In [169]:
import torch 
import torch.nn as nn
import os 
import sys
# Get the absolute path of the project root
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from proteinshake.datasets import ProteinLigandInterfaceDataset
from src.utils import data_utils as dtu
from torch.utils.data import DataLoader, Dataset

In [170]:
dataset = ProteinLigandInterfaceDataset(root='../data').to_point().torch()

In [171]:
# Initial testing to be done with proteins less than or equal to 150 residues in length
max_seq_length = 2916
seq_lengths = dtu.get_dataset_seq_lengths(dataset,leq = max_seq_length)
total_num = sum(seq_lengths.values())
total_num

4620

In [172]:
def padd_tensors(tensor_list):
    max_length = max([x.shape[0] for x in tensor_list])
    padded_tensors = []
    for tensor in tensor_list:
        cur_size = tensor.shape[0]
        pad_size = max_length - cur_size
        padding = torch.zeros(pad_size, tensor.shape[1])
        padd_tensors = torch.cat((tensor, padding), 0)
        padded_tensors.append(padd_tensors)
    return torch.stack(padded_tensors)

In [173]:
data_subset = dtu.get_subset_leq_len(dataset,leq = max_seq_length)
tensor_list = [sample[0][:,:] for sample in data_subset]
padded_tensors = padd_tensors(tensor_list)
Y = padded_tensors

In [174]:
epsilon = 1e-8
Y = (Y - Y[:,:,:].min())/ (Y[:,:,:].max() - Y[:,:,:].min() + epsilon)

In [175]:
class vae_test(nn.Module):
    def __init__(self, max_seq_length, latent_dim):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.latent_dim = latent_dim

        # Encoder Layers 
        self.conv1 = nn.Conv1d(in_channels = 4, out_channels = 128, kernel_size = 1, stride = 1)
        self.conv2 = nn.Conv1d(in_channels = 128, out_channels = 256, kernel_size = 1, stride = 1)
        self.max_pool = nn.MaxPool1d(kernel_size = max_seq_length, stride = 1)
        self.mean_fc = nn.Linear(256, latent_dim)
        self.std_fc = nn.Linear(256, latent_dim)

        # Decoder Layers
        self.dec_fc1 = nn.Linear(latent_dim, 128)
        self.dec_fc2 = nn.Linear(128, 256)
        self.dec_fc3 = nn.Linear(256, self.max_seq_length*4)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterization(mu, logvar)
        decoded = self.decoder(z)
        return decoded, mu, logvar
    
    def reparameterization(self, mean, var):
        epsilon = torch.rand_like(var)    
        z = mean + var*epsilon
        return z
    
    def encoder(self, x):
        x = x.permute(0,2,1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool(x)
  
        x = x.permute(0,2,1)
        return self.mean_fc(x), self.std_fc(x)
    
    def decoder(self, z):
        z = self.dec_fc1(z)
        z = self.dec_fc2(z)
        z = self.dec_fc3(z)
        z = self.sig(z)
        return z.view(z.shape[0],self.max_seq_length,4)

In [176]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

In [177]:
bs = 100
train_loader = DataLoader(dataset=Y, batch_size=bs, shuffle=True)
model = vae_test(max_seq_length = max_seq_length, latent_dim = 20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [178]:
def train(model, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        print(f'---Epoch {epoch}---')
        overall_loss = 0
        for idx, x in enumerate(train_loader):

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(idx*bs))
    return overall_loss

train(model, optimizer, epochs=50)

---Epoch 0---


KeyboardInterrupt: 