In [5]:
!pip install -q torch  
!pip install -q transformers accelerate fair-esm obonet

In [1]:
import random
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

In [4]:
DATA_ROOT = "/kaggle/input/cafa-6-protein-function-prediction"

In [5]:
#### 1st let's read the terms, nd do some checkings
train_terms = pd.read_csv(
    f"{DATA_ROOT}/Train/train_terms.tsv",
    sep="\t",
    header=0
)

train_terms["aspect"] = train_terms["aspect"].str.strip()

print(train_terms.head())
print(train_terms.shape)
print(train_terms.aspect.unique())  ## we shall have F, C, and P

def build_go_vocab(df, aspect):
    df_aspect = df[df.aspect == aspect]
    assert len(df_aspect) > 0, f"No terms found for aspect {aspect}"
    gos = sorted(df_aspect.term.unique())
    return {go: i for i, go in enumerate(gos)}

train_seqs = {}

with open(f"{DATA_ROOT}/Train/train_sequences.fasta") as f:
    cur = None
    for line in f:
        if line.startswith(">"):
            cur = line.split("|")[1]   
            train_seqs[cur] = ""
        else:
            train_seqs[cur] += line.strip()

train_terms["aspect"] = train_terms["aspect"].str.strip()

# Getting all unique proteins that have annotations
all_proteins = train_terms.EntryID.unique()
random.seed(42)
random.shuffle(all_proteins)

