<a href="https://colab.research.google.com/github/shivendrra/biosaic/blob/main/training%20files/biosaic_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class ModelConfig:
  d_model: int= 256
  in_dim: int= 4
  n_embed: int= 256
  beta: float= 0.1
  dropout: float= 0.2
  n_heads: int= 8
  n_layers: int= 10

In [None]:
class encoder(nn.Module):
  def __init__(self, _in, d_model, n_layers, n_heads):
    super().__init__()
    self.embed = nn.Linear(_in, d_model)
    self.encoder = nn.TransformerEncoder(
      nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads),
      num_layers=n_layers)

  def forward(self, x):
    x = self.embed(x)
    x = x.permute(1, 0, 2)  # (L, B, d_model)
    z_e = self.encoder(x) # Transformer encoding
    return z_e.permute(1, 0, 2) # Back to (B, L, 4)

class decoder(nn.Module):
  def __init__(self, d_model, _out, n_layers, n_heads):
    super().__init__()
    self.decoder = nn.TransformerDecoder(
      nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads),
      num_layers=n_layers
    )
    self.fc_out = nn.Linear(d_model, _out)  # Output logits (4 classes)

  def forward(self, z_q):
    z_q = z_q.permute(1, 0, 2)  # (L, B, d_model)
    x_recon = self.decoder(z_q, z_q)  # Transformer decoding
    x_recon = self.fc_out(x_recon.permute(1, 0, 2))  # Back to (B, L, 4)
    return x_recon

class Quantizer(nn.Module):
  def __init__(self, n_embed, d_model, beta):
    super().__init__()
    self.n_embed, self.d_model, self.beta = n_embed, d_model, beta
    self.embeddings = nn.Embedding(n_embed, d_model)
    self.embeddings.weight.data.uniform_(-1.0 / n_embed, 1.0 / n_embed)

  def forward(self, z_e):
    z_e_flat = z_e.reshape(-1, self.d_model)
    distances = torch.cdist(z_e_flat, self.embeddings.weight)
    encoding_indices = torch.argmin(distances, dim=1)
    z_q = self.embeddings(encoding_indices).view(z_e.shape)
    loss = self.beta * torch.mean((z_q.detach() - z_e) ** 2) + torch.mean((z_e.detach() - z_q) ** 2)

    z_q = z_e + (z_q - z_e).detach()
    return z_q, loss, encoding_indices.view(z_e.shape[:-1])

class DNA_VQVAE(nn.Module):
  def __init__(self, args: ModelConfig):
    super().__init__()
    self.encoder = encoder(args.in_dim, args.d_model, args.n_layers, args.n_heads)
    self.vq_layer = Quantizer(args.n_embed, args.d_model, args.beta)
    self.decoder = decoder(args.d_model, args.in_dim, args.n_layers, args.n_heads)

  def forward(self, x):
    z_e = self.encoder(x)
    z_q, vq_loss, indices = self.vq_layer(z_e)
    x_recon = self.decoder(z_q)
    return x_recon, vq_loss, indices

In [None]:
DNA_VOCAB = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
INDEX_TO_DNA = {v: k for k, v in DNA_VOCAB.items()}  # Reverse mapping

class Dataset:
  def __init__(self, path:str):
    self.path = path

  def load_simple(self):
    with open(self.path, "r", encoding="utf-8") as f:
      lines = [line.strip() for line in f.readlines() if line.strip()]  # removing empty lines & strip whitespace
    merged_sequence = "".join(lines)  # joining all lines into a single sequence
    return merged_sequence

  def load_encoded(self, seq=None):
    if seq:
      loaded_sequences = seq
    else:
      loaded_sequences = self.load_simple()
    seq_idx = [DNA_VOCAB[char] for char in loaded_sequences]
    return F.one_hot(torch.tensor(seq_idx, dtype=torch.long), num_classes=4) # shape (L, 4)

  def train_test_split(self, sequence:str=None, ratio:float=0.8):
    sequence = self.load_encoded(seq=sequence) if sequence else self.load_encoded()
    split_size = int(0.8 * len(sequence))

    train_data = sequence[:split_size]
    test_data = sequence[split_size:]
    return train_data, test_data

In [None]:
# train-test split
file_path = "/content/drive/MyDrive/dna_data2.txt"
data = Dataset(file_path)
train_data, val_data = data.train_test_split(ratio=0.75)

In [None]:
# train params
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 1e-4         # bumped from 1e-5
weight_decay  = 1e-4
amsgrad       = True
warmup_epochs = 50           # linear warm‑up
epochs        = 2000
eval_interval = 100
eval_iters    = 30
batch_size    = 6
block_size    = 256
loss_history  = []

In [None]:
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

_model = DNA_VQVAE(ModelConfig).to("cpu")
n_param = sum(p.numel() for p in _model.parameters())/1e6
print(f"{n_param:.2f} million")
optimizer = torch.optim.Adam(_model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=1e-5, betas=(0.9, 0.95))

# ——— Learning‑rate Schedulers ———
# 1) Warm‑up: linearly ramp LR from 0 → lr over warmup_epochs
warmup_scheduler = LambdaLR(
  optimizer,
  lr_lambda=lambda epoch: min((epoch+1)/warmup_epochs, 1.0)
)
# 2) After warm‑up, cosine decay from lr → 0 over remaining epochs
cosine_scheduler = CosineAnnealingLR(
  optimizer,
  T_max=epochs - warmup_epochs,
  eta_min=1e-6
)

In [None]:
checkpoint_path = '/content/drive/MyDrive/biosaic_30m.pth'
checkpoint = torch.load(checkpoint_path)
_model.load_state_dict(checkpoint)
_model = _model.to("cpu")

