<a href="https://colab.research.google.com/github/vnsupara/esm2_protein_ss_prediction/blob/main/protein_language_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets seqio fair-esm torch torchvision torchaudio --upgrade
!pip install -q accelerate



In [None]:
!pip install transformers -U
!pip install Bio



In [None]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from torch import nn
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO

device = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", device)

DEVICE: cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
checkpoint_dir = '/content/drive/MyDrive/protein_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
os.chdir(checkpoint_dir)

print(f"Working directory: {checkpoint_dir}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Working directory: /content/drive/MyDrive/protein_checkpoints
Checkpoints will be saved to Google Drive


In [None]:
train_df = pd.read_csv("/content/drive/MyDrive/train.tsv", sep="\t")
test_df  = pd.read_csv("/content/drive/MyDrive/test.tsv",  sep="\t")

In [None]:
seqs = {}
for rec in SeqIO.parse("/content/drive/MyDrive/sequences.fasta", "fasta"):
    seqs[rec.id] = str(rec.seq)


In [None]:
labels = ['H','B','E','G','I','P','T','S','.']
label2id = {l:i for i,l in enumerate(labels)}
id2label = {i:l for l,i in label2id.items()}

In [None]:

model_name = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
esm_model = AutoModel.from_pretrained(model_name).to(device)


Loading ESM model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
esm_model.eval()

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 480, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-11): 12 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=480, out_features=480, bias=True)
            (key): Linear(in_features=480, out_features=480, bias=True)
            (value): Linear(in_features=480, out_features=480, bias=True)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=480, out_features=480, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=480, out_features=1920, bias=True)
        )
        (output): EsmOutput(
      

In [None]:
os.makedirs("embeddings", exist_ok=True)

def embed_protein(pid, seq):
    save_path = f"embeddings/{pid}.npy"
    if os.path.exists(save_path):
        return np.load(save_path)

    tokens = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
    with torch.no_grad():
        out = esm_model(**{k:v.to(device) for k,v in tokens.items()})
    emb = out.last_hidden_state.squeeze(0).cpu().numpy()
# save in case times out
    np.save(save_path, emb)
    return emb
all_pids = set([x.split("_")[0] for x in train_df.id] +
               [x.split("_")[0] for x in test_df.id])

print("Embedding proteins...")
for pid in tqdm(all_pids):
    out_file = f"embeddings/{pid}.npy"

    if os.path.exists(out_file):
        continue

    seq = seqs[pid]
    embed_protein(pid, seq)

Embedding proteins...


100%|██████████| 7711/7711 [00:00<00:00, 15321.91it/s]


In [None]:
class SSData(Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pid, aa, pos_str = row.id.split("_")
        pos = int(pos_str)
        emb = np.load(f"embeddings/{pid}.npy")
        emb_index = pos + 1

        if not (1 <= emb_index < len(emb) - 1):
            x = np.zeros(480, dtype=np.float32)
        else:
            x = emb[emb_index]

        y = label2id[row.secondary_structure] if "secondary_structure" in row else -1
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y)

In [None]:
train_dataset = SSData(train_df)
print("training data...")
train_loader  = DataLoader(train_dataset, batch_size=256, shuffle=True)

training data...


In [None]:
class BiLSTMClassifier(nn.Module):
    def __init__(self, input_dim=480, hidden=256, num_classes=9):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden, num_layers=1,
                            batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden*2, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)  # (B,1,480)
        out,_ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out


model = BiLSTMClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

CHECKPOINT = "model_checkpoint.pt"


In [None]:
if os.path.exists(CHECKPOINT):
    print("Loading checkpoint...")
    model.load_state_dict(torch.load(CHECKPOINT))

In [None]:
EPOCHS = 4

best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    total, correct = 0, 0

    for x, y in tqdm(train_loader):
        mask = (y != -1)
        x, y = x[mask].to(device), y[mask].to(device)

        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()

        correct += (pred.argmax(1) == y).sum().item()
        total += y.size(0)

    acc = correct / total * 100
    print(f"Epoch {epoch+1}/{EPOCHS} - Train Acc: {acc:.2f}%")

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), CHECKPOINT)
        print(f"new best model saved at Epoch {epoch+1} with accuracy: {best_acc:.2f}%")
    else:
        print("no improvement in accuracy, skipping checkpoint save")

100%|██████████| 10580/10580 [2:06:00<00:00,  1.40it/s]


Epoch 1/6 - Train Acc: 67.85%
New best model saved at Epoch 1 with accuracy: 67.85%


100%|██████████| 10580/10580 [1:58:13<00:00,  1.49it/s]


Epoch 2/6 - Train Acc: 71.58%
New best model saved at Epoch 2 with accuracy: 71.58%


  0%|          | 26/10580 [00:17<1:57:10,  1.50it/s]


KeyboardInterrupt: 

In [None]:
test_dataset = SSData(test_df)
test_loader  = DataLoader(test_dataset, batch_size=256)

preds = []

model.eval()
for x, _ in tqdm(test_loader):
    x = x.to(device)
    with torch.no_grad():
        out = model(x).argmax(1).cpu().numpy()
    preds.extend(out)

test_df["prediction"] = [id2label[i] for i in preds]

test_df.to_csv("prediction_4.csv", index=False)
print("Saved prediction_4.csv")

100%|██████████| 2634/2634 [27:14<00:00,  1.61it/s]


Saved prediction_4.csv — ready to zip + upload to Codabench


In [None]:
test_df.to_csv("/content/drive/MyDrive/prediction_4.csv", index=False)