In [None]:
!pip install torch pandas numpy h5py tqdm

In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
from src.dataset import ProteinDataset
from src.utils import train_model, test_model

In [None]:
# Load and prepare data
csv_file = 'data/disorder/strict.csv'
prott5_file = 'data/disorder/embeddings/unfiltered_all_prott5.h5'
prott5_res_file = 'data/disorder/embeddings/unfiltered_all_prott5_res.h5'
prostt5_file = 'data/disorder/embeddings/prostt5.h5'
esm_file = 'data/disorder/embeddings/unfiltered_all_esm2_3b.h5'
esm_res_file = 'data/disorder/embeddings/unfiltered_all_esm2_3b_res.h5'
chemical_shifts_df = pd.read_csv(csv_file)
chemical_shifts_df.describe()

In [7]:
#target_columns = ['C', 'CA', 'CB', 'HA', 'H', 'N', 'HB']
target_columns = ['N']
print(len(chemical_shifts_df))
chemical_shifts_df.dropna(inplace=True, subset=target_columns)
print(len(chemical_shifts_df))

# Split data into train, validation, and test sets
train_df, test_df = train_test_split(chemical_shifts_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.25, random_state=42)  # 0.25 x 0.8 = 0.2

# Normalize targets based on training data statistics
scaler = StandardScaler()
train_targets = train_df[target_columns]
scaler.fit(train_targets)

# Save the mean and std for later un-normalizing
means = scaler.mean_
stds = scaler.scale_

# Save the scaler object
joblib.dump(scaler, 'scaler.joblib')

# Apply normalization to the training targets
#train_df[target_columns] = scaler.transform(train_targets)

# Apply the same normalization to validation and test sets
# val_df[target_columns] = scaler.transform(val_df[target_columns])
# test_df[target_columns] = scaler.transform(test_df[target_columns])

# Create datasets
train_dataset = ProteinDataset(target_columns, train_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)
val_dataset = ProteinDataset(target_columns, val_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)
test_dataset = ProteinDataset(target_columns, test_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)

In [8]:
print('Trainng dataset length:', len(train_dataset))
print('Validation dataset length:', len(val_dataset))
print('Test dataset length:', len(test_dataset))

Trainng dataset length: 115233
Validation dataset length: 38411
Test dataset length: 38412


In [12]:
learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 128
num_epochs = 5

In [13]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=num_epochs, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)

ChemicalShiftsPredictor(
  (light_attention): LightAttention(
    (feature_convolution): Conv1d(1024, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
    (attention_convolution): Conv1d(1024, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
    (softmax): Softmax(dim=-1)
    (dropout): Dropout(p=0.25, inplace=False)
  )
  (fc_layers): Sequential(
    (0): Linear(in_features=5120, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)


  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
  torch.tensor(protein_prott5_emb, dtype=torch.float32).squeeze(),
Epoch 1/5:   1%|          | 39/3602 [00:06<10:16,  5.78batch/s] 


KeyboardInterrupt: 

In [None]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=num_epochs, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=True, use_protein_mean=False)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=False)

In [None]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=50, weight_decay=weight_decay, patience=patience, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=True)

In [None]:
torch.cuda.empty_cache()
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=50, weight_decay=weight_decay, patience=patience, batch_size=2048, use_prostt5=True, use_protein_mean=True, use_esm2=False)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=True, use_esm2=False)