In [2]:
import torch
import torch.nn as nn 
import os 
import sys
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.attentionVae import *
from proteinshake.datasets import ProteinLigandInterfaceDataset, AlphaFoldDataset, GeneOntologyDataset
from src.utils import data_utils as dtu
from src.dataset_classes.sequenceDataset import SequenceDataset
from torch.utils.data import DataLoader, Dataset, Subset
from src.utils.data_utils import *

%load_ext autoreload
%autoreload 2


In [3]:
dataset = ProteinLigandInterfaceDataset(root='../data').to_point().torch()

In [4]:
max_seq_length = 500
seq_dataset = SequenceDataset(dataset, max_seq_length, transformer_input=True, return_proteins = False)
seq_dataloader = DataLoader(seq_dataset, batch_size = 128, shuffle = False)

100%|██████████| 4642/4642 [00:01<00:00, 3562.03it/s]


In [5]:
next(iter(seq_dataloader)).shape

torch.Size([128, 500, 20])

In [17]:
# Positional Encoding Sanity check
embed_dim = 20
latent_dim = 2
pos_enc = PositionalEncoding(embed_dim = embed_dim, dropout= 0.1, device=device)

# Attention Encoder Block Sanity check
att_enc = AttentionEncoderBlock(embed_dim = embed_dim, num_heads = 5,dropout = 0.1, hidden_dim = 256)

# Linear Encoder
glob_attn_module = nn.Sequential(nn.Linear(embed_dim, 1),
                                            nn.Softmax(dim=1))
forward_latent_mean_layer = nn.Linear(embed_dim,latent_dim)

In [None]:
test_batch_out = pos_enc(next(iter(seq_dataloader)))
test_batch_out[0,0,:]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

In [20]:
x = next(iter(seq_dataloader))
padding_mask = (x.sum(dim=-1) == 0).bool()  # Sum along the embedding dimension, if sum is 0, it's padding
# Convert padding mask to shape (B, 1, N) to match the attention's expected format
padding_mask = padding_mask
test_batch_out_att_enc, attn_w = att_enc.attention(x,x,x, key_padding_mask = padding_mask)

In [21]:
glob_attn = glob_attn_module(test_batch_out_att_enc)

In [26]:
glob_attn.squeeze().shape

torch.Size([128, 500])

In [22]:
print(glob_attn.transpose(-1, 1).shape)
print(test_batch_out_att_enc.shape) 

torch.Size([128, 1, 500])
torch.Size([128, 500, 20])


In [62]:
hidden_dim = 64
seq_len = 500
part1 = nn.Linear(latent_dim, seq_len*(hidden_dim//2))

block = nn.Sequential(
            nn.Conv1d(in_channels=hidden_dim//2, out_channels=hidden_dim, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(hidden_dim))
part2 = nn.Conv1d(hidden_dim,20, kernel_size=3, padding=1)

In [None]:
N_layers =  1
embed_dim = 20
hidden_dim = 256
num_heads = 2
dropout = 0.1
latent_dim = 256
seq_len = 500
if torch.cuda.is_available():
    torch.cuda.current_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

attention_vae_test = AttentionVAE(optimizer=torch.optim.Adam, 
                                  optimizer_param={'lr':0.001},
                                  N_layers=N_layers,
                                  embed_dim=embed_dim,
                                  hidden_dim=hidden_dim,
                                  num_heads=num_heads,
                                  dropout=dropout,
                                  latent_dim=latent_dim,
                                  seq_len=seq_len)

In [12]:
test_in = next(iter(seq_dataloader))

In [24]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [25]:
attention_vae_test.get_attention_weights = True
attention_vae_test.pos_encoder.device = device

In [26]:
device

device(type='cpu')

In [28]:
test_out = attention_vae_test(test_in)

In [32]:
attention_vae_test.attention_weights[0].shape

torch.Size([128, 500, 500])

In [14]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
optimizer = torch.optim.Adam
optimizer_param = {'lr':0.001}
trainer = pl.Trainer(max_epochs=100,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger(save_dir="logs/"))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(attention_vae_test, seq_dataloader)


  | Name                        | Type               | Params | Mode 
---------------------------------------------------------------------------
0 | pos_encoder                 | PositionalEncoding | 0      | train
1 | attention_encoders          | ModuleList         | 12.3 K | train
2 | global_forward_layer        | Linear             | 125 K  | train
3 | forward_latent_mean_layer   | Linear             | 1.3 M  | train
4 | forward_latent_logvar_layer | Linear             | 1.3 M  | train
5 | soft                        | Softmax            | 0      | train
6 | relu                        | ReLU               | 0      | train
7 | fc1_dec                     | Linear             | 1.3 M  | train
8 | fc3_dec                     | Linear             | 50.0 M | train
---------------------------------------------------------------------------
54.0 M    Trainable params
0         Non-trainable params
54.0 M    Total params
215.972   Total estimated model params size (MB)
22        Modules

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined