In [144]:
import torch
import numpy as np
import torch.nn as nn
import sys
import os
project_root = os.path.abspath("..")  # Adjust if needed
import pytorch_lightning as pl
# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.pointNetVae import PointNetVAE
from src.utils.data_utils import *
from src.dataset_classes.pointDataset import *
from proteinshake.datasets import ProteinFamilyDataset
from proteinshake.tasks import LigandAffinityTask
import random
from torch.utils.data import DataLoader, Dataset, Subset
from src.utils.data_utils import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [207]:
dataset = ProteinFamilyDataset(root='../data').to_graph(eps = 8).pyg()
point_d = ProteinFamilyDataset(root='../data').to_point().torch()
dataset = [data[0] for data in dataset if len(data[1]['protein']['sequence']) < 500]

In [199]:
dataset[0][0].x.shape

torch.Size([277])

In [156]:
for data in dataset:
    data.x = torch.nn.functional.one_hot(data.x, num_classes=21).float()

In [None]:
torch.nn.functional.pad(dataset[0].x.argmax(dim=-1), (0,500 - dataset[0].x.argmax(dim=-1).shape[0]), value=21).shape

torch.Size([500])

In [148]:
one_hot_encode_seq(dataset[0][1]['protein']['sequence'], 500).shape

torch.Size([500, 21])

In [151]:
torch.nn.functional.one_hot(dataset[0][0].x, 21)

tensor([[0, 0, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [208]:
from torch_geometric.loader import DataLoader
from torch_geometric.nn import InnerProductDecoder

# dataset = [...]  # List of torch_geometric.data.Data objects (one per graph)
batch_size = 16
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_batch = next(iter(loader))

In [232]:
# test_batch

indices_with_mask_val = []
for i in range(batch_size):
    x_true_indices = test_batch.x[torch.where(test_batch.batch == i)[0]]
    x_true_indices = torch.nn.functional.pad(x_true_indices, (0,500 - x_true_indices.shape[0]), value=21)
    indices_with_mask_val.append(x_true_indices)

In [233]:
indices_with_mask_val = torch.stack(indices_with_mask_val)

In [237]:
test_batch.batch[-1]

tensor(15)

In [132]:
test_batch[0].edge_index.shape

torch.Size([2, 1562])

In [136]:
dense_to_sparse(test_batch[0].edge_index)[0]

tensor([[   0,    0,    0,  ...,    1,    1,    1],
        [   6,    7,    8,  ..., 1559, 1560, 1561]])

In [123]:
dec(testIn[0], test_batch.edge_index)

tensor([0.6120, 0.6623, 0.9430,  ..., 0.3138, 0.4440, 0.7019])

In [7]:
test_batch.edge_index

tensor([[   0,    0,    0,  ..., 4215, 4216, 4216],
        [   1,    2,   30,  ..., 4212, 4215, 4214]])

In [8]:
from torch_geometric.nn import GCNConv, dense_diff_pool, global_mean_pool, TopKPooling
hidden_dim = 32
latent_dim = 2
class encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(1, hidden_dim, cached=True) # cached only for transductive learning
        self.conv2 = GCNConv(hidden_dim ,2*hidden_dim)

        self.fc_mu = nn.Linear(2*hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(2*hidden_dim, latent_dim)
        

    def forward(self,x):
        x_f, x_edg, x_batch = x.x.unsqueeze(-1).float(), x.edge_index, x.batch
        x_f = self.conv1(x_f, x_edg)

        x_f = self.conv2(x_f, x_edg)
        pooled_x = global_mean_pool(x_f, x_batch)
        mu = self.fc_mu(pooled_x)
        logvar = self.fc_logvar(pooled_x)
        return mu, logvar

In [86]:
test_batch[3].x

tensor([ 0, 11, 10,  7, 13, 10, 13,  0,  7,  6,  7,  0,  6, 18, 19,  7, 12,  7,
        11,  5, 13, 13,  3,  2, 13,  5,  5, 15, 11,  5, 19, 13, 11,  1, 15, 15,
         5,  0, 10,  7,  9,  3, 12,  5,  5, 10,  4, 13,  2,  3, 14,  5,  7, 10,
        10,  2, 11, 16,  5, 13, 16,  6, 14,  0,  9,  9, 16, 16,  2, 12,  0,  9,
        10, 16,  0, 10,  3, 11, 10,  7, 19, 11, 15,  8,  9, 15,  4,  7, 10, 15,
        10,  7,  5, 18, 15,  0, 10,  9,  8, 15,  7,  0,  9,  2, 13,  5,  3,  7,
        19, 11, 10, 19, 11, 11,  1,  7, 11, 13, 12,  6,  5,  0, 19,  0,  5,  7,
         9,  7,  7, 12, 19,  0, 19, 10,  1, 12, 16, 14,  5,  6, 19,  3,  5,  9,
         9,  5, 11, 15, 15, 14, 18,  7,  9, 19,  5,  7,  0,  2, 18,  2, 15, 14,
         7,  6,  9, 19,  9, 15,  7,  5, 10, 19,  0, 10,  5, 11,  0, 12,  5, 13,
         9, 11,  5, 19,  7,  7,  1,  0,  9, 11, 10, 14, 19, 15,  0, 14, 13,  8,
         4, 15, 12, 10,  6, 14,  0,  0,  5, 11, 10,  5,  3,  5, 10,  2, 11,  9,
        15,  9,  2, 11, 10,  2,  7,  9, 

In [49]:
hidden_dim = 32
kpool = TopKPooling(hidden_dim, ratio=int(4))
gcn = GCNConv(1, hidden_dim)

In [50]:
out_gcn = gcn(test_batch.x.unsqueeze(-1).float(), test_batch.edge_index)
pool_out = kpool(out_gcn, test_batch.edge_index, batch = test_batch.batch)

In [51]:
out_gcn.shape

torch.Size([4217, 32])

In [52]:
for out in pool_out:
    try:
        print(out.shape)
    except:
        continue

torch.Size([64, 32])
torch.Size([2, 38])
torch.Size([64])
torch.Size([64])
torch.Size([64])


In [55]:
pool_out[0].reshape(batch_size, 4,-1).shape

torch.Size([16, 4, 32])

In [59]:
pool_out[0].shape

torch.Size([64, 32])

In [61]:
pool_out[-3].reshape(16, 4)[0]

tensor([0, 0, 0, 0])

In [26]:
test_batch.edge_index.shape

torch.Size([2, 42486])

In [28]:
test_batch.x.shape

torch.Size([4217])