In [1]:
import torch
import numpy as np
import torch.nn as nn
import sys
import os
project_root = os.path.abspath("..")  # Adjust if needed
import pytorch_lightning as pl
# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.utils.data_utils import *
from src.dataset_classes.pointDataset import *
from proteinshake.datasets import ProteinFamilyDataset
from proteinshake.tasks import LigandAffinityTask
import random
from src.models.graphVAE import GraphVAE
from torch.utils.data import Dataset, Subset
from src.utils.data_utils import *
from src.dataset_classes.graphDataset import *
from torch_geometric.nn import TopKPooling
from torch_geometric.nn import GAE, VGAE, GCNConv, TopKPooling, global_mean_pool, InnerProductDecoder
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.transforms import Pad

%load_ext autoreload
%autoreload 2

In [2]:
dataset = ProteinFamilyDataset(root='../data').to_graph(eps = 8).pyg()
dataset = load_graph_data(dataset)

In [3]:
from torch_geometric.loader import DataLoader

# dataset = [...]  # List of torch_geometric.data.Data objects (one per graph)
batch_size = 128
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
test_batch = next(iter(loader))

In [4]:
idx_list = range(len(dataset))
subset_size = int(len(dataset)//10)
val_idx = random.sample(idx_list, subset_size)  # Get random subset
train_idx = list(set(idx_list) - set(val_idx))
train_dataloader = DataLoader(Subset(dataset, train_idx).dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(Subset(dataset, train_idx).dataset,batch_size=batch_size, shuffle=False)

In [5]:
latent_dim = 32
epochs = 30
lr = 0.0001

if torch.cuda.is_available():
    torch.cuda.current_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
from src.models.graphVAE import GraphVAE
gvae = GraphVAE(latent_dim, torch.optim.Adam, {'lr':0.001}, conv_hidden_dim = 16, hidden_dim = 256, beta=0.05, beta_increment=0)
test_out = gvae(test_batch)

In [10]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
optimizer = torch.optim.Adam
optimizer_param = {'lr':0.001}
trainer = pl.Trainer(max_epochs=epochs,
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs/"))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(gvae, train_dataloader, val_dataloader)


  | Name            | Type    | Params | Mode 
----------------------------------------------------
0 | conv1           | GCNConv | 336    | train
1 | conv2           | GCNConv | 544    | train
2 | fc_mu           | Linear  | 512 K  | train
3 | fc_logvar       | Linear  | 512 K  | train
4 | fc1_dec         | Linear  | 8.4 K  | train
5 | fc2_dec_feature | Linear  | 2.6 M  | train
6 | fc_adj_dec      | Linear  | 528 K  | train
7 | tanh            | Tanh    | 0      | train
8 | sigmoid         | Sigmoid | 0      | train
9 | soft            | Softmax | 0      | train
----------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.526    Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/ProteinManifoldLearning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/opt/anaconda3/envs/ProteinManifoldLearning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