### we re only training on half the dataset, due to time
train_proteins = set(all_proteins[:len(all_proteins)//2])
train_terms_subset = (
    train_terms[train_terms.EntryID.isin(train_proteins)]
    .copy()
)

go_vocab = {
    "F": build_go_vocab(train_terms_subset, "F"),
    "P": build_go_vocab(train_terms_subset, "P"),
    "C": build_go_vocab(train_terms_subset, "C"),
}

for aspect in ["F", "P", "C"]:
    n = train_terms_subset[train_terms_subset.aspect==aspect].EntryID.nunique()
    print(aspect, n)



  EntryID        term aspect
0  Q5W0B1  GO:0000785      C
1  Q5W0B1  GO:0004842      F
2  Q5W0B1  GO:0051865      P
3  Q5W0B1  GO:0006275      P
4  Q5W0B1  GO:0006513      P
(537027, 3)
['C' 'F' 'P']
F 29021
P 30038
C 30168


In [6]:
from torch.utils.data import Dataset
import torch

class ProteinDataset(Dataset):
    def __init__(self, terms_df, seqs, go_vocab, aspect):
        self.data = []

        terms_df = terms_df.copy()
        terms_df["aspect"] = terms_df["aspect"].str.strip()

        df_aspect = terms_df[terms_df.aspect == aspect]
        print("ASPECT:", aspect)
        print("terms_df aspects:", terms_df.aspect.unique())
        print("terms_df rows:", len(terms_df))
        print("rows after aspect filter:", len(df_aspect))

        grouped = df_aspect.groupby("EntryID")

        for pid, grp in grouped:
            if pid in seqs:
                label = torch.zeros(len(go_vocab), dtype=torch.float32)
                for go in grp.term:
                    if go in go_vocab:
                        label[go_vocab[go]] = 1.0
                self.data.append((seqs[pid], label))

        print("final dataset size:", len(self.data))
        assert len(self.data) > 0, f"No proteins found for aspect {aspect}!"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [7]:
class ESM2Classifier(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.classifier = torch.nn.Linear(
            self.encoder.config.hidden_size, num_labels
        )

    def forward(self, input_ids, attention_mask):
        out = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled = out.last_hidden_state.mean(dim=1)  # we can choose other approachs, instead of mean pooling
        return self.classifier(pooled)


In [11]:
def train_model(aspect):
    MODEL_NAME = "facebook/esm2_t6_8M_UR50D" ### again, the t6 was used instead
                                        ### of larger models, to speed up the training
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    full_dataset = ProteinDataset(
        train_terms_subset, train_seqs, go_vocab[aspect], aspect
    )
    
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    model = ESM2Classifier(MODEL_NAME, len(go_vocab[aspect])).cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    print(f"[{aspect}] Training with {train_size} train, {val_size} val samples")
    

    for epoch in range(10):
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        for seqs, labels in tqdm(train_loader, desc=f"Train {aspect} Epoch {epoch+1}"):
            enc = tokenizer(
                seqs,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  
            ).to("cuda")
            
            labels = labels.cuda()
            logits = model(**enc)
            loss = loss_fn(logits, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
        
        avg_train_loss = train_loss / max(train_batches, 1)
        
        model.eval()
        val_loss = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for seqs, labels in tqdm(val_loader, desc=f"Val {aspect} Epoch {epoch+1}"):
                enc = tokenizer(
                    seqs,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=512
                ).to("cuda")
                
                labels = labels.cuda()
                logits = model(**enc)
                loss = loss_fn(logits, labels)
                
                val_loss += loss.item()
                val_batches += 1
        
        avg_val_loss = val_loss / max(val_batches, 1)
        
        print(f"[{aspect}] Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    # Save model, each head
    torch.save(model.state_dict(), f"esm2_{aspect}.pt")
    return model

In [11]:
for aspect in ["F", "P", "C"]:
    train_model(aspect)

ASPECT: F
terms_df aspects: ['C' 'F' 'P']
terms_df rows: 267985
rows after aspect filter: 64227
final dataset size: 29021


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[F] Loading pre-trained weights from /kaggle/input/esm-protein/transformers/default/1/esm2_F.pt
[F] Training with 23216 train, 5805 val samples


Train F Epoch 1: 100%|██████████| 1451/1451 [05:39<00:00,  4.27it/s]
Val F Epoch 1: 100%|██████████| 363/363 [00:29<00:00, 12.17it/s]


[F] Epoch 1: Train Loss: 0.0026, Val Loss: 0.0025


Train F Epoch 2: 100%|██████████| 1451/1451 [05:39<00:00,  4.28it/s]
Val F Epoch 2: 100%|██████████| 363/363 [00:29<00:00, 12.16it/s]


[F] Epoch 2: Train Loss: 0.0025, Val Loss: 0.0024


Train F Epoch 3: 100%|██████████| 1451/1451 [05:39<00:00,  4.27it/s]
Val F Epoch 3: 100%|██████████| 363/363 [00:29<00:00, 12.15it/s]


[F] Epoch 3: Train Loss: 0.0024, Val Loss: 0.0024


Train F Epoch 4: 100%|██████████| 1451/1451 [05:39<00:00,  4.27it/s]
Val F Epoch 4: 100%|██████████| 363/363 [00:29<00:00, 12.14it/s]


[F] Epoch 4: Train Loss: 0.0024, Val Loss: 0.0024


Train F Epoch 5: 100%|██████████| 1451/1451 [05:39<00:00,  4.27it/s]
Val F Epoch 5: 100%|██████████| 363/363 [00:29<00:00, 12.14it/s]


[F] Epoch 5: Train Loss: 0.0024, Val Loss: 0.0024
ASPECT: P
terms_df aspects: ['C' 'F' 'P']
terms_df rows: 267985
rows after aspect filter: 124630
final dataset size: 30038


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[P] Loading pre-trained weights from /kaggle/input/esm-protein/transformers/default/1/esm2_P.pt
[P] Training with 24030 train, 6008 val samples


Train P Epoch 1: 100%|██████████| 1502/1502 [05:54<00:00,  4.24it/s]
Val P Epoch 1: 100%|██████████| 376/376 [00:31<00:00, 11.92it/s]


[P] Epoch 1: Train Loss: 0.0025, Val Loss: 0.0024


Train P Epoch 2: 100%|██████████| 1502/1502 [05:54<00:00,  4.24it/s]
Val P Epoch 2: 100%|██████████| 376/376 [00:31<00:00, 11.92it/s]


[P] Epoch 2: Train Loss: 0.0024, Val Loss: 0.0024


Train P Epoch 3: 100%|██████████| 1502/1502 [05:54<00:00,  4.24it/s]
Val P Epoch 3: 100%|██████████| 376/376 [00:31<00:00, 11.95it/s]


[P] Epoch 3: Train Loss: 0.0024, Val Loss: 0.0024


Train P Epoch 4: 100%|██████████| 1502/1502 [05:54<00:00,  4.24it/s]
Val P Epoch 4: 100%|██████████| 376/376 [00:31<00:00, 11.95it/s]


[P] Epoch 4: Train Loss: 0.0024, Val Loss: 0.0024


Train P Epoch 5: 100%|██████████| 1502/1502 [05:54<00:00,  4.24it/s]
Val P Epoch 5: 100%|██████████| 376/376 [00:31<00:00, 11.94it/s]


[P] Epoch 5: Train Loss: 0.0024, Val Loss: 0.0024
ASPECT: C
terms_df aspects: ['C' 'F' 'P']
terms_df rows: 267985
rows after aspect filter: 79128
final dataset size: 30168


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[C] Loading pre-trained weights from /kaggle/input/esm-protein/transformers/default/1/esm2_C.pt
[C] Training with 24134 train, 6034 val samples


Train C Epoch 1: 100%|██████████| 1509/1509 [05:53<00:00,  4.27it/s]
Val C Epoch 1: 100%|██████████| 378/378 [00:31<00:00, 12.14it/s]


[C] Epoch 1: Train Loss: 0.0057, Val Loss: 0.0056


Train C Epoch 2: 100%|██████████| 1509/1509 [05:53<00:00,  4.27it/s]
Val C Epoch 2: 100%|██████████| 378/378 [00:31<00:00, 12.11it/s]


[C] Epoch 2: Train Loss: 0.0055, Val Loss: 0.0055


Train C Epoch 3: 100%|██████████| 1509/1509 [05:53<00:00,  4.27it/s]
Val C Epoch 3: 100%|██████████| 378/378 [00:31<00:00, 12.10it/s]


[C] Epoch 3: Train Loss: 0.0054, Val Loss: 0.0054


Train C Epoch 4: 100%|██████████| 1509/1509 [05:53<00:00,  4.27it/s]
Val C Epoch 4: 100%|██████████| 378/378 [00:31<00:00, 12.12it/s]


[C] Epoch 4: Train Loss: 0.0053, Val Loss: 0.0053


Train C Epoch 5: 100%|██████████| 1509/1509 [05:53<00:00,  4.27it/s]
Val C Epoch 5: 100%|██████████| 378/378 [00:31<00:00, 12.10it/s]

[C] Epoch 5: Train Loss: 0.0052, Val Loss: 0.0053





**Prediction On the test set**

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import gc
#### load the same pretrained model, then change the classifier's head
class GOClassifierHead(torch.nn.Module):
    def __init__(self, hidden_size, num_labels):
        super().__init__()
        self.classifier = torch.nn.Linear(hidden_size, num_labels)

    def forward(self, pooled):
        return self.classifier(pooled)

class TestDataset(Dataset):
    def __init__(self, seqs):
        self.pids = list(seqs.keys())
        self.seqs = list(seqs.values())

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        return self.pids[idx], self.seqs[idx]

test_seqs = {}
with open(f"{DATA_ROOT}/Test/testsuperset.fasta") as f:
    cur = None
    for line in f:
        if line.startswith(">"):
            cur = line[1:].split()[0]
            test_seqs[cur] = ""
        else:
            test_seqs[cur] += line.strip()

print("Loaded", len(test_seqs), "test sequences")


MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

encoder = AutoModel.from_pretrained(MODEL_NAME).cuda().half()
encoder.eval()
hidden_size = encoder.config.hidden_size

### loading heads

heads = {}
base_path = "/kaggle/working/"

for aspect in ["F", "P", "C"]:
    ckpt = torch.load(f"{base_path}/esm2_{aspect}.pt", map_location="cpu")
    head = GOClassifierHead(hidden_size, len(go_vocab[aspect]))
    head.load_state_dict({k: v for k, v in ckpt.items() if k.startswith("classifier.")})
    head.cuda().half().eval()
    heads[aspect] = head

del ckpt
torch.cuda.empty_cache()
gc.collect()


test_dataset = TestDataset(test_seqs)
loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)

go_terms = {a: list(go_vocab[a].keys()) for a in ["F", "P", "C"]}


BUFFER_LIMIT = 50_000
write_buffer = []

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

with open("submission.tsv", "w") as f, torch.inference_mode():
    for pids, seqs in tqdm(loader, desc="Predicting test proteins"):

        enc = tokenizer(
            list(seqs),
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

        enc = {k: v.cuda(non_blocking=True) for k, v in enc.items()}
        pooled = encoder(**enc).last_hidden_state[:, 0]

        TOPK = 200 ### since the average of GO/protein is about 370, later we'll adjust it

        for aspect in ["F", "P", "C"]:
            probs = torch.sigmoid(heads[aspect](pooled))
            terms = go_terms[aspect]

            vals, idxs = torch.topk(probs, TOPK, dim=1, largest=True, sorted=False)

            for i, pid in enumerate(pids):
                for j in idxs[i]:
                    write_buffer.append(
                        f"{pid}\t{terms[j]}\t{probs[i, j].item():.6f}\n"
                    )

        if len(write_buffer) >= BUFFER_LIMIT:
            f.write("".join(write_buffer))
            write_buffer.clear()

        del enc, pooled, probs, idxs
        torch.cuda.empty_cache()

    if write_buffer:
        f.write("".join(write_buffer))

print("✅ submission.tsv saved successfully!")


Loaded 224309 test sequences


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predicting test proteins:   0%|          | 0/14020 [00:00<?, ?it/s]

✅ submission.tsv saved successfully!
