In [None]:
!pip install Bio



In [None]:
# libraries and config global variables
import gc
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import pandas as pd
from Bio import SeqIO
from sklearn.model_selection import train_test_split

CSV_PATH    = "train.tsv"
TEST_CSV    = "test.tsv"
FASTA_PATH  = "sequences.fasta"
CHECKPOINT  = "LongSafari/hyenadna-small-32k-seqlen-hf"
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE  = 32
NUM_EPOCHS  = 5
LR_ENCODER  = 1e-5
LR_HEAD     = 1e-3
MAX_LEN     = 381
VALID_RATIO = 0.1
LOG_INTERVAL= 50

In [None]:
# function and class definitions
def pad_collate(batch, pad_id):
    seqs, masks, lbls = zip(*batch)
    B, L = len(seqs), MAX_LEN
    input_ids = torch.full((B, L), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((B, L), dtype=torch.long)
    labels = torch.full((B, L), -100, dtype=torch.long)
    for i, (ids, m, lbl) in enumerate(zip(seqs, masks, lbls)):
        sl = min(ids.size(0), L)
        input_ids[i,:sl] = ids[:sl]
        attention_mask[i,:sl] = m[:sl]
        ml = min(lbl.size(0), max(0, sl-2))
        if ml>0:
            labels[i,1:1+ml] = lbl[:ml]
    return input_ids, attention_mask, labels

class ExonIntronDataset(Dataset):
    def __init__(self, df, seq_dict, tokenizer):
      self.rows = []
      for _, row in df.iterrows():
          gid = row['id']
          if gid in seq_dict:
              seq = seq_dict[gid]
              lbl = torch.tensor([int(c) for c in row['label']], dtype=torch.long)
              self.rows.append((seq, lbl))
      self.tokenizer = tokenizer

    def __len__(self):
      return len(self.rows)

    def __getitem__(self, idx):
      seq, lbl = self.rows[idx]
      enc = self.tokenizer(
          seq,
          truncation=True,
          max_length=MAX_LEN,
          return_attention_mask=True,
          return_tensors='pt'
      )
      return enc['input_ids'].squeeze(0), enc['attention_mask'].squeeze(0), lbl

class TestDataset(Dataset):
  def __init__(self, df, seq_dict, tokenizer):
    self.ids = df['id'].tolist()
    self.seq_dict = seq_dict
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.ids)

  def __getitem__(self, idx):
    gid = self.ids[idx]
    seq = self.seq_dict.get(gid, '')
    enc = self.tokenizer(
        seq,
        truncation=True,
        max_length=MAX_LEN,
        return_attention_mask=True,
        return_tensors='pt'
    )
    return gid, enc['input_ids'].squeeze(0), enc['attention_mask'].squeeze(0)

