In [1]:
%load_ext autoreload
%autoreload 2
# %cd 'python/AggNet'
%ls

'AggNet: Advancing protein aggregation analysis through deep learning and protein language model - He.html'
 APNet_CPAD2_20aa_results.csv
 APNet_CPAD2_6aa_results.csv
 APNet_CPAD2_results.csv
 APNet_results.csv
 APRNet_results.csv
 [0m[01;34mcheckpoint[0m/
 CPAD2_combined_predictions.csv
 CPAD2.ipynb
 CPAD2_length_comparison.png
 CPAD2_length_stratified_metrics.csv
 CPAD2_overall_comparison.png
 CPAD2_overall_metrics.csv
 CPAD2_roc_curves.png
 [01;34mdata[0m/
 example.ipynb
 [01;34mlightning_logs[0m/
 [01;34mlightning_logs_og[0m/
 [01;34mmodel[0m/
 NNK_TRAINING_GUIDE.md
 ProjectStructure.md
 README.md
 [01;34mscript[0m/
 [01;34mutils[0m/


In [2]:
import pandas as pd
import torch
from pathlib import Path

from model.APNet.data_module import DataModule
from model.APNet.lightning_module import LightningModule
from model.APRNet import APRNet
from utils.dataset import ProteinDataset
from utils.file import read_fasta
from utils.lightning import LitModelInference, merge_batch_prediction
from utils.metric.bin_cls import compute_metrics

root_path: /nobackup/autodelete/usr/ssgardin/AggNet


In [3]:
# Utility helpers for model comparison
def evaluate_model_on_df(model, df, sample=None, batch_size=256, seed=7):
    eval_df = df.copy()
    if sample is not None and len(eval_df) > sample:
        eval_df = eval_df.sample(n=sample, random_state=seed)
    dataset = ProteinDataset(name="thompson_eval", metadata=eval_df.reset_index(drop=True))
    model.set_batch_size(batch_size=batch_size, num_workers=1)
    predictions = model.predict(dataset=dataset)
    merged = merge_batch_prediction(predictions)
    logits = merged['preds'].cpu()
    labels = merged['labels'].cpu()
    metrics = compute_metrics(logits, labels, softmax=True, only_df=True)
    probs = logits.softmax(dim=-1)[:, 1].numpy()
    eval_df = eval_df.reset_index(drop=True)
    eval_df['pred_prob'] = probs
    eval_df['pred_label'] = (eval_df['pred_prob'] >= 0.5).astype(int)
    return metrics, eval_df


def compare_peptide_predictions(models, peptides):
    results = []
    for model_name, model in models.items():
        preds = model.predict(sequence=peptides)
        merged = merge_batch_prediction(preds)
        probs = merged['preds'].cpu().softmax(dim=-1).numpy()[:, 1]
        results.append(pd.DataFrame({
            'peptide': peptides,
            f'{model_name}_prob': probs
        }))
    df = results[0]
    for extra_df in results[1:]:
        df = df.merge(extra_df, on='peptide')
    for model_name in models.keys():
        prob_col = f'{model_name}_prob'
        df[f'{model_name}_label'] = (df[prob_col] >= 0.5).map({True: 'amyloid', False: 'non-amyloid'})
    return df


def run_aprnet_profile(aprnet, sequence, params):
    labels, scores = aprnet([sequence], None, **params)
    return labels[0], scores[0]

In [4]:
# Common paths
data_dir = Path('./data')
amyhex_fasta = data_dir / 'AmyHex/Hex142.fasta'
nnk_metadata_path = data_dir / 'NNK/metadata.csv'
checkpoint_original = Path('./checkpoint/APNet.ckpt')
checkpoint_thompson = Path('./checkpoint/thompon_dataset_20aa_epoch14.ckpt')
default_batch_size = 256

# Amyloid Peptide Prediction

## load data

In [5]:
fasta_file = str(amyhex_fasta)
peptides, _ = read_fasta(fasta_file)
print(f"Loaded {len(peptides)} hexapeptides from AmyHex")

Loaded 142 hexapeptides from AmyHex


## Thompson dataset (NNK)

In [6]:
thompson_df = pd.read_csv(nnk_metadata_path)
eval_df = thompson_df[thompson_df['split'].isin(['valid', 'test'])].copy()
thompson_6aa = eval_df[eval_df['length'] <= 6].copy()
thompson_20aa = eval_df[eval_df['length'] <= 20].copy()

print(f"Thompson eval set: {len(eval_df)} sequences")
print(f"  ≤6 aa subset: {len(thompson_6aa)} sequences")
print(f"  ≤20 aa subset: {len(thompson_20aa)} sequences")

Thompson eval set: 12192 sequences
  ≤6 aa subset: 1542 sequences
  ≤20 aa subset: 12192 sequences


## Load APNet checkpoints for comparison

In [7]:
batch_size = default_batch_size
checkpoint_original_path = str(checkpoint_original)
checkpoint_thompson_path = str(checkpoint_thompson)

# Load Original APNet
APNet_original = LitModelInference(LightningModule, DataModule, checkpoint_original_path)
APNet_original.set_batch_size(batch_size=batch_size, num_workers=1)

# Load Thompson Dataset APNet (trained with 20 AA window)
APNet_thompson = LitModelInference(LightningModule, DataModule, checkpoint_thompson_path)
APNet_thompson.set_batch_size(batch_size=batch_size, num_workers=1)

models = {
    'original': APNet_original,
    'thompson20aa': APNet_thompson
}

[loading checkpoint]: checkpoint/APNet.ckpt


/home/ssgardin/.conda/envs/agnet/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ssgardin/.conda/envs/agnet/lib/python3.13/site ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
Seed set to 42


[loading checkpoint]: checkpoint/thompon_dataset_20aa_epoch14.ckpt


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
Seed set to 42


In [8]:
{name: str(next(model.ckpt_model.parameters()).device) for name, model in models.items()}

{'original': 'cuda:0', 'thompson20aa': 'cuda:0'}

## Run AmyHex predictions with both checkpoints

In [9]:
amyhex_comparison = compare_peptide_predictions(models, peptides)
amyhex_comparison.head()

[prepare custom predict dataset] 142
[self.predict_dataset] 142


  0%|          | 0/1 [00:00<?, ?it/s]

[prepare custom predict dataset] 142
[self.predict_dataset] 142


  0%|          | 0/1 [00:00<?, ?it/s]



Unnamed: 0,peptide,original_prob,thompson20aa_prob,original_label,thompson20aa_label
0,YVEYIG,0.282254,0.006226,non-amyloid,non-amyloid
1,IQIVYK,0.965807,0.050248,amyloid,non-amyloid
2,GYVIIK,0.977053,0.011451,amyloid,non-amyloid
3,STVIIL,0.961463,0.186739,amyloid,non-amyloid
4,SGVIIE,0.697697,0.050646,amyloid,non-amyloid


In [10]:
label_counts = {}
for model_name in models.keys():
    counts = amyhex_comparison[f'{model_name}_label'].value_counts().rename(model_name)
    label_counts[model_name] = counts
pd.DataFrame(label_counts).fillna(0)

Unnamed: 0,original,thompson20aa
non-amyloid,94,139
amyloid,48,3


## Thompson dataset evaluation (6 aa vs 20 aa subsets)

In [11]:
subset_sample = 4000  # set to None to use the full subset
metrics_rows = []
thompson_predictions = {}

for subset_name, subset_df in {'6aa': thompson_6aa, '20aa': thompson_20aa}.items():
    thompson_predictions[subset_name] = {}
    for model_name, model in models.items():
        metrics_df, pred_table = evaluate_model_on_df(model, subset_df, sample=subset_sample, batch_size=batch_size)
        metrics_df = metrics_df.assign(model=model_name, subset=subset_name)
        metrics_rows.append(metrics_df)
        thompson_predictions[subset_name][model_name] = pred_table

metrics_table = pd.concat(metrics_rows, ignore_index=True)
metrics_table[['subset', 'model', 'ACC', 'AUC', 'MCC', 'F1', 'SE', 'SP']]

[prepare custom predict dataset] 1542
[self.predict_dataset] 1542


  0%|          | 0/7 [00:00<?, ?it/s]

[prepare custom predict dataset] 1542
[self.predict_dataset] 1542


  0%|          | 0/7 [00:00<?, ?it/s]

[prepare custom predict dataset] 4000
[self.predict_dataset] 4000


  0%|          | 0/16 [00:00<?, ?it/s]

[prepare custom predict dataset] 4000
[self.predict_dataset] 4000


  0%|          | 0/16 [00:00<?, ?it/s]



Unnamed: 0,subset,model,ACC,AUC,MCC,F1,SE,SP
0,6aa,original,0.685,0.486,-0.031,0.157,0.187,0.778
1,6aa,thompson20aa,0.846,0.62,0.117,0.063,0.033,0.996
2,20aa,original,0.644,0.554,0.057,0.291,0.267,0.787
3,20aa,thompson20aa,0.788,0.761,0.401,0.435,0.298,0.973


In [12]:
def merge_subset_predictions(subset_name):
    orig = thompson_predictions[subset_name]['original'][['sequence', 'label', 'pred_prob']].rename(columns={'pred_prob': 'original_prob'})
    thomp = thompson_predictions[subset_name]['thompson20aa'][['sequence', 'label', 'pred_prob']].rename(columns={'pred_prob': 'thompson20aa_prob'})
    return orig.merge(thomp, on=['sequence', 'label'])

merge_subset_predictions('6aa').head()

Unnamed: 0,sequence,label,original_prob,thompson20aa_prob
0,LVF,1,0.980386,0.024242
1,FI,0,0.840692,0.065336
2,SYFC,1,0.733882,0.059166
3,RI,0,0.077235,0.096387
4,RYPDRS,0,2e-06,0.032377


In [13]:
merge_subset_predictions('20aa').head()

Unnamed: 0,sequence,label,original_prob,thompson20aa_prob
0,NFPRRSRVLKYCITLSSNHS,0,0.005828,0.114493
1,RYPLVKFQCYTGNYGKFGNL,1,0.071707,0.361847
2,YNNVCSEAEWS,0,0.449845,0.051732
3,NGGINNSWWGLIFLNLWPTI,0,0.109334,0.37226
4,SGFCQGGNSSRYMSII,1,0.547547,0.033944


# Protein Aggregation Profile comparison (APRNet)

In [14]:
APRNet_struct_params = {
    'beta': 3.36,
    'delta': 0.4,
    't_start': 0.51,
    't_expand': 0.37,
    't_patience': 9,
}
APRNet_seq_params = {
    't_start': 0.46,
    't_expand': 0.37,
    't_patience': 7,
}

## load data

In [15]:
# WFL VH
sequence = 'QVQLVQSGAEVKKPGSSVKVSCKASGGTFWFGAFTWVRQAPGQGLEWMGGIIPIFGLTNLAQNFQGRVTITADESTSTVYMELSSLRSEDTAVYYCARSSRIYDLNPSLTAYYDMDVWGQGTMVTVSS'
structure = None
checkpoint = './checkpoint/APNet.ckpt'

## load model

In [16]:
params = APRNet_struct_params if structure is not None else APRNet_seq_params
structure_input = None if structure is None else [structure]

aprnet_original_6 = APRNet.APRNet(APNet_original, pep_len=6, log=False)
aprnet_thompson_6 = APRNet.APRNet(APNet_thompson, pep_len=6, log=False)
aprnet_thompson_20 = APRNet.APRNet(APNet_thompson, pep_len=20, log=False)

aprnet_variants = {
    'original_pep6': aprnet_original_6,
    'thompson_pep6': aprnet_thompson_6,
    'thompson_pep20': aprnet_thompson_20
}
aprnet_variants

{'original_pep6': <model.APRNet.APRNet.APRNet at 0x7fcf33e4ae40>,
 'thompson_pep6': <model.APRNet.APRNet.APRNet at 0x7fcf35b8b390>,
 'thompson_pep20': <model.APRNet.APRNet.APRNet at 0x7fcf35b8b250>}

## run prediction

In [17]:
apr_results = {}
for name, apr in aprnet_variants.items():
    labels, scores = run_aprnet_profile(apr, sequence, params)
    apr_results[name] = {'labels': labels, 'scores': scores}

[prepare custom predict dataset] 123
[self.predict_dataset] 123


  0%|          | 0/1 [00:00<?, ?it/s]

[prepare custom predict dataset] 123
[self.predict_dataset] 123


  0%|          | 0/1 [00:00<?, ?it/s]

[prepare custom predict dataset] 109
[self.predict_dataset] 109


  0%|          | 0/1 [00:00<?, ?it/s]



## merge results

In [18]:
profile_df = pd.DataFrame({'residue': list(sequence)})
for name, result in apr_results.items():
    profile_df[f'{name}_score'] = result['scores']
    profile_df[f'{name}_APR'] = result['labels']

profile_df.head()

Unnamed: 0,residue,original_pep6_score,original_pep6_APR,thompson_pep6_score,thompson_pep6_APR,thompson_pep20_score,thompson_pep20_APR
0,Q,0.078544,0,0.001447,0,0.003865,0
1,V,0.149857,0,0.013708,0,0.008118,0
2,Q,0.154817,0,0.02188,0,0.008118,0
3,L,0.319754,0,0.057259,0,0.016237,0
4,V,0.322026,0,0.07993,0,0.016237,0


In [None]:
# plot apr score profiles
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
for name in aprnet_variants.keys():
    plt.plot(profile_df['residue'], profile_df[f'{name}_score'], label=f'{name} Score')
plt.xlabel('Residue Position')
plt.ylabel('APR Score')
plt.title('APR Score Profiles by APRNet Variant')
plt.legend()
plt.show()  