In [1]:
import torch
import numpy as np
import torch.nn as nn
import sys
import os
project_root = os.path.abspath("..")  # Adjust if needed
import pytorch_lightning as pl
# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.pointNetVae import PointNetVAE
from src.utils.data_utils import *
from src.dataset_classes.pointDataset import *
from proteinshake.datasets import ProteinFamilyDataset
from proteinshake.tasks import LigandAffinityTask
import random
from torch.utils.data import DataLoader, Dataset, Subset
%load_ext autoreload
%autoreload 2

In [2]:
optimizer = torch.optim.AdamW
optimizer_param = {'lr':0.001}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PointNetVAE(2, optimizer, optimizer_param)

In [3]:
from proteinshake.transforms import CenterTransform
dataset = ProteinFamilyDataset(root='../data').to_point().torch()

In [11]:
idx_list = range(len(dataset))
subset_size = int(len(dataset)//10)
val_idx = random.sample(idx_list, subset_size)  # Get random subset
train_idx = list(set(idx_list) - set(val_idx))
s = 500
train_subset = PointDataset(Subset(dataset, train_idx), s)
val_subset = PointDataset(Subset(dataset, val_idx), s)

100%|██████████| 27999/27999 [00:04<00:00, 6654.52it/s]
100%|██████████| 3110/3110 [00:00<00:00, 11817.45it/s]


In [12]:
train_subset[0][:,3].shape

torch.Size([500])

In [73]:
latent_dim = 128
epochs = 30
lr = 0.0001
batch_size = 256
train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_subset,batch_size=batch_size, shuffle=False)
x_dim = train_subset[0].shape[0]
if torch.cuda.is_available():
    torch.cuda.current_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [80]:
next(iter(train_dataloader))[:,:,3]

tensor([[ 7., 15., 15.,  ..., 20., 20., 20.],
        [12., 16., 10.,  ..., 20., 20., 20.],
        [ 5., 19.,  5.,  ..., 20., 20., 20.],
        ...,
        [14.,  1.,  9.,  ..., 20., 20., 20.],
        [11., 10.,  2.,  ..., 20., 20., 20.],
        [ 8.,  9.,  0.,  ..., 20., 20., 20.]])

In [87]:
import torch.optim.adam
model = PointNetVAE(latent_dim,torch.optim.Adam,{'lr':0.0001}, beta=1, global_feature_size = 1024, conv_hidden_dim = 1024)
model


PointNetVAE(
  (embedding): Embedding(21, 32, padding_idx=20)
  (conv1_label): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
  (conv2_label): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
  (fc1_seq_enc): Linear(in_features=16000, out_features=1024, bias=True)
  (fc1_enc_mu): Linear(in_features=2048, out_features=128, bias=True)
  (fc1_enc_logvar): Linear(in_features=2048, out_features=128, bias=True)
  (conv1): Conv1d(3, 1024, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(2048, 1024, kernel_size=(1,), stride=(1,))
  (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (soft): Softmax(dim=-1)
  (max_pool): MaxPool1d(kernel_size=500, stride=500, padding=0, dilation=1, ceil_mode=False)
  (fc1

In [88]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
optimizer = torch.optim.Adam
optimizer_param = {'lr':0.001}
trainer = pl.Trainer(max_epochs=epochs,
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs/"))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloader, val_dataloader)


   | Name           | Type        | Params | Mode 
--------------------------------------------------------
0  | embedding      | Embedding   | 672    | train
1  | conv1_label    | Conv1d      | 1.1 K  | train
2  | conv2_label    | Conv1d      | 1.1 K  | train
3  | fc1_seq_enc    | Linear      | 16.4 M | train
4  | fc1_enc_mu     | Linear      | 262 K  | train
5  | fc1_enc_logvar | Linear      | 262 K  | train
6  | conv1          | Conv1d      | 4.1 K  | train
7  | conv2          | Conv1d      | 2.1 M  | train
8  | conv3          | Conv1d      | 2.1 M  | train
9  | bn1            | BatchNorm1d | 2.0 K  | train
10 | bn2            | BatchNorm1d | 4.1 K  | train
11 | bn3            | BatchNorm1d | 2.0 K  | train
12 | relu           | ReLU        | 0      | train
13 | soft           | Softmax     | 0      | train
14 | max_pool       | MaxPool1d   | 0      | train
15 | fc1_dec        | Linear      | 33.0 K | train
16 | fc2_dec        | Linear      | 131 K  | train
17 | fc3_dec        | Li

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['train_elbo_loss', 'train_elbo_loss_step', 'train_rec_loss', 'train_rec_loss_step', 'train_loss', 'train_loss_step', 'val_elbo_loss', 'val_elbo_loss_epoch', 'val_rec_loss', 'val_KL_loss', 'train_elbo_loss_epoch', 'train_rec_loss_epoch', 'train_KL_loss', 'train_loss_epoch', 'learning_rate']. Condition can be set using `monitor` key in lr scheduler dict