class TokenClassifier(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.dropout = nn.Dropout(0.1)
    self.fc1 = nn.Linear(hidden_dim, hidden_dim//2)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden_dim//2, 2)

  def forward(self, x):
    B, L, H = x.size()
    x = self.dropout(x)
    x = self.act(self.fc1(x.reshape(B*L, H)))
    out = self.fc2(x)
    return out.reshape(B, L, 2)


In [None]:
  # load and split
  df = pd.read_csv(CSV_PATH, sep='\t')
  train_df, val_df = train_test_split(df, test_size=VALID_RATIO, random_state=42)
  test_df = pd.read_csv(TEST_CSV, sep='\t')
  seq_dict = {rec.id:str(rec.seq) for rec in SeqIO.parse(FASTA_PATH,'fasta')}

  # tokenizer & encoder
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, trust_remote_code=True)
  pad_id = tokenizer.pad_token_id
  base_model = AutoModel.from_pretrained(
      CHECKPOINT,
      trust_remote_code=True,
      torch_dtype=torch.float32
  ).to(DEVICE)
  base_model.train()
  for p in base_model.parameters(): p.requires_grad = True

  hidden_dim = base_model.config.d_model
  classifier = TokenClassifier(hidden_dim).to(DEVICE)

  enc_params = list(base_model.parameters())
  head_params = list(classifier.parameters())
  optimizer = optim.AdamW([
      {'params': enc_params, 'lr': LR_ENCODER},
      {'params': head_params, 'lr': LR_HEAD}
  ])
  total_steps = NUM_EPOCHS * (len(train_df)//BATCH_SIZE)
  scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=int(0.1*total_steps),
      num_training_steps=total_steps
  )
  criterion = nn.CrossEntropyLoss(ignore_index=-100)

  # DataLoaders
  train_loader = DataLoader(ExonIntronDataset(train_df, seq_dict, tokenizer),
      batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True,
      collate_fn=lambda b: pad_collate(b, pad_id)
  )
  val_loader = DataLoader(ExonIntronDataset(val_df, seq_dict, tokenizer),
      batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True,
      collate_fn=lambda b: pad_collate(b, pad_id)
  )
  test_loader = DataLoader(TestDataset(test_df, seq_dict, tokenizer),
      batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True,
      collate_fn=lambda b: ([g for g,_,_ in b], *pad_collate([(ids,m,torch.tensor([])) for g,ids,m in b], pad_id)[:2])
  )

  # training
  for epoch in range(1, NUM_EPOCHS+1):
      classifier.train()
      total_loss = 0.0
      for step, (in_ids, in_mask, labels) in enumerate(train_loader, start=1):
          in_ids, in_mask, labels = in_ids.to(DEVICE), in_mask.to(DEVICE), labels.to(DEVICE)

          out = base_model(in_ids)
          emb = (out.last_hidden_state * in_mask.unsqueeze(-1))[:,1:-1,:]

          logits = classifier(emb)
          tgt = labels[:,1:1+logits.size(1)].reshape(-1)
          loss = criterion(logits.reshape(-1,2), tgt)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          scheduler.step()
          total_loss += loss.item()
          if step % LOG_INTERVAL == 0:
              print(f"Epoch {epoch} step {step}/{len(train_loader)} loss {total_loss/step:.4f}")
      print(f"Epoch {epoch} Train loss: {total_loss/len(train_loader):.4f}")

      # validation
      classifier.eval()
      vloss, corr, tot = 0.0,0,0
      with torch.no_grad():
          for in_ids, in_mask, labels in val_loader:
              in_ids, in_mask, labels = in_ids.to(DEVICE), in_mask.to(DEVICE), labels.to(DEVICE)
              out = base_model(in_ids)
              emb = (out.last_hidden_state * in_mask.unsqueeze(-1))[:,1:-1,:]
              logits = classifier(emb)
              tgt = labels[:,1:1+logits.size(1)]
              loss = criterion(logits.reshape(-1,2), tgt.reshape(-1))
              vloss += loss.item()
              preds = logits.argmax(-1)
              mask = tgt != -100
              corr += (preds[mask]==tgt[mask]).sum().item()
              tot += mask.sum().item()
      print(f"Epoch {epoch} Val loss: {vloss/len(val_loader):.4f}, Val acc: {corr/tot:.2%}")

KeyboardInterrupt: 

In [None]:
# test
classifier.eval()
outs=[]
with torch.no_grad():
    for gids, in_ids, in_mask in test_loader:
        in_ids, in_mask = in_ids.to(DEVICE), in_mask.to(DEVICE)
        out = base_model(in_ids)
        emb = (out.last_hidden_state*in_mask.unsqueeze(-1))[:,1:-1,:]
        preds = classifier(emb).argmax(-1).cpu()
        for gid, seq, p in zip(gids, in_ids.cpu(), preds):
            sl=(seq!=pad_id).sum().item()
            s=max(0,sl-2)
            outs.append({'id':gid,'prediction':''.join(map(str,p[:s].tolist()))})
pd.DataFrame(outs).to_csv("predictions.csv",sep="\t",index=False)
print("Saved predictions.csv")

Saved predictions.csv
