In [2]:
import torch
import torch.nn as nn 
import os 
import sys
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.attentionVae import *
from proteinshake.datasets import ProteinLigandInterfaceDataset, AlphaFoldDataset, GeneOntologyDataset
from src.utils import data_utils as dtu
from src.dataset_classes.sequenceDataset import SequenceDataset
from torch.utils.data import DataLoader, Dataset, Subset
from src.utils.data_utils import *

%load_ext autoreload
%autoreload 2


In [7]:
dataset = GeneOntologyDataset(root='../data').to_point().torch()

In [8]:
max_seq_length = 500
seq_dataset = SequenceDataset(dataset, max_seq_length, transformer_input=True)
seq_dataloader = DataLoader(seq_dataset, batch_size = 128, shuffle = False)

100%|██████████| 32633/32633 [00:08<00:00, 3809.27it/s]


In [24]:
# Positional Encoding Sanity check
embed_dim = 20
latent_dim = 2
pos_enc = PositionalEncoding(embed_dim = embed_dim, dropout= 0.1)

# Attention Encoder Block Sanity check
att_enc = AttentionEncoderBlock(embed_dim = embed_dim, num_heads = 5,dropout = 0.1, hidden_dim = 256)

# Linear Encoder
glob_attn_module = nn.Sequential(nn.Linear(embed_dim, 1),
                                            nn.Softmax(dim=1))
forward_latent_mean_layer = nn.Linear(embed_dim,latent_dim)

In [25]:
test_batch_out = pos_enc(next(iter(seq_dataloader)))
test_batch_out[0,0,:]

tensor([0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 0.0000, 0.0000,
        1.1111, 0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 0.0000,
        1.1111, 1.1111])

In [None]:
x = next(iter(seq_dataloader))
padding_mask = (x.sum(dim=-1) == 0).bool()  # Sum along the embedding dimension, if sum is 0, it's padding
# Convert padding mask to shape (B, 1, N) to match the attention's expected format
padding_mask = padding_mask
test_batch_out_att_enc, attn_w = att_enc.attention(x,x,x, key_padding_mask = padding_mask)

In [27]:
glob_attn = glob_attn_module(test_batch_out_att_enc)

In [None]:
print(glob_attn.transpose(-1, 1).shape)
print(test_batch_out_att_enc.shape) 

torch.Size([128, 1, 500])
torch.Size([128, 500, 20])


In [28]:
z_rep = torch.bmm(glob_attn.transpose(-1, 1), test_batch_out_att_enc).squeeze()

In [34]:
z_f = forward_latent_mean_layer(z_rep)

In [35]:
z_f.shape

torch.Size([128, 2])

In [63]:
hidden_dim = 64
seq_len = 500
part1 = nn.Linear(latent_dim, seq_len*(hidden_dim//2))

block = nn.Sequential(
            nn.Conv1d(in_channels=hidden_dim//2, out_channels=hidden_dim, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(hidden_dim))
part2 = nn.Conv1d(hidden_dim,20, kernel_size=3, padding=1)

In [70]:
p1 = part1(z_f)
p1 = p1.reshape(-1, hidden_dim//2, seq_len)
p1.shape

torch.Size([128, 32, 500])

In [71]:
block1 = block(p1)

In [72]:
block1.shape

torch.Size([128, 64, 500])

In [73]:
p2 = part2(block1)

In [76]:
p2.permute(0,2,1).shape

torch.Size([128, 500, 20])

In [75]:
test_batch_out.shape

torch.Size([128, 500, 20])

In [84]:
N_layers =  4
embed_dim = 20
hidden_dim = 64
num_heads = 4
dropout = 0.1
latent_dim = 2
seq_len = 500
attention_vae_test = AttentionVAE(N_layers=N_layers,
                                  embed_dim=embed_dim,
                                  hidden_dim=hidden_dim,
                                  num_heads=num_heads,
                                  dropout=dropout,
                                  latent_dim=latent_dim,
                                  seq_len=seq_len)

In [85]:
test_in = next(iter(seq_dataloader))