<a href="https://colab.research.google.com/github/yifan-grace-tang/final-project/blob/main/Renee/renee's_work_from_kaize_grace_mlcb_project_04_23.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<table style="background-color: transparent; border: none;">   
  <tr>  
    <td>
    <br/>   
    <img src="https://cdn.prod.website-files.com/6606dc3fd5f6645318003df4/6678476dc198b5a75b8c8873_ES_Logo_Black_5.png" width="75" alt="img"/>
    <br/>
    </td>     
    <td><h1>Project Research&nbsp&nbsp&nbsp</h1></td>   
    
  </tr>
</table>

</br>

> __Author:__ Grace Tang

> __Edited:__ `04.22.25`

---
<br/>

#### __Description__

Predicting the functions of proteins is an essential problem in the study of proteins. Many previous works developed computational methods that utilize protein sequence information to predict protein functions. With the advancement of `AlphaFold2`, accurate protein structure data has been available for hundreds of millions of proteins.

Incorporating structure information should be able to further boost the performance of protein function prediction methods. There are several potential ways to utilize the structure information: use computer-vision-like models to extract information from contact maps; use graph neural networks to encode 3D structures; use pretrained protein structure model (e.g., `ESM-IF`).

In addition, with the advancement of LLMs, the text annotations of proteins can also be leveraged as an extra information source for protein function prediction, e.g., use LLM as encoder to encode the text annotations of proteins. Utilizing the information from sequence, structure, and text annotations, we can develop a model that accurately predicts protein functions.

__Using a subset of the dataset used for [DeepFRI](https://github.com/flatironinstitute/DeepFRI) (a similar protein function prediction model) we aim to coallesce sequence data and structure data to predict protein function in the form of [EC Numbers](https://en.wikipedia.org/wiki/Enzyme_Commission_number).__

[[1]](ttps://www.biorxiv.org/content/10.1101/2022.11.29.518451v1), [[2]](https://www.nature.com/articles/s42003-024-07359-z), [[3]](https://www.nature.com/articles/s41467-021-23303-9), [[4]](https://www.biorxiv.org/content/10.1101/2024.05.14.594226v1 )


#### Methodology

> _This section relies on having_ `annotations.tsv`, `train.txt`, `validation.txt`, _and_ `sequences.fasta` _in your runtime_

We need a model that turns raw protein sequences into one or more EC numbers. Our goal is simplicity, speed, and the ability to predict multiple functions per protein.

__Embedding:__

We map each amino acid to a **128-dimensional vector** via a small, learnable embedding layer.  
  - Captures key sequence patterns  without the heavy storage or compute of large pre-trained models.  
  - Fits comfortably in GPU memory when training on ~15,000 proteins.

__Pooling:__

We apply **masked mean pooling** to average embeddings across the sequence.  
  - Converts variable-length sequences into fixed-size fingerprints.  
  - Ensures padding tokens do not skew the average, keeping representations true to the real residues.

__Model Architecture:__

A simple two-layer feedforward network with a 256-unit hidden layer and 50% dropout.  
  - Two layers os enough capacity to learn interactions between sequence features, but not so deep as to overfit.  
  - Dropout promotes generalization by preventing reliance on any single feature.

Each EC category is a separate output node with a sigmoid activation, trained using binary cross-entropy.  
  - Proteins can carry **multiple EC numbers**.  
  - Independent sigmoids allow the model to assign all applicable ECs without forcing a single choice.

__Training:__

- **Loss:** `BCEWithLogitsLoss` for multi-label classification  
- **Batch size:** `32`  
- **Optimizer:** `AdamW` (`lr=1e-3`, `weight_decay=1e-5`)  
- **Epochs:** `10` (monitor validation `F1` for early stopping)  

__Next Steps:__

- **Threshold Calibration:** Use validation data to pick optimal probability cutoffs for each EC.  
- **Feature Extensions:** Add graphical information from a contact map into the embedding __and__ pick a less expensive sequence embedding approach then `ESM` to do the sequence embeddings.


In [1]:
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.nn.utils.rnn import pad_sequence

In [2]:
def parse_fasta(path):
    seqs = {}
    with open(path) as f:
        curr, lines = None, []
        for line in f:
            if line.startswith('>'):
                if curr:
                    seqs[curr] = ''.join(lines)
                curr = line[1:].split()[0]
                lines = []
            else:
                lines.append(line.strip())
        if curr:
            seqs[curr] = ''.join(lines)
    return seqs

def parse_annotations(path):
    ann_map, all_ec = {}, set()
    with open(path) as f:
        for ln in f:
            if ln.startswith('### PDB-chain'):
                break
        for ln in f:
            if not ln.strip() or ln.startswith('#'): continue
            pid, ecs = ln.strip().split('\t')
            ec_list = [e.strip() for e in ecs.split(',') if e.strip()]
            ann_map[pid] = ec_list
            all_ec.update(ec_list)
    classes = sorted(all_ec)
    return ann_map, classes

def build_df(ids, seqs, ann_map, classes):
    rows = []
    for pid in ids:
        seq = seqs.get(pid, '')
        if not seq:
            print(f"Warning: {pid} not found")
            continue
        labels = [1 if ec in ann_map.get(pid, []) else 0 for ec in classes]
        rows.append({'id': pid, 'sequence': seq, 'labels': labels})
    return pd.DataFrame(rows)

In [None]:
from google.colab import files
uploaded = files.upload()


MessageError: RangeError: Maximum call stack size exceeded.

In [3]:
train_ids = [l.strip() for l in open('train.txt') if l.strip()]
val_ids   = [l.strip() for l in open('validation.txt') if l.strip()]

In [4]:
db_seqs = parse_fasta('sequences.fasta')
ann_map, classes = parse_annotations('annotations.tsv')

In [5]:
df_train = build_df(train_ids, db_seqs, ann_map, classes)
df_val   = build_df(val_ids,   db_seqs, ann_map, classes)

In [34]:
!pip install --upgrade biopython



In [None]:
import os
import requests
import numpy as np
from Bio.PDB import MMCIFParser as CIFParser
from Bio.PDB.Polypeptide import is_aa

def download_cif(pdb_id, out_dir="cifs"):
    # Download mmCIF file
    pdb_id = pdb_id.lower()
    url = f"https://files.rcsb.org/download/{pdb_id}.cif"
    os.makedirs(out_dir, exist_ok=True)
    file_path = os.path.join(out_dir, f"{pdb_id}.cif")

    if not os.path.exists(file_path):
        r = requests.get(url)
        if r.status_code == 200:
            with open(file_path, "w") as f:
                f.write(r.text)
        else:
            raise ValueError(f"Could not download PDBx/mmCIF ID: {pdb_id}")
    return file_path

def compute_contact_map(cif_file, chain_id="A", threshold=8.0):
    # Get contact map
    parser = CIFParser()
    structure = parser.get_structure("protein", cif_file)
    model = structure[0]
    chain = model[chain_id]

    coords = []
    sequence = []
    for res in chain:
        if is_aa(res, standard=True) and "CA" in res:
            coords.append(res["CA"].coord)
            sequence.append(res.get_resname())

    coords = np.array(coords)
    if coords.shape[0] == 0:
        raise ValueError("No valid C-alpha atoms found")

    dist_matrix = np.linalg.norm(coords[:, None] - coords[None, :], axis=-1)
    cmap = (dist_matrix < threshold).astype(np.uint8)
    return cmap, sequence

def save_as_npz(pdb_id, chain_id="A", out_dir="npz", threshold=8.0):
    # save .npz
    os.makedirs(out_dir, exist_ok=True)
    cif_file = download_cif(pdb_id, out_dir="cifs")
    cmap, sequence = compute_contact_map(cif_file, chain_id=chain_id, threshold=threshold)

    outfile = os.path.join(out_dir, f"{pdb_id.upper()}-{chain_id}.npz")
    np.savez_compressed(outfile, cmap=cmap, sequence=sequence)

def extract_pdb_and_chain(pdb_with_chain):
    pdb_id, chain_id = pdb_with_chain.split('-')
    return pdb_id, chain_id

def process_pdbs(pdb_ids_with_chains, out_dir="npz"):
    for pdb_with_chain in pdb_ids_with_chains:
        try:
            pdb_id, chain_id = extract_pdb_and_chain(pdb_with_chain)
            save_as_npz(pdb_id, chain_id=chain_id, out_dir=out_dir)
        except ValueError as e:
            print(f"Skipping {pdb_with_chain}: {e}")
            continue  # Skip to the next protein

process_pdbs(train_ids, out_dir="npz")



In [6]:
aa_list = list("ACDEFGHIKLMNPQRSTVWY")
aa2idx = {aa: i+1 for i, aa in enumerate(aa_list)}
pad_idx = 0
unk_idx = len(aa_list) + 1

def seq_to_indices(seq):
    return [aa2idx.get(res, unk_idx) for res in seq]

class ProteinDataset(Dataset):
    def __init__(self, df):
        self.seqs = [seq_to_indices(s) for s in df['sequence']]
        self.labels = torch.tensor(df['labels'].tolist(), dtype=torch.float32)
    def __len__(self): return len(self.seqs)
    def __getitem__(self, idx):
        return torch.tensor(self.seqs[idx], dtype=torch.long), self.labels[idx]

In [7]:
def collate_fn(batch):
    seqs, labs = zip(*batch)
    padded = pad_sequence(seqs, batch_first=True, padding_value=pad_idx)
    return padded, torch.stack(labs)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
train_ds = ProteinDataset(df_train)
val_ds   = ProteinDataset(df_val)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, collate_fn=collate_fn)

In [8]:
class ECModel(nn.Module):
    def __init__(self, num_tokens, embed_dim, num_classes):
        super().__init__()
        self.embed = nn.Embedding(num_tokens+1, embed_dim, padding_idx=pad_idx)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # x: [B, L]
        mask = (x != pad_idx).float()
        emb = self.embed(x)                # [B, L, E]
        summed = (emb * mask.unsqueeze(-1)).sum(1)
        lens = mask.sum(1, keepdim=True)
        pooled = summed / lens             # [B, E]
        return self.classifier(pooled)

In [9]:
model = ECModel(num_tokens=unk_idx, embed_dim=128, num_classes=len(classes)).to(device)
opt = optim.AdamW(model.parameters(), lr=1e-3)
crit = nn.BCEWithLogitsLoss()

In [10]:
epochs = 10
for ep in range(1, epochs+1):
    model.train(); total_loss = 0
    for seqs, labs in train_loader:
        seqs, labs = seqs.to(device), labs.to(device)
        logits = model(seqs)
        loss = crit(logits, labs)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item() * seqs.size(0)
    print(f"Epoch {ep} Train Loss: {total_loss/len(train_ds):.4f}")

    model.eval(); preds, targs = [], []
    with torch.no_grad():
        for seqs, labs in val_loader:
            seqs = seqs.to(device)
            logit = model(seqs)
            prob = torch.sigmoid(logit).cpu()
            preds.append(prob);
            targs.append(labs)
    preds = torch.vstack(preds).numpy() > 0.5
    targs = torch.vstack(targs).numpy() > 0.5
    print(
        "Val P, R, F1=",
        precision_score(targs, preds, average='micro', zero_division=0),
        recall_score(targs, preds, average='micro', zero_division=0),
        f1_score(targs, preds, average='micro', zero_division=0)
    )

Epoch 1 Train Loss: 0.0393
Val P, R, F1= 1.0 0.000676132521974307 0.0013513513513513514
Epoch 2 Train Loss: 0.0177
Val P, R, F1= 0.6 0.004056795131845842 0.008059100067159167
Epoch 3 Train Loss: 0.0172
Val P, R, F1= 0.5454545454545454 0.002028397565922921 0.0040417649040080834
Epoch 4 Train Loss: 0.0168
Val P, R, F1= 0.6 0.008113590263691683 0.016010673782521682
Epoch 5 Train Loss: 0.0164
Val P, R, F1= 0.6304347826086957 0.00980392156862745 0.019307589880159785
Epoch 6 Train Loss: 0.0160
Val P, R, F1= 0.5538461538461539 0.012170385395537525 0.023817399933840556
Epoch 7 Train Loss: 0.0158
Val P, R, F1= 0.6666666666666666 0.012846517917511832 0.025207296849087894
Epoch 8 Train Loss: 0.0155
Val P, R, F1= 0.625 0.015212981744421906 0.0297029702970297
Epoch 9 Train Loss: 0.0152
Val P, R, F1= 0.675 0.018255578093306288 0.03554970375246873
Epoch 10 Train Loss: 0.0150
Val P, R, F1= 0.6777777777777778 0.020622041920216362 0.04002624671916011


In [19]:
import numpy as np

def find_best_thresholds_per_ec(y_true, y_probs, metric='f1'):
    thresholds = np.linspace(0.1, 0.9, 81)
    best_thresholds = []

    for i in range(y_true.shape[1]):
        best_score = -1
        best_t = 0.5
        for t in thresholds:
            preds = (y_probs[:, i] >= t).astype(int)
            if metric == 'f1':
                score = f1_score(y_true[:, i], preds, zero_division=0)
            elif metric == 'precision':
                score = precision_score(y_true[:, i], preds)
            elif metric == 'recall':
                score = recall_score(y_true[:, i], preds)
            if score > best_score:
                best_score = score
                best_t = t
        best_thresholds.append(best_t)

    return np.array(best_thresholds)

In [20]:
model.eval()
all_true, all_probs = [], []
with torch.no_grad():
    for seqs, labs in val_loader:
        seqs = seqs.to(device)
        logits = model(seqs)
        probs  = torch.sigmoid(logits).cpu()  # probabilities
        all_true.append(labs)
        all_probs.append(probs)

y_val_true = torch.cat(all_true, dim=0).numpy()
y_val_probs = torch.cat(all_probs, dim=0).numpy()
best_thresholds = find_best_thresholds_per_ec(y_val_true, y_val_probs, metric='f1')

true = torch.cat(all_true, dim=0)
pred = (torch.cat(all_probs, dim=0).numpy() >= best_thresholds).astype(int)

rows = []
for pid, trow, prow in zip(val_ids, true, pred):
    true_ecs = [c for c, flag in zip(classes, trow) if flag]
    pred_ecs = [c for c, flag in zip(classes, prow) if flag]
    rows.append({
        'strain': pid,
        'validation EC class': ';'.join(true_ecs) or '-',
        'predicted': ';'.join(pred_ecs) or '-'
    })

results_df = pd.DataFrame(rows, columns=['strain', 'validation EC class', 'predicted'])
print(results_df)

        strain validation EC class                  predicted
0       1EF9-A             4.1.1.-                          -
1       4BYF-A             3.6.4.-                          -
2       1MVP-A            3.4.23.-                    2.7.7.-
3       2BIH-A             1.7.1.-                          -
4     6UE0-AAA     4.3.3.-;4.3.3.7                          -
...        ...                 ...                        ...
1724    2F9R-A             4.6.1.-                    3.2.1.-
1725    1QQW-A   1.11.1.-;1.11.1.6                          -
1726    6AHR-E   3.1.26.-;3.1.26.5                    2.1.1.-
1727    2BJI-A    3.1.3.-;3.1.3.25                    1.1.1.-
1728    3BAL-A           1.13.11.-  1.15.1.1;3.2.1.-;3.2.1.14

[1729 rows x 3 columns]


In [21]:
num_nothing_predicted = (pred.sum(axis=1) == 0).sum()
print(f"Examples with no EC predicted: {num_nothing_predicted}/{len(pred)}")

Examples with no EC predicted: 1013/1729


In [22]:
pip install transformers torch


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
pip install transformers torch scikit-learn




In [23]:
with open('validation.txt', 'r') as f:
    validation_ids = [line.strip() for line in f.readlines()]

validation_sequences = [db_seqs[protein_id] for protein_id in validation_ids if protein_id in db_seqs]
print(f"First 5 validation sequences: {validation_sequences[:5]}")


First 5 validation sequences: ['MSYQYVNVVTINKVAVIEFNYGRKLNALSKVFIDDLMQALSDLNRPEIRCIILRAPSGSKVFSAGHDIHELPSGGRDPLSYDDPLRQITRMIQKFPKPIISMVEGSVWGGAFEMIMSSDLIIAASTSTFSMTPVNLGVPYNLVGIHNLTRDAGFHIVKELIFTASPITAQRALAVGILNHVVEVEELEDFTLQMAHHISEKAPLAIAVIKEELRVLGEAHTMNSDEFERIQGMRRAVYDSEDYQEGMNAFLEKRKPNFVGH', 'MESALTARDRVGVQDFVLLENFTSEAAFIENLRRRFRENLIYTYIGPVLVSVNPYRDLQIYSRQHMERYRGVSFYEVPPHLFAVADTVYRALRTERRDQAVMISGESGAGKTEATKRLLQFYAETCPAPERGGAVRDRLLQSNPVLEAFGNAKTLRNDNSSRFGKYMDVQFDFKGAPVGGHILSYLLEKSRVVHQNHGERNFHIFYQLLEGGEEETLRRLGLERNPQSYLYLVKGQCAKVSSINDKSDWKVVRKALTVIDFTEDEVEDLLSIVASVLHLGNIHFAANEESNAQVTTENQLKYLTRLLSVEGSTLREALTHRKIIAKGEELLSPLNLEQAAYARDALAKAVYSRTFTWLVGKINRSLASKDVESPSWRSTTVLGLLDIYGFEVFQHNSFEQFCINYCNEKLQQLFIELTLKSEQEEYEAEGIAWEPVQYFNNKIICDLVEEKFKGIISILDEECLRPGEATDLTFLEKLEDTVKHHPHFLTHKLADQRTRKSLGRGEFRLLHYAGEVTYSVTGFLDKNNDLLFRNLKETMCSSKNPIMSQCFDRSELSDKKRPETVATQFKMSLLQLVEILQSKEPAYVRCIKPNDAKQPGRFDEVLIRHQVKYLGLLENLRVRRAGFAYRRKYEAFLQRYKSLCPETWPTWAGRPQDGVAVLVRHLGYKPEEYKMGRTKIFIRFPKTLFATEDALEVRRQSLA

In [24]:
pip install biopython


Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m53.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [25]:
from transformers import BertTokenizer, BertModel
import torch
from Bio import SeqIO

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert")
model = BertModel.from_pretrained("Rostlab/prot_bert")

def get_embeddings_batch(sequences, batch_size=8):
    embeddings = []
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i + batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        batch_embeddings = outputs.last_hidden_state.mean(dim=1)
        embeddings.append(batch_embeddings)
    return torch.cat(embeddings, dim=0)

train_sequences = [db_seqs[protein_id] for protein_id in train_ids if protein_id in db_seqs]
validation_sequences = [db_seqs[protein_id] for protein_id in val_ids if protein_id in db_seqs]

train_embeddings = get_embeddings_batch(train_sequences, batch_size=8)
validation_embeddings = get_embeddings_batch(validation_sequences, batch_size=8)

print(f"Train embeddings shape: {train_embeddings.shape}")
print(f"Validation embeddings shape: {validation_embeddings.shape}")



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/361 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Train embeddings shape: torch.Size([15551, 1024])
Validation embeddings shape: torch.Size([1729, 1024])


In [None]:
import torch

label_to_idx = {label: idx for idx, label in enumerate(classes)}

def encode_labels(labels, label_to_idx):
    encoded_labels = torch.zeros(len(label_to_idx), dtype=torch.float32)
    for label in labels:
        if label in label_to_idx:
            encoded_labels[label_to_idx[label]] = 1.0
    return encoded_labels

train_labels = [encode_labels(ann_map[protein_id], label_to_idx) for protein_id in train_ids if protein_id in ann_map]
validation_labels = [encode_labels(ann_map[protein_id], label_to_idx) for protein_id in validation_ids if protein_id in ann_map]

train_labels = torch.stack(train_labels)
validation_labels = torch.stack(validation_labels)

print(f"Sample train labels: {train_labels[:5]}")
print(f"Sample validation labels: {validation_labels[:5]}")



In [None]:
print(f"Train labels (sample): {train_labels[:5]}")
print(f"Validation labels (sample): {validation_labels[:5]}")

train_non_zero = torch.sum(train_labels != 0).item()
validation_non_zero = torch.sum(validation_labels != 0).item()

print(f"Number of non-zero labels in train set: {train_non_zero}")
print(f"Number of non-zero labels in validation set: {validation_non_zero}")


In [None]:
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score

class MultiLabelClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MultiLabelClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        return self.sigmoid(x)

model = MultiLabelClassifier(input_dim=1024, output_dim=len(classes))
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(train_embeddings)
    loss = criterion(outputs, train_labels)
    loss.backward()
    optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

model.eval()
y_pred = model(validation_embeddings).detach().numpy()
y_pred_bin = (y_pred > 0.5).astype(int)

accuracy = accuracy_score(validation_labels.numpy(), y_pred_bin)
f1 = f1_score(validation_labels.numpy(), y_pred_bin, average='micro')

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Validation F1 Score: {f1:.4f}")


In [None]:
outputs = model(validation_embeddings).detach().numpy()
print(f"Sample validation outputs: {outputs[:5]}")


In [None]:
import matplotlib.pyplot as plt

plt.hist(train_labels.numpy().sum(axis=1), bins=50)
plt.title('Distribution of Non-Zero Labels in Training Set')
plt.xlabel('Number of Non-Zero Labels per Protein')
plt.ylabel('Frequency')
plt.show()
