In [1]:
from sklearn.metrics import silhouette_score, davies_bouldin_score
import torch.nn as nn
import sys
import os
project_root = os.path.abspath("..")  # Adjust if needed
import pytorch_lightning as pl
# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)
from collections import Counter

from src.utils.data_utils import *
from src.dataset_classes.pointDataset import *
from proteinshake.datasets import ProteinFamilyDataset
from proteinshake.tasks import LigandAffinityTask
import random
from src.models.graphVAE import GraphVAE
from src.models.basicVae import LitBasicVae
from src.models.PointNetVae_chamfer_split import PointNetVAE
from torch.utils.data import Dataset, Subset
from torch_geometric.utils import to_dense_batch, to_dense_adj
import numpy as np
from src.utils.data_utils import *
from src.dataset_classes.graphDataset import *
from src.dataset_classes.sequenceDataset import *
from torch_geometric.loader import DataLoader as Pyg_DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [None]:
data = ProteinFamilyDataset(root='../data').to_point().torch()
graph_data = ProteinFamilyDataset(root='../data').to_graph(eps = 8).pyg()
seq_dataset = SequenceDataset(data, 500, return_proteins=True)
seq_dataloader = DataLoader(seq_dataset, batch_size = 128, shuffle=False)

point_dataset = PointDataset(data, 500, return_proteins = True)
point_dataloader = DataLoader(point_dataset, batch_size = 128, shuffle=False)

### NEED TO USE PYG-DATALOADER ###
graph_dataset  = load_graph_data(graph_data, amnino_acids = 21)
graph_dataloader = Pyg_DataLoader(graph_dataset, batch_size=128, shuffle=False)

families = [k[1]['protein']['Pfam'][0] for k in seq_dataset.org_protein_data]

# Count occurrences of each element
counter = Counter(families)

100%|██████████| 31109/31109 [00:06<00:00, 4660.26it/s]
100%|██████████| 31109/31109 [00:07<00:00, 3975.23it/s]


In [6]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# Run inference

In [7]:
seq_checkpoint_path = '../trained_models/Pfam/BVAE/BETA_EXP/0_LD16_HD512_Beta0.001_BetaInc0.ckpt' #'../trained_models/Pfam/BVAE/FINAL_MODEL/LD16_HD512_Beta0.005_BetaInc0.ckpt'
point_checkpoint_path = '../trained_models/Pfam/BETA_point_vae/Pfam_BETA_EXP/0_LD64_GF512_Beta0.001_HD512.ckpt' #"../trained_models/Pfam/PVAE/FINAL_MODEL/FINAL_PVAE_LD64_GF512_BetaInc0_Beta0.005_HD512_CH8.ckpt"
graph_checkpoint_path = '../trained_models/Pfam/GVAE/BETA_EXP/0_LD32_HD512_Beta0.001_GCH16_LR0.0001.ckpt'#'../trained_models/Pfam/GVAE/2_LD32_HD512_Beta0.005_GCH96_LR0.0001.ckpt'

seq_model = LitBasicVae.load_from_checkpoint(seq_checkpoint_path)
point_model = PointNetVAE.load_from_checkpoint(point_checkpoint_path)
graph_model = GraphVAE.load_from_checkpoint(graph_checkpoint_path)

# Put model in evaluation mode
seq_model.eval()
point_model.eval()
graph_model.eval()
graph_model.to(device)

GraphVAE(
  (conv1): GCNConv(21, 16)
  (conv2): GCNConv(16, 32)
  (fc_mu): Linear(in_features=16000, out_features=32, bias=True)
  (fc_logvar): Linear(in_features=16000, out_features=32, bias=True)
  (fc1_dec): Linear(in_features=32, out_features=512, bias=True)
  (fc2_dec_feature): Linear(in_features=512, out_features=10500, bias=True)
  (fc_adj_dec): Linear(in_features=32, out_features=16000, bias=True)
  (tanh): Tanh()
  (sigmoid): Sigmoid()
  (soft): Softmax(dim=-1)
)

In [None]:
from tqdm import tqdm
original_seq_data = []
seq_res_mu0 = []
seq_res_mu1 = []
seq_res_mu2 = []

for batch in tqdm(seq_dataloader):
    original_seq_data.append(batch)
    rep_z, x_mu, x_logvar, x_rec, logit = seq_model(batch.to(device))
    seq_res_mu0.append(x_rec.detach().cpu().numpy())

seq_latent_res = np.concatenate(seq_res_mu0, axis = 0)

point_seq_res = []
for batch in tqdm(point_dataloader):
    rep_z, x_mu, x_logvar, x_rec, logit = point_model(batch.to(device))
    point_seq_res.append(x_rec.detach().cpu().numpy())

point_seq_res = np.concatenate(point_seq_res, axis = 0)

graph_seq_res = []
for batch in tqdm(graph_dataloader):
    rep_z, x_mu, x_logvar, x_rec, logit_feature, adj_matrix  = graph_model(batch.to(device))
    graph_seq_res.append(x_rec.detach().cpu().numpy())
graph_seq_res = np.concatenate(graph_seq_res, axis = 0)

100%|██████████| 225/225 [00:03<00:00, 65.04it/s]
100%|██████████| 225/225 [00:02<00:00, 91.30it/s] 
100%|██████████| 225/225 [00:47<00:00,  4.74it/s]


In [None]:


seq_res = torch.argmax(torch.tensor(seq_latent_res), dim=-1)
point_seq_res = torch.argmax(torch.tensor(point_seq_res), dim=-1)
graph_seq_res = torch.argmax(torch.tensor(graph_seq_res), dim=-1)

In [34]:
torch.vstack(original_seq_data).shape

torch.Size([28733, 500, 21])