<table style="background-color: transparent; border: none;">   
  <tr>  
    <td>
    <br/>   
    <img src="https://cdn.prod.website-files.com/6606dc3fd5f6645318003df4/6678476dc198b5a75b8c8873_ES_Logo_Black_5.png" width="75" alt="img"/>
    <br/>
    </td>     
    <td><h1>Project Research&nbsp&nbsp&nbsp</h1></td>   
    
  </tr>
</table>

</br>

> __Author:__ Grace Tang

> __Edited:__ `04.22.25`

---
<br/>

#### __Description__

Predicting the functions of proteins is an essential problem in the study of proteins. Many previous works developed computational methods that utilize protein sequence information to predict protein functions. With the advancement of `AlphaFold2`, accurate protein structure data has been available for hundreds of millions of proteins.

Incorporating structure information should be able to further boost the performance of protein function prediction methods. There are several potential ways to utilize the structure information: use computer-vision-like models to extract information from contact maps; use graph neural networks to encode 3D structures; use pretrained protein structure model (e.g., `ESM-IF`).

In addition, with the advancement of LLMs, the text annotations of proteins can also be leveraged as an extra information source for protein function prediction, e.g., use LLM as encoder to encode the text annotations of proteins. Utilizing the information from sequence, structure, and text annotations, we can develop a model that accurately predicts protein functions.