In [None]:
def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix]).float()  # Convert to float
  y = torch.stack([data[i+1:i+block_size+1] for i in ix]).float()  # Convert to float
  return x.to("cpu"), y.to("cpu")

torch.manual_seed(1800)

@torch.no_grad()
def estimate_loss():
  out = {}
  _model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      x_recon, vq_loss, _ = _model(X)
      recon_loss = F.cross_entropy(x_recon.view(-1, 4), Y.view(-1, 4))
      losses[k] = (recon_loss + vq_loss).item()
    out[split] = losses.mean()
  _model.train()
  return out

import timeit

start_time = timeit.default_timer()
for epoch in range(epochs):
  xb, yb = get_batch('train')

  x_recon, vq_loss, _ = _model(xb)
  recon_ce  = F.cross_entropy(x_recon.view(-1,4), yb.view(-1,4))
  recon_mse = F.mse_loss(torch.softmax(x_recon, dim=-1), yb)
  recon_loss = recon_ce + 0.5*recon_mse

  optimizer.zero_grad()
  recon_loss.backward()
  # — Gradient clipping —
  torch.nn.utils.clip_grad_norm_(_model.parameters(), max_norm=1.0)

  optimizer.step()

  # — Scheduler step —
  if epoch < warmup_epochs:
    warmup_scheduler.step()
  else:
    cosine_scheduler.step()

  # — Logging & eval —
  if (epoch+1) % eval_interval == 0:
    losses = estimate_loss()
    print(f"Epoch {epoch+1:4d} | train {losses['train']:.4f}  val {losses['val']:.4f}")
    loss_history.append((epoch+1, losses['train'], losses['val']))

end_time = timeit.default_timer()
print(f"Total time taken: {(end_time - start_time) / 3600} hrs")

In [None]:
n_param = sum(p.numel() for p in _model.parameters())/1e6
model_save_name = f'biosaic_{n_param:.0f}m.pth'
path = f"/content/drive/MyDrive/{model_save_name}"
torch.save(_model.state_dict(), path)

In [None]:
import matplotlib.pyplot as plt

epochs_logged, train_losses, val_losses = zip(*loss_history)
plt.figure(figsize=(8, 5))
plt.plot(epochs_logged, train_losses, label="Train Loss", marker='o', linestyle='-')
plt.plot(epochs_logged, val_losses, label="Val Loss", marker='o', linestyle='--')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training & Validation Loss Over Time")
plt.grid(True)
plt.show()

In [None]:
# saving safe-tensors
from safetensors.torch import save_file

model_save_name = f'consolidated_00.safetensors'
path = f"/content/drive/MyDrive/{model_save_name}"
save_file(model.state_dict(), path)

In [None]:
DNA_VOCAB = {"A": 0, "T": 1, "C": 2, "G": 3}
INDEX_TO_DNA = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DNA_VQVAE(ModelConfig).to(device)

class tokenizer:
  def __init__(self):
    self.vocab = DNA_VOCAB
    self.ids_to_dna = INDEX_TO_DNA
    self.device = device

  def __str__(self):
    return f"\t/Biosaic VQ-VAE tokenizer v1.0.0/\t"

  def dna_to_onehot(self, seq):
    seq_idx = [DNA_VOCAB[char] for char in seq]
    one_hot = F.one_hot(torch.tensor(seq_idx), num_classes=4)
    return one_hot.float()

  def onehot_to_dna(self, logits):
    decoded_out = torch.argmax(logits, dim=-1)
    decoded = ''.join(self.ids_to_dna[idx.item()] for idx in decoded_out.squeeze(0))
    return decoded

  def encode(self, seq: str):
    one_hot_seq = self.dna_to_onehot(seq).unsqueeze(0).to(device)
    _, _, tokens = model(one_hot_seq)
    return tokens.cpu().numpy()

  def decode(self, tokens: list):
    tokens = torch.tensor(tokens, dtype=torch.long).to(device)
    z_q = model.vq_layer.embeddings(tokens)

    with torch.no_grad():
      logits = model.decoder(z_q)
    decoded = self.onehot_to_dna(logits)
    return decoded

In [None]:
with open('/content/drive/MyDrive/dna_data2.txt', 'r', encoding='utf-8') as file:
  dataset = file.readlines()
  dataset = "".join(line.strip() for line in dataset if line.strip())
  dataset = dataset.upper()
print(len(dataset)/1e6, 'million words')

In [None]:
seq = """ACTACGCCACGCCAGTTGAACTGGTGCCGCTGTTAGAGGAAAAATCTTCATGGATGAGCCATGCCGCGCTGGTGTTTGGTCGCGAAGATTCCGGGTTGACTAACGAAGAGTTAGCGTTGGCTGACGTTCTTACTGGTGTGCCGATGGTGGCGGATTATCCTTCGCTCAATCTGGGGCAGGCGGTGATGGTCTATTGCTATCAATTAGCAACATTAATACAACAACCGGCGAAAAGTGATGCAACGGCAGACCAACATCAACTGCAAGCTTTACGCGAACGAGCCATGACATTGCTGACGACTCTGGCAGTGGCAGATGACATAAAACTGGTCGACTGGTTACAACAACGCCTGGGGCTTTTAGAGCAACGAGACACGGCAATGTTGCACCGTTTGCTGCATGATATTGAAAAAAATATCACCAAATAAAAAACGCCTTAGTAAGTATTTTTC"""
token = tokenizer()
encoded = token.encode(seq)
decoded = token.decode(encoded)
print(encoded)
print(decoded)
print(seq == decoded)