In [None]:
!pip install torch pandas numpy h5py tqdm

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import h5py
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
from src.dataset import ProteinDataset
from src.model import ChemicalShiftsPredictor
from src.utils import train_model, test_model

In [27]:
# Load and prepare data
csv_file = 'data/disorder/strict.csv'
prott5_file = 'data/disorder/embeddings/unfiltered_all_prott5.h5'
prott5_res_file = 'data/disorder/embeddings/unfiltered_all_prott5_res.h5'
prostt5_file = 'data/disorder/embeddings/prostt5.h5'
chemical_shifts_df = pd.read_csv(csv_file)
chemical_shifts_df.describe()

Unnamed: 0.1,Unnamed: 0,entryID,stID,entity_assemID,entityID,seq_index,k,zscores,pscores,C,CA,CB,HA,H,N,HB
count,217493.0,217493.0,217493.0,217493.0,217493.0,217493.0,217493.0,209016.0,209016.0,154402.0,203289.0,184908.0,151129.0,192492.0,192056.0,133093.0
mean,950.551788,25516.888194,1.069805,1.086191,1.028737,72.199183,16.539553,11.35494,0.16265,175.888378,56.793136,38.003339,4.343245,8.27458,119.59808,2.311165
std,551.256443,10008.917334,0.445658,0.467657,0.175259,58.331021,4.661,4.08679,0.229569,5.180103,4.960242,12.878662,0.870714,0.6885,5.324304,2.330996
min,0.0,7349.0,1.0,1.0,1.0,1.0,0.0,-5.0427,-0.0,0.0,4.186,-34.477,-2.4295,-0.914,0.0,-1.626
25%,470.0,17602.0,1.0,1.0,1.0,29.0,15.0,10.6949,0.046,174.517,54.36,30.481,4.019,7.882,116.803,1.7215
50%,943.0,25387.0,1.0,1.0,1.0,60.0,18.0,12.6262,0.0766,175.996,56.971,34.6515,4.299,8.273,120.034,2.042
75%,1431.0,30293.0,1.0,1.0,1.0,99.0,21.0,13.9357,0.1363,177.528,59.859,41.354,4.619,8.654,122.818,2.8125
max,1909.0,51871.0,8.0,8.0,4.0,471.0,21.0,16.1497,0.9995,187.57,176.857,137.868,173.538,119.596,181.029,790.2155


In [28]:
#target_columns = ['C', 'CA', 'CB', 'HA', 'H', 'N', 'HB']
target_columns = ['N']
print(len(chemical_shifts_df))
chemical_shifts_df.dropna(inplace=True, subset=target_columns)
print(len(chemical_shifts_df))

# Split data into train, validation, and test sets
train_df, test_df = train_test_split(chemical_shifts_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.25, random_state=42)  # 0.25 x 0.8 = 0.2

# Normalize targets based on training data statistics
scaler = StandardScaler()
train_targets = train_df[target_columns]
scaler.fit(train_targets)

# Save the mean and std for later un-normalizing
means = scaler.mean_
stds = scaler.scale_

# Save the scaler object
joblib.dump(scaler, 'scaler.joblib')

# Apply normalization to the training targets
#train_df[target_columns] = scaler.transform(train_targets)

# Apply the same normalization to validation and test sets
# val_df[target_columns] = scaler.transform(val_df[target_columns])
# test_df[target_columns] = scaler.transform(test_df[target_columns])

# Create datasets
train_dataset = ProteinDataset(target_columns, train_df, prott5_file, prott5_res_file, prostt5_file)
val_dataset = ProteinDataset(target_columns, val_df, prott5_file, prott5_res_file, prostt5_file)
test_dataset = ProteinDataset(target_columns, test_df, prott5_file, prott5_res_file, prostt5_file)

217493
192056


In [29]:
print('Trainng dataset length:', len(train_dataset))
print('Validation dataset length:', len(val_dataset))
print('Test dataset length:', len(test_dataset))

Trainng dataset length: 115233
Validation dataset length: 38411
Test dataset length: 38412


In [30]:
# Assuming you have defined model, train_loader, val_loader, test_loader
learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 1024
num_epochs = 5

trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=num_epochs, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=False, use_protein_mean=False)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=False, use_protein_mean=False)

Epoch 1/5: 100%|██████████| 113/113 [01:12<00:00,  1.55batch/s]


Epoch 1, Train Loss: 3568.5859, Validation Loss: 252.9527


Epoch 2/5: 100%|██████████| 113/113 [01:13<00:00,  1.55batch/s]


Epoch 2, Train Loss: 160.4907, Validation Loss: 108.3401


Epoch 3/5: 100%|██████████| 113/113 [01:10<00:00,  1.60batch/s]


Epoch 3, Train Loss: 77.9913, Validation Loss: 58.1687


Epoch 4/5: 100%|██████████| 113/113 [01:10<00:00,  1.60batch/s]


Epoch 4, Train Loss: 45.2795, Validation Loss: 37.5724


Epoch 5/5: 100%|██████████| 113/113 [01:11<00:00,  1.58batch/s]


Epoch 5, Train Loss: 30.6636, Validation Loss: 27.9441
Test Loss: 28.4074


In [31]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=num_epochs, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=True, use_protein_mean=False)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=False)

Epoch 1/5:   0%|          | 0/113 [00:00<?, ?batch/s]

Epoch 1/5: 100%|██████████| 113/113 [01:09<00:00,  1.63batch/s]


Epoch 1, Train Loss: 2809.6434, Validation Loss: 133.9024


Epoch 2/5: 100%|██████████| 113/113 [01:07<00:00,  1.67batch/s]


Epoch 2, Train Loss: 75.2289, Validation Loss: 49.4624


Epoch 3/5: 100%|██████████| 113/113 [01:08<00:00,  1.66batch/s]


Epoch 3, Train Loss: 36.5987, Validation Loss: 31.3096


Epoch 4/5: 100%|██████████| 113/113 [01:09<00:00,  1.62batch/s]


Epoch 4, Train Loss: 25.2554, Validation Loss: 24.5307


Epoch 5/5: 100%|██████████| 113/113 [01:11<00:00,  1.59batch/s]


Epoch 5, Train Loss: 19.9593, Validation Loss: 20.8266
Test Loss: 21.2452


In [32]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=50, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)

Epoch 1/5:   0%|          | 0/113 [00:00<?, ?batch/s]

Epoch 1/5: 100%|██████████| 113/113 [01:09<00:00,  1.62batch/s]


Epoch 1, Train Loss: 2277.2400, Validation Loss: 46.7049


Epoch 2/5: 100%|██████████| 113/113 [01:13<00:00,  1.54batch/s]


Epoch 2, Train Loss: 24.2542, Validation Loss: 16.0128


Epoch 3/5: 100%|██████████| 113/113 [01:11<00:00,  1.58batch/s]


Epoch 3, Train Loss: 13.1453, Validation Loss: 12.5087


Epoch 4/5: 100%|██████████| 113/113 [01:10<00:00,  1.59batch/s]


Epoch 4, Train Loss: 10.6051, Validation Loss: 11.2528


Epoch 5/5: 100%|██████████| 113/113 [01:09<00:00,  1.63batch/s]


Epoch 5, Train Loss: 9.2809, Validation Loss: 10.4593
Test Loss: 10.8457