__Using a subset of the dataset used for [DeepFRI](https://github.com/flatironinstitute/DeepFRI) (a similar protein function prediction model) we aim to coallesce sequence data and structure data to predict protein function in the form of [EC Numbers](https://en.wikipedia.org/wiki/Enzyme_Commission_number).__

[[1]](ttps://www.biorxiv.org/content/10.1101/2022.11.29.518451v1), [[2]](https://www.nature.com/articles/s42003-024-07359-z), [[3]](https://www.nature.com/articles/s41467-021-23303-9), [[4]](https://www.biorxiv.org/content/10.1101/2024.05.14.594226v1 )


#### Methodology

> _This section relies on having_ `annotations.tsv`, `train.txt`, `validation.txt`, _and_ `sequences.fasta` _in your runtime_

We need a model that turns raw protein sequences into one or more EC numbers. Our goal is simplicity, speed, and the ability to predict multiple functions per protein.

__Embedding:__

We map each amino acid to a **128-dimensional vector** via a small, learnable embedding layer.  
  - Captures key sequence patterns  without the heavy storage or compute of large pre-trained models.  
  - Fits comfortably in GPU memory when training on ~15,000 proteins.

__Pooling:__

We apply **masked mean pooling** to average embeddings across the sequence.  
  - Converts variable-length sequences into fixed-size fingerprints.  
  - Ensures padding tokens do not skew the average, keeping representations true to the real residues.

__Model Architecture:__

A simple two-layer feedforward network with a 256-unit hidden layer and 50% dropout.  
  - Two layers os enough capacity to learn interactions between sequence features, but not so deep as to overfit.  
  - Dropout promotes generalization by preventing reliance on any single feature.

Each EC category is a separate output node with a sigmoid activation, trained using binary cross-entropy.  
  - Proteins can carry **multiple EC numbers**.  
  - Independent sigmoids allow the model to assign all applicable ECs without forcing a single choice.

__Training:__

- **Loss:** `BCEWithLogitsLoss` for multi-label classification  
- **Batch size:** `32`  
- **Optimizer:** `AdamW` (`lr=1e-3`, `weight_decay=1e-5`)  
- **Epochs:** `10` (monitor validation `F1` for early stopping)  

__Next Steps:__

- **Threshold Calibration:** Use validation data to pick optimal probability cutoffs for each EC.  
- **Feature Extensions:** Add graphical information from a contact map into the embedding __and__ pick a less expensive sequence embedding approach then `ESM` to do the sequence embeddings.


In [1]:
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.nn.utils.rnn import pad_sequence

In [3]:
def parse_fasta(path):
    seqs = {}
    with open(path) as f:
        curr, lines = None, []
        for line in f:
            if line.startswith('>'):
                if curr:
                    seqs[curr] = ''.join(lines)
                curr = line[1:].split()[0]
                lines = []
            else:
                lines.append(line.strip())
        if curr:
            seqs[curr] = ''.join(lines)
    return seqs

def parse_annotations(path):
    ann_map, all_ec = {}, set()
    with open(path) as f:
        for ln in f:
            if ln.startswith('### PDB-chain'):
                break
        for ln in f:
            if not ln.strip() or ln.startswith('#'): continue
            pid, ecs = ln.strip().split('\t')
            ec_list = [e.strip() for e in ecs.split(',') if e.strip()]
            ann_map[pid] = ec_list
            all_ec.update(ec_list)
    classes = sorted(all_ec)
    return ann_map, classes

def build_df(ids, seqs, ann_map, classes):
    rows = []
    for pid in ids:
        seq = seqs.get(pid, '')
        if not seq:
            print(f"Warning: {pid} not found")
            continue
        labels = [1 if ec in ann_map.get(pid, []) else 0 for ec in classes]
        rows.append({'id': pid, 'sequence': seq, 'labels': labels})
    return pd.DataFrame(rows)

In [4]:
train_ids = [l.strip() for l in open('train.txt') if l.strip()]
val_ids   = [l.strip() for l in open('validation.txt') if l.strip()]

In [5]:
db_seqs = parse_fasta('sequences.fasta')
ann_map, classes = parse_annotations('annotations.tsv')

In [6]:
df_train = build_df(train_ids, db_seqs, ann_map, classes)
df_val   = build_df(val_ids,   db_seqs, ann_map, classes)

In [7]:
aa_list = list("ACDEFGHIKLMNPQRSTVWY")
aa2idx = {aa: i+1 for i, aa in enumerate(aa_list)}
pad_idx = 0
unk_idx = len(aa_list) + 1

def seq_to_indices(seq):
    return [aa2idx.get(res, unk_idx) for res in seq]

class ProteinDataset(Dataset):
    def __init__(self, df):
        self.seqs = [seq_to_indices(s) for s in df['sequence']]
        self.labels = torch.tensor(df['labels'].tolist(), dtype=torch.float32)
    def __len__(self): return len(self.seqs)
    def __getitem__(self, idx):
        return torch.tensor(self.seqs[idx], dtype=torch.long), self.labels[idx]

In [8]:
def collate_fn(batch):
    seqs, labs = zip(*batch)
    padded = pad_sequence(seqs, batch_first=True, padding_value=pad_idx)
    return padded, torch.stack(labs)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
train_ds = ProteinDataset(df_train)
val_ds   = ProteinDataset(df_val)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, collate_fn=collate_fn)

In [9]:
class ECModel(nn.Module):
    def __init__(self, num_tokens, embed_dim, num_classes):
        super().__init__()
        self.embed = nn.Embedding(num_tokens+1, embed_dim, padding_idx=pad_idx)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # x: [B, L]
        mask = (x != pad_idx).float()
        emb = self.embed(x)                # [B, L, E]
        summed = (emb * mask.unsqueeze(-1)).sum(1)
        lens = mask.sum(1, keepdim=True)
        pooled = summed / lens             # [B, E]
        return self.classifier(pooled)

In [10]:
model = ECModel(num_tokens=unk_idx, embed_dim=128, num_classes=len(classes)).to(device)
opt = optim.AdamW(model.parameters(), lr=1e-3)
crit = nn.BCEWithLogitsLoss()

In [11]:
epochs = 10
for ep in range(1, epochs+1):
    model.train(); total_loss = 0
    for seqs, labs in train_loader:
        seqs, labs = seqs.to(device), labs.to(device)
        logits = model(seqs)
        loss = crit(logits, labs)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item() * seqs.size(0)
    print(f"Epoch {ep} Train Loss: {total_loss/len(train_ds):.4f}")

    model.eval(); preds, targs = [], []
    with torch.no_grad():
        for seqs, labs in val_loader:
            seqs = seqs.to(device)
            logit = model(seqs)
            prob = torch.sigmoid(logit).cpu()
            preds.append(prob);
            targs.append(labs)
    preds = torch.vstack(preds).numpy() > 0.5
    targs = torch.vstack(targs).numpy() > 0.5
    print(
        "Val P, R, F1=",
        precision_score(targs, preds, average='micro', zero_division=0),
        recall_score(targs, preds, average='micro', zero_division=0),
        f1_score(targs, preds, average='micro', zero_division=0)
    )

Epoch 1 Train Loss: 0.0390
Val P, R, F1= 0.6 0.0010141987829614604 0.0020249746878164025
Epoch 2 Train Loss: 0.0177
Val P, R, F1= 0.6666666666666666 0.0033806626098715348 0.006727211570803902
Epoch 3 Train Loss: 0.0171
Val P, R, F1= 0.8 0.004056795131845842 0.008072653884964682
Epoch 4 Train Loss: 0.0167
Val P, R, F1= 0.5757575757575758 0.006423258958755916 0.012704781009695755
Epoch 5 Train Loss: 0.0163
Val P, R, F1= 0.5652173913043478 0.013184584178498986 0.02576808721506442
Epoch 6 Train Loss: 0.0159
Val P, R, F1= 0.5769230769230769 0.010141987829614604 0.019933554817275746
Epoch 7 Train Loss: 0.0157
Val P, R, F1= 0.6428571428571429 0.009127789046653144 0.018
Epoch 8 Train Loss: 0.0154
Val P, R, F1= 0.6721311475409836 0.013860716700473293 0.02716131169261345
Epoch 9 Train Loss: 0.0151
Val P, R, F1= 0.6 0.018255578093306288 0.03543307086614173
Epoch 10 Train Loss: 0.0149
Val P, R, F1= 0.6704545454545454 0.019945909398242055 0.038739330269205514


In [15]:
model.eval()
all_true, all_pred = [], []
with torch.no_grad():
    for seqs, labs in val_loader:
        seqs = seqs.to(device)
        logits = model(seqs)
        probs  = torch.sigmoid(logits).cpu()
        all_true.append(labs)
        all_pred.append((probs > 0.5).int())

true = torch.cat(all_true, dim=0)
pred = torch.cat(all_pred, dim=0)

rows = []
for pid, trow, prow in zip(val_ids, true, pred):
    true_ecs = [c for c, flag in zip(classes, trow) if flag]
    pred_ecs = [c for c, flag in zip(classes, prow) if flag]
    rows.append({
        'strain': pid,
        'validation EC class': ';'.join(true_ecs) or '-',
        'predicted': ';'.join(pred_ecs) or '-'
    })

results_df = pd.DataFrame(rows, columns=['strain', 'validation EC class', 'predicted'])
print(results_df)

        strain validation EC class predicted
0       1EF9-A             4.1.1.-         -
1       4BYF-A             3.6.4.-         -
2       1MVP-A            3.4.23.-         -
3       2BIH-A             1.7.1.-         -
4     6UE0-AAA     4.3.3.-;4.3.3.7         -
...        ...                 ...       ...
1724    2F9R-A             4.6.1.-   3.2.1.-
1725    1QQW-A   1.11.1.-;1.11.1.6         -
1726    6AHR-E   3.1.26.-;3.1.26.5         -
1727    2BJI-A    3.1.3.-;3.1.3.25         -
1728    3BAL-A           1.13.11.-         -

[1729 rows x 3 columns]
