<a href="https://colab.research.google.com/github/shivendrra/biosaic/blob/main/training%20files/biosaic_evoformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class ModelConfig:
  DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  A            = 4        # DNA alphabet
  C            = 21       # 21 letter for amino acid & 4 for dna
  d_msa        = 128
  d_pair       = 64
  n_heads      = 8
  n_blocks     = 4

In [None]:
class RowAttention(nn.Module):
  def __init__(self, d_msa, n_heads):
    super().__init__()
    self.attn = nn.MultiheadAttention(d_msa, n_heads, batch_first=True)
  def forward(self, msa):  # msa: (B, N, L, d_msa)
    B, N, L, D = msa.shape
    x = msa.view(B*L, N, D)  # treat each position across sequences as a sequence
    out, _ = self.attn(x, x, x)
    return out.view(B, N, L, D)

class ColAttention(nn.Module):
  def __init__(self, d_msa, n_heads):
    super().__init__()
    self.attn = nn.MultiheadAttention(d_msa, n_heads, batch_first=True)
  def forward(self, msa):
    B, N, L, D = msa.shape
    x = msa.permute(0,2,1,3).reshape(B* N, L, D)  # each sequence across positions
    out, _ = self.attn(x, x, x)
    return out.view(B, L, N, D).permute(0,2,1,3)

class TriMulUpdate(nn.Module):
  def __init__(self, d_pair):
    super().__init__()
    self.linear_a = nn.Linear(d_pair, d_pair)
    self.linear_b = nn.Linear(d_pair, d_pair)
  def forward(self, pair):
    # pair: (B, L, L, d_pair)
    left = self.linear_a(pair)    # (B,L,L,d)
    right= self.linear_b(pair)    # (B,L,L,d)
    # outer product along one axis
    # simplistic: new_pair[i,j] += sum_k left[i,k] * right[k,j]
    return pair + torch.einsum("bikd,bkjd->bijd", left, right)

class Block(nn.Module):
  def __init__(self, d_msa, d_pair, n_heads):
    super().__init__()
    self.row_attn = RowAttention(d_msa, n_heads)
    self.col_attn = ColAttention(d_msa, n_heads)
    self.tri_mul = TriMulUpdate(d_pair)
    # plus feed‑forwards, layernorms, gating, etc.

  def forward(self, msa, pair):
    msa = msa + self.row_attn(msa)
    msa = msa + self.col_attn(msa)
    pair= pair + self.tri_mul(pair)
    return msa, pair

class Evoformer(nn.Module):
  def __init__(self, params: ModelConfig):
    """
      A: alphabet size (e.g. 4 for DNA, 21 for protein)
      C: number of initial pair features
    """
    super().__init__()
    self.embed_msa  = nn.Linear(ModelConfig.A, ModelConfig.d_msa)
    self.embed_pair = nn.Linear(ModelConfig.C, ModelConfig.d_pair)
    self.blocks     = nn.ModuleList([
      Block(ModelConfig.d_msa, ModelConfig.d_pair, ModelConfig.n_heads)
      for _ in range(ModelConfig.n_blocks)
    ])
    # for masked token prediction
    self.msa_out = nn.Linear(ModelConfig.d_msa, ModelConfig.A)
  def forward(self, msa, pair):
    # msa: (B, N, L, A); pair: (B, L, L, C)
    msa  = self.embed_msa(msa)
    pair = self.embed_pair(pair)
    for blk in self.blocks:
      msa, pair = blk(msa, pair)
    # return logits for each msa position
    return self.msa_out(msa), pair

In [None]:
class TrainConfig:
  DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  LR           = 1e-4
  WD           = 1e-4
  AMS          = True
  WARMUP       = 50
  EPOCHS       = 500
  BATCH        = 8
  MSA_SEQ      = 32       # number of sequences in each MSA
  L_SEQ        = 256      # length of each sequence
  EVAL_ITERS   = 5
  EVAL_INTV    = 50

In [None]:
import numpy as np

msa_data  = np.load("msa.npy")   # shape (D, N, L, A)
pair_data = np.load("pair.npy")  # shape (D, L, L, C)
assert msa_data.ndim==4 and pair_data.ndim==4

In [None]:
# ------ 3. Train/Val Split ------
D = msa_data.shape[0]
split = int(D * 0.85)
msa_train, msa_val   = msa_data[:split], msa_data[split:]
pair_train, pair_val = pair_data[:split], pair_data[split:]

In [None]:
# ------ 4. Model, Optimizer, Scheduler ------
model = Evoformer(ModelConfig).to(ModelConfig.DEVICE)
n_param = sum(p.numel() for p in _model.parameters())/1e6
print(f"{n_param:.2f} million")

opt   = AdamW(model.parameters(), lr=TrainConfig.LR, weight_decay=TrainConfig.WD, amsgrad=TrainConfig.AMS)
warm  = LambdaLR(opt, lambda e: min((e+1)/TrainConfig.WARMUP, 1.0))
cos   = CosineAnnealingLR(opt, T_max=TrainConfig.EPOCHS-TrainConfig.WARMUP, eta_min=1e-6)

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

# ------ 5. Batch Sampler ------
def get_batch(split):
  if split=="train":
    msa, pair = msa_train, pair_train
  else:
    msa, pair = msa_val,   pair_val
  idx = np.random.randint(0, msa.shape[0], size=TrainConfig.BATCH)
  # each batch: (B, N, L, A) and (B, L, L, C)
  return (
    torch.tensor(msa[idx],  dtype=torch.float32, device=TrainConfig.DEVICE),
    torch.tensor(pair[idx], dtype=torch.float32, device=TrainConfig.DEVICE)
  )

# ------ 6. Eval Loss (masked‑token CE) ------
@torch.no_grad()
def estimate_loss():
  model.eval()
  out = {}
  for split in ("train","val"):
    losses = []
    for _ in range(TrainConfig.EVAL_ITERS):
      M, P = get_batch(split)
      logits, _ = model(M, P)
      # masked‑token: randomly mask 15% of msa positions
      mask = (torch.rand_like(logits[...,0]) < 0.15)
      target = M.argmax(-1)  # (B,N,L)
      logits = logits[mask]
      target = target[mask]
      losses.append(F.cross_entropy(logits, target).item())
    out[split] = sum(losses)/len(losses)
  model.train()
  return out

# ------ 7. Training Loop ------
history = []
for epoch in range(TrainConfig.EPOCHS):
  M, P = get_batch("train")
  opt.zero_grad()
  logits, _ = model(M, P)
  mask   = (torch.rand_like(logits[...,0]) < 0.15)
  target = M.argmax(-1)
  loss   = F.cross_entropy(logits[mask], target[mask])
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  opt.step()
  if epoch < TrainConfig.WARMUP: warm.step()
  else:            cos.step()

  if (epoch+1)%TrainConfig.EVAL_INTV==0:
    losses = estimate_loss()
    print(f"Epoch {epoch+1:4d} | train {losses['train']:.4f}  val {losses['val']:.4f}")
    history.append((epoch+1, losses['train'], losses['val']))

In [None]:
# ------ 8. Save & Plot ------
n_param = sum(p.numel() for p in _model.parameters())/1e6
model_save_name = f'evo_{n_param:.0f}m.pth'
path = f"/content/drive/MyDrive/{model_save_name}"
torch.save(_model.state_dict(), path)

In [None]:
try:
  import matplotlib.pyplot as plt
  e,t,v = zip(*history)
  plt.plot(e,t,label="train"); plt.plot(e,v,label="val")
  plt.legend(); plt.show()
except:
  pass