In [1]:
import torch

In [2]:
import pandas as pd

In [3]:
data = pd.read_csv('/home/bio_science/drug_generation/vae_denovo/data/zinc_250k_splits.csv')

In [16]:
data.iloc[1]['smiles']

'CC1CC(C)CC(Nc2cncc(-c3nncn3C)c2)C1'

In [17]:
from rdkit import Chem

In [18]:
mol = Chem.MolFromSmiles(data.iloc[1]['smiles'])

In [19]:
from torch_geometric.utils import from_smiles

In [20]:
d = from_smiles(data.iloc[1]['smiles'])

In [21]:
d

Data(x=[21, 9], edge_index=[2, 46], edge_attr=[46, 3], smiles='CC1CC(C)CC(Nc2cncc(-c3nncn3C)c2)C1')

In [14]:
from torch_geometric.nn import GCNConv, global_mean_pool

In [15]:
d.size()

(24, 24)

In [22]:
layer = GCNConv(24, 32)

In [25]:
layer(d.x, d.edge_index)

RuntimeError: Found dtype Long but expected Float

In [56]:
import argparse
import math
from pathlib import Path

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils import from_smiles


class SmilesDataset(Dataset):
    """Read a CSV and convert each SMILES to a PyG graph Data object."""

    def __init__(self, csv_path: str, transform=None):
        super().__init__(transform=transform)
        self.df = pd.read_csv(csv_path)
        if "smiles" not in self.df.columns:
            raise ValueError("CSV must contain a 'smiles' column")

    def len(self):
        return len(self.df)

    def get(self, idx):
        smiles = self.df.iloc[idx]["smiles"]
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Invalid SMILES at idx {idx}: {smiles}")
        # Convert to a torch_geometric Data object
        data = from_smiles(smiles)

        # Global graph‑level feature: Morgan fingerprint (2048‑bit)
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
        fp = torch.tensor(fp, dtype=torch.float)
        data.y = fp  # store reconstruction target on graph level
        return data


    def collate_fn(self, batch):
        batch = Batch.from_data_list(batch)
        batch.x = batch.x.float()
        if batch.edge_attr is not None:
            batch.edge_attr = batch.edge_attr.float()
        return batch

In [47]:
def collate_float(batch):
    print("=== collate_float called ===") 
    batch = Batch.from_data_list(batch)
    batch.x = batch.x.float()
    if batch.edge_attr is not None:
        batch.edge_attr = batch.edge_attr.float()
    return batch

In [57]:
def to_float(data: Data) -> Data:
    if data.x is not None:
        data.x = data.x.float()
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.float()
    return data

In [58]:
dataset = SmilesDataset('/home/bio_science/drug_generation/vae_denovo/data/zinc_250k_splits.csv', transform=to_float)

In [36]:
from torch_geometric.loader import DataLoader as GeoLoader

In [59]:
dl = GeoLoader(dataset, batch_size=32, shuffle=True)

In [51]:
print(dl.__class__, dl.__class__.__module__)

<class 'torch_geometric.loader.dataloader.DataLoader'> torch_geometric.loader.dataloader


In [53]:
print("Loader.collate_fn:", dl.collate_fn)

Loader.collate_fn: <torch_geometric.loader.dataloader.Collater object at 0x7f82fd1b57e0>


In [55]:
import torch_geometric
print(torch_geometric.__version__)

2.6.1


In [60]:
for data in dl:
    print(data)
    break

DataBatch(x=[740, 9], edge_index=[2, 1598], edge_attr=[1598, 3], smiles=[32], y=[65536], batch=[740], ptr=[33])


In [61]:
data

DataBatch(x=[740, 9], edge_index=[2, 1598], edge_attr=[1598, 3], smiles=[32], y=[65536], batch=[740], ptr=[33])

In [62]:
data.x.dtype

torch.float32

In [41]:
sample.x.size(1)

9

In [42]:
conv1 = GCNConv(9, 32)

In [43]:
sample

Data(x=[21, 9], edge_index=[2, 46], edge_attr=[46, 3], smiles='CC1CC(C)CC(Nc2cncc(-c3nncn3C)c2)C1', y=[2048])

In [7]:
out = conv1(sample.x, sample.edge_index)

NameError: name 'conv1' is not defined

In [47]:
out.shape

torch.Size([21, 32])

In [8]:
class GNNEncoder(nn.Module):
    def __init__(self, num_node_feats, hidden_dim=128, latent_dim=64):
        super().__init__()
        self.conv1 = GCNConv(num_node_feats, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.mu_lin = nn.Linear(hidden_dim, latent_dim)
        self.logvar_lin = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        mu = self.mu_lin(x)
        logvar = self.logvar_lin(x)
        return mu, logvar

In [9]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=64, out_dim=2048):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 128)
        self.lin2 = nn.Linear(128, 256)
        self.lin3 = nn.Linear(256, out_dim)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

In [None]:
class GraphVAE(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        num_node_feats = cfg.num_node_feats
        hidden_dim=cfg.hidden_dim
        latent_dim=cfg.latent_dim

        self.encoder = GNNEncoder(num_node_feats, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, out_dim=2048)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, data):
        mu, logvar = self.encoder(data.x, data.edge_index, data.batch)
        z = self.reparameterize(mu, logvar)
        recon_y = self.decoder(z)
        return recon_y, mu, logvar
        

In [64]:
gvae = GraphVAE(9)

In [14]:
data.x

tensor([[6, 0, 4,  ..., 4, 0, 0],
        [6, 0, 4,  ..., 4, 0, 0],
        [6, 0, 4,  ..., 4, 0, 0],
        ...,
        [6, 0, 3,  ..., 3, 0, 1],
        [8, 0, 1,  ..., 3, 0, 0],
        [7, 0, 3,  ..., 3, 0, 1]])

In [68]:
a,b,c = gvae(data)

In [70]:
a.shape, b.shape, c.shape

(torch.Size([32, 2048]), torch.Size([32, 64]), torch.Size([32, 64]))

In [None]:
b.sh

In [None]:
class VaeTrainer():
    def __init__(self, cfg):
        self.cfg = cfg

    def _build_model(self):
        if cfg.model == 'gnn':
            self.model = GraphVAE(cfg)
        else:
            raise ValueError(f"Unknown model: {cfg.model}")

        params_groups = [
            {"params": self.model.parameters(), "lr": cfg.lr}
        ]

        if cfg.optimizer == 'adam':
            self.optimizer = torch.optim.Adam(params_groups)
        elif cfg.optimizer == 'sgd':
            self.optimizer = torch.optim.SGD(params_groups)
        else:
            raise ValueError(f"Unknown optimizer: {cfg.optimizer}")

    def train(self, train_dataloader, val_dataloader=None, test_dataloader=None):
        self.model.train()
        for epoch in range(self.epochs):
            self.train_one_epoch(train_dataloader, val_dataloader)

    def train_one_epoch(self, train_dataloader, val_dataloader=None):
        total = 0.0
        for idx, batch in enumerate(train_dataloader):
            batch = batch.to(self.cfg.device, pin_memory=True, non_blocking=True)
            self.optimizer.zero_grad()
            recon, mu, logvar = self.model(batch)
            loss = self.loss_fn(recon, batch.y.to(self.cfg.device), mu, logvar)
            loss.backward()
            self.optimizer.step()
            total += loss.item()

        return total / len(train_dataloader.dataset)


    def loss_fn(self, recon, target, mu, logvar):
        bce = F.binary_cross_entropy_with_logits(recon, target, reduction="sum")
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return bce + kl