In [38]:
import sys
import os 
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import gpytorch
# Get the absolute path of the project root
project_root = os.path.abspath("..")  # Adjust if needed

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from proteinshake.datasets import ProteinLigandInterfaceDataset
from src.utils import data_utils as dtu
%load_ext autoreload
%autoreload 2


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
dataset = ProteinLigandInterfaceDataset(root='../data').to_point().torch()

In [26]:
# Initial testing to be done with proteins less than or equal to 150 residues in length
max_seq_length = 150
seq_lengths = dtu.get_dataset_seq_lengths(dataset,leq = max_seq_length)
total_num = sum(seq_lengths.values())
total_num

371

In [27]:
def padd_tensors(tensor_list):
    max_length = max([x.shape[0] for x in tensor_list])
    padded_tensors = []
    for tensor in tensor_list:
        cur_size = tensor.shape[0]
        pad_size = max_length - cur_size
        padding = torch.zeros(pad_size, tensor.shape[1])
        padd_tensors = torch.cat((tensor, padding), 0)
        padded_tensors.append(padd_tensors)
    return torch.stack(padded_tensors)

In [56]:
data_subset = dtu.get_subset_leq_len(dataset,leq = max_seq_length)
tensor_list = [sample[0][:,:] for sample in data_subset]
padded_tensors = padd_tensors(tensor_list)
Y = padded_tensors

min_val = Y.min()
max_val = Y.max()
Y = (Y - min_val) / (max_val - min_val)

In [59]:
Y = Y.reshape((371,4*150))
Y[0]

tensor([0.2614, 0.3603, 0.2078, 0.2674, 0.2544, 0.3639, 0.1993, 0.2585, 0.2520,
        0.3526, 0.1978, 0.2405, 0.2480, 0.3521, 0.2086, 0.2376, 0.2422, 0.3620,
        0.2074, 0.2376, 0.2375, 0.3576, 0.1977, 0.2316, 0.2344, 0.3481, 0.2033,
        0.2704, 0.2290, 0.3542, 0.2115, 0.2614, 0.2241, 0.3606, 0.2034, 0.2256,
        0.2201, 0.3510, 0.1983, 0.2376, 0.2162, 0.3474, 0.2084, 0.2525, 0.2106,
        0.3574, 0.2100, 0.2256, 0.2062, 0.3570, 0.1994, 0.2495, 0.2029, 0.3462,
        0.2017, 0.2614, 0.1983, 0.3489, 0.2118, 0.2525, 0.1929, 0.3581, 0.2078,
        0.2256, 0.1884, 0.3514, 0.1996, 0.2286, 0.1847, 0.3435, 0.2070, 0.2794,
        0.1808, 0.3517, 0.2139, 0.2704, 0.1752, 0.3571, 0.2055, 0.2465, 0.1702,
        0.3473, 0.2022, 0.2256, 0.1666, 0.3459, 0.2128, 0.2525, 0.1633, 0.3568,
        0.2140, 0.2226, 0.1568, 0.3567, 0.2044, 0.2495, 0.1512, 0.3471, 0.2075,
        0.2316, 0.1407, 0.3503, 0.2113, 0.2555, 0.1413, 0.3423, 0.2195, 0.2256,
        0.1504, 0.3478, 0.2244, 0.2614, 

In [18]:
class CloudPointKernel(gpytorch.kernels.Kernel):
    """
    Custom Kernel for 3D protein structures modeled as point clouds.
    """

    def __init__(self, **kwargs):
        super().__init__(has_lengthscale=True, **kwargs)

    def forward(self, x1, x2, diag=False, **params):
        # Compute pairwise squared Euclidean distance between point clouds
        dist_sq = torch.cdist(x1, x2, p=2) ** 2

        # Apply RBF-like kernel function
        kernel_matrix = torch.exp(-0.5 * dist_sq / self.lengthscale**2)

        return kernel_matrix
