In [1]:
import torch
import torch.nn as nn 
import os 
import sys
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.attentionVae import *
from proteinshake.datasets import ProteinLigandInterfaceDataset, AlphaFoldDataset, GeneOntologyDataset
from src.utils import data_utils as dtu
from src.dataset_classes.sequenceDataset import SequenceDataset
from torch.utils.data import DataLoader, Dataset, Subset
from src.utils.data_utils import *

%load_ext autoreload
%autoreload 2


In [2]:
dataset = GeneOntologyDataset(root='../data').to_point().torch()

In [91]:
max_seq_length = 500
seq_dataset = SequenceDataset(dataset, max_seq_length, transformer_input=True)
seq_dataloader = DataLoader(seq_dataset, batch_size = 128, shuffle = False)

100%|██████████| 32633/32633 [00:17<00:00, 1873.01it/s]


In [92]:
# Positional Encoding Sanity check
pos_enc = PositionalEncoding(embed_dim = 20, dropout= 0.1)

# Attention Encoder Block Sanity check
att_enc = AttentionEncoderBlock(embed_dim = 20, num_heads = 5,dropout = 0.1, hidden_dim = 128)

In [93]:
test_batch_out = pos_enc(next(iter(seq_dataloader)))
test_batch_out[0,0,:]

tensor([0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 1.1111, 0.0000, 0.0000, 0.0000,
        1.1111, 0.0000, 1.1111, 0.0000, 0.0000, 0.0000, 1.1111, 0.0000, 1.1111,
        1.1111, 1.1111])

In [94]:
x = next(iter(seq_dataloader))
padding_mask = (x.sum(dim=-1) == 0).bool()  # Sum along the embedding dimension, if sum is 0, it's padding
# Convert padding mask to shape (B, 1, N) to match the attention's expected format
padding_mask = padding_mask
test_batch_out_att_enc, attn_w = att_enc.attention(x,x,x, key_padding_mask = padding_mask)

In [99]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.

          from:
          https://stackoverflow.com/questions/60817390/implementing-a-simple-resnet-block-with-pytorch
        """
        super(Block, self).__init__()

        self.skip = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
          self.skip = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm1d(out_channels))
        else:
          self.skip = None

        self.block = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm1d(out_channels))
        
    def forward(self, x):

      identity = x

      out = self.block(x)

      if self.skip is not None:
          identity = self.skip(x)

      out += identity
      out = torch.nn.functional.relu(out)

      return out


In [102]:
test_batch_out_att_enc.shape

torch.Size([128, 500, 20])

In [None]:
dec_layers = [nn.Linear(self.latent_dim, self.seq_len*(self.hidden_dim//2)),
                     Block(self.hidden_dim//2, self.hidden_dim),
                     nn.Conv1d(self.hidden_dim, self.input_dim, kernel_size=3, padding=1)]
        self.dec_conv_module = nn.ModuleList(dec_layers)

torch.Size([128, 500, 20])

In [96]:
x.shape

torch.Size([128, 500, 20])