In [11]:
import torch 
import torch.nn as nn
import os 
import sys
# Get the absolute path of the project root
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from proteinshake.datasets import ProteinLigandInterfaceDataset, AlphaFoldDataset, GeneOntologyDataset, ProteinFamilyDataset
from src.utils import data_utils as dtu
from torch.utils.data import DataLoader, Dataset, Subset
from src.models.LSTMVae import LSTMVae
# from src.models.basicVae_pyt import BasicVae
import numpy as np
from src.dataset_classes.sequenceDataset import *
from sklearn.model_selection import KFold
import random
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [16]:
dataset = ProteinFamilyDataset(root='../data').to_point().torch()
indices = list(range(100))  # Example list of indices
subset_size = 10  # Size of the random subset

random_subset = random.sample(indices, subset_size)  # Get random subset
remaining_subset = list(set(indices) - set(random_subset))  # Get remaining indices

print("Random Subset:", random_subset)
print("Remaining Subset:", remaining_subset)

idx_list = range(len(dataset))
subset_size = int(len(dataset)//10)
val_idx = random.sample(idx_list, subset_size)  # Get random subset
train_idx = list(set(idx_list) - set(val_idx))

Random Subset: [17, 71, 48, 98, 46, 79, 4, 66, 61, 35]
Remaining Subset: [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 63, 64, 65, 67, 68, 69, 70, 72, 73, 74, 75, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 99]


In [17]:
s = 500
train_subset = SequenceDataset(Subset(dataset, train_idx), s)
val_subset = SequenceDataset(Subset(dataset, val_idx), s)

100%|██████████| 27999/27999 [00:09<00:00, 2911.92it/s]
100%|██████████| 3110/3110 [00:01<00:00, 2382.02it/s]


## Data Loader

In [18]:
latent_dim = 64
epochs = 100
lr = 0.001
batch_size = 128
train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_subset,batch_size=batch_size, shuffle=False)
x_dim = train_subset[0].shape[0]
if torch.cuda.is_available():
    torch.cuda.current_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
dummy_input = next(iter(train_dataloader))

In [97]:
torch.sum(dummy_input.argmax(-1) != 20, axis = -1).shape

torch.Size([128])

In [37]:
encoder = nn.LSTM(21, 32, 2, batch_first=True).eval()

In [45]:
_, (x_hn, _) = encoder(dummy_input)


In [44]:
x_hn.shape

torch.Size([2, 128, 32])

In [41]:
dummy_out[1][1].shape

torch.Size([2, 128, 32])

In [None]:
dummy_out[0][0] == dummy_out[1][0]

RuntimeError: The size of tensor a (500) must match the size of tensor b (128) at non-singleton dimension 1

In [37]:
model = LSTMVae(latent_dim = 256, seq_len = 500, hidden_dim = 512, amino_acids = 21, optimizer=torch.optim.Adam, optimizer_param={'lr':0.001}, beta = 1, dropout = 0.0, reconstruction_loss_weight = 1)
model

LSTMVae(
  (tanh): Tanh()
  (soft): Softmax(dim=2)
  (dropout_layer): Dropout(p=0.0, inplace=False)
  (fc1_enc): Linear(in_features=10500, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3_enc_mean): Linear(in_features=512, out_features=256, bias=True)
  (fc3_enc_logvar): Linear(in_features=512, out_features=256, bias=True)
  (fc1_dec): Linear(in_features=256, out_features=512, bias=True)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3_dec): Linear(in_features=512, out_features=10500, bias=True)
  (bn3): BatchNorm1d(10500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_length1): Linear(in_features=256, out_features=1024, bias=True)
  (fc_length2): Linear(in_features=1024, out_features=1, bias=True)
  (length_relu): ReLU()
)

In [38]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
optimizer = torch.optim.Adam
optimizer_param = {'lr':0.001}
trainer = pl.Trainer(max_epochs=epochs,
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs/"))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [39]:
trainer.fit(model, train_dataloader, val_dataloader)


   | Name           | Type        | Params | Mode 
--------------------------------------------------------
0  | tanh           | Tanh        | 0      | train
1  | soft           | Softmax     | 0      | train
2  | dropout_layer  | Dropout     | 0      | train
3  | fc1_enc        | Linear      | 5.4 M  | train
4  | bn1            | BatchNorm1d | 1.0 K  | train
5  | fc3_enc_mean   | Linear      | 131 K  | train
6  | fc3_enc_logvar | Linear      | 131 K  | train
7  | fc1_dec        | Linear      | 131 K  | train
8  | bn2            | BatchNorm1d | 1.0 K  | train
9  | fc3_dec        | Linear      | 5.4 M  | train
10 | bn3            | BatchNorm1d | 21.0 K | train
11 | fc_length1     | Linear      | 263 K  | train
12 | fc_length2     | Linear      | 1.0 K  | train
13 | length_relu    | ReLU        | 0      | train
--------------------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.778    Total estimated model params 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [170]:
model(dummy_input)[0]

tensor([[-1.2268, -1.8560,  1.3069,  ...,  0.4998, -1.7541, -1.4200],
        [-1.4645,  1.2581, -1.7342,  ...,  0.6334,  0.9132, -0.3911],
        [ 1.8033,  2.0661, -0.9577,  ...,  0.0785, -0.7675,  0.1191],
        ...,
        [-0.1844, -0.5075, -1.4278,  ...,  1.5344,  1.4427,  2.0966],
        [ 0.3775, -0.6507,  1.3097,  ...,  0.1902, -1.3920, -0.5289],
        [-0.3449,  0.4200,  0.9115,  ..., -0.1417, -0.1503, -0.5963]],
       grad_fn=<AddBackward0>)