<a href="https://colab.research.google.com/github/vicben2/hgraph2graph/blob/main/Partb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/wengong-jin/hgraph2graph

In [None]:
%cd hgraph2graph

In [None]:
!pip install rdkit networkx tqdm
import torch
from torch.utils.data import DataLoader
import pandas as pd
import random
import tqdm
import matplotlib.pyplot as plt
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Draw, DataStructs
from rdkit import DataStructs
from hgraph import HierVAE, MolGraph, common_atom_vocab, Vocab, PairVocab
from hgraph import MoleculeDataset
import sys
import os

In [None]:
class Args:
  def __init__(self):
    self.vocab = None
    self.atom_vocab = common_atom_vocab
    self.rnn_type = 'LSTM'
    self.hidden_size = 250
    self.embed_size = 250
    self.batch_size = 20
    self.latent_size = 32
    self.depthT = 15
    self.depthG = 15
    self.diterT = 1
    self.diterG = 3
    self.dropout = 0.0

args = Args()

seed = 7

torch.manual_seed(seed)
random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#vocab
vocab_path = 'data/chembl/vocab.txt'
vocab = [x.strip("\r\n ").split() for x in open(vocab_path)]
args.vocab = PairVocab(vocab, cuda=(device.type == 'cuda')) # FIX: Explicitly pass cuda argument

def filter_for_vocab(smiles_list, vocab, max_atoms=50):
    valid_smiles = []

    for smi in tqdm(smiles_list, desc="Filtering for vocabulary"):
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.GetNumAtoms() > max_atoms:
                continue

            Chem.Kekulize(mol, clearAromaticFlags=False)

            hmol = MolGraph(smi)
            ok = True
            for node, attr in hmol.mol_tree.nodes(data=True):
                smiles_node = attr['smiles']
                ok &= attr['label'] in vocab.vmap
                for i, s in attr['inter_label']:
                    ok &= (smiles_node, s) in vocab.vmap
                if not ok:
                    break

            if ok:
                valid_smiles.append(smi)

        except Exception as e:
            continue

    return valid_smiles

with open('data/chembl/all.txt', 'r') as f:
    all_smiles = [line.strip() for line in f if line.strip()]

valid_smiles = filter_for_vocab(candidate_smiles, vocab)

#splitting
TRAIN_SIZE = min(2000, len(valid_smiles) - 200)
TEST_SIZE = 200
train_smiles = valid_smiles[:TRAIN_SIZE]
test_smiles = valid_smiles[TRAIN_SIZE:TRAIN_SIZE + TEST_SIZE]

with open('train_final.txt', 'w') as f:
    f.write('\n'.join(train_smiles))
with open('test_final.txt', 'w') as f:
    f.write('\n'.join(test_smiles))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import hgraph.hgnn as hgnn_module

def make_cuda_fixed(tensors):
    tree_tensors, graph_tensors = tensors

    def make_tensor(x):
        if isinstance(x, torch. Tensor):
            return x
        elif isinstance(x, np.ndarray):
            return torch.from_numpy(x)
        else:
            return torch.tensor(x)

    tree_tensors = [make_tensor(x).to(device).long() for x in tree_tensors[:-1]] + [tree_tensors[-1]]
    graph_tensors = [make_tensor(x).to(device).long() for x in graph_tensors[:-1]] + [graph_tensors[-1]]
    return tree_tensors, graph_tensors

hgnn_module.make_cuda = make_cuda_fixed
import hgraph.nnutils as nnutils_module
original_index_select_ND = nnutils_module.index_select_ND

def index_select_ND_fixed(source, dim, index):
    if not isinstance(index, torch.Tensor):
        index = torch.tensor(index, device=source.device, dtype=torch.long)
    return original_index_select_ND(source, dim, index)

nnutils_module.index_select_ND = index_select_ND_fixed

class TrainingArgs:
    def __init__(self):
        self.vocab = vocab
        self.atom_vocab = common_atom_vocab
        self.save_dir = 'checkpoints'

        self.rnn_type = 'LSTM'
        self.hidden_size = 250
        self.embed_size = 250
        self.batch_size = 20
        self.latent_size = 32
        self.depthT = 15
        self.depthG = 15
        self.diterT = 1
        self.diterG = 3
        self.dropout = 0.0

        self.lr = 1e-3
        self.clip_norm = 5.0
        self.max_beta = 0.1

        self.num_epochs = 6
        self.print_iter = 20

train_args = TrainingArgs()
os.makedirs(train_args.save_dir, exist_ok=True)

from hgraph import HierVAE

model = HierVAE(train_args).to(device)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))

#weights
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init. xavier_normal_(param)

#optimizer
optimizer = optim.Adam(model.parameters(), lr=train_args.lr)

torch.save(model.state_dict(), os.path.join(train_args.save_dir, "model.epoch_0"))
total_step = 0
beta = 0.01

for epoch in range(1, train_args. num_epochs + 1):
    # Create dataset
    dataset = MoleculeDataset(train_smiles, vocab, common_atom_vocab, train_args.batch_size)
    loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=lambda x: x[0])

    model.train()
    epoch_losses = []
    epoch_kl = []
    epoch_wacc = []
    epoch_tacc = []

    successful_batches = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch_idx, batch in enumerate(pbar):
        total_step += 1

        try:
            model.zero_grad()
            graphs, tensors, orders = batch
            tree_tensors, graph_tensors = tensors

            def to_tensor(x):
                if isinstance(x, torch.Tensor):
                    return x
                elif isinstance(x, np.ndarray):
                    return torch.from_numpy(x)
                elif isinstance(x, list):
                    return x
                else:
                    return torch.tensor(x)

            tree_tensors = [to_tensor(x) for x in tree_tensors]
            graph_tensors = [to_tensor(x) for x in graph_tensors]
            tensors = (tree_tensors, graph_tensors)

            #forward pass
            loss, kl_div, wacc, iacc, tacc, sacc = model(graphs, tensors, orders, beta=beta)

            #backward pass
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), train_args.clip_norm)
            optimizer.step()

            #record
            epoch_losses.append(loss.item())
            epoch_kl.append(kl_div)
            epoch_wacc.append(wacc * 100)
            epoch_tacc.append(tacc * 100)

            successful_batches += 1

            if total_step % train_args.print_iter == 0 and len(epoch_losses) >= train_args.print_iter:
                recent_loss = np.mean(epoch_losses[-train_args.print_iter:])
                recent_kl = np.mean(epoch_kl[-train_args. print_iter:])
                recent_wacc = np.mean(epoch_wacc[-train_args. print_iter:])
                recent_tacc = np.mean(epoch_tacc[-train_args. print_iter:])
                print(f"\n  [{total_step}] Loss: {recent_loss:.3f}, KL:  {recent_kl:.2f}, "
                      f"Word Acc: {recent_wacc:.1f}%, Topo Acc: {recent_tacc:.1f}%")

    #beta
    beta = min(train_args.max_beta, beta + 0.02)

    #checkpoint
    ckpt_path = os.path.join(train_args. save_dir, f"model.epoch_{epoch}")
    torch.save(model.state_dict(), ckpt_path)

    avg_loss = np.mean(epoch_losses) if epoch_losses else 0
    print(f"\nEpoch {epoch}")
    print(f"Avg Loss: {avg_loss:.4f}")

In [None]:
def calculate_tanimoto(smiles1, smiles2, radius=2, nBits=2048):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)

    if mol1 is None or mol2 is None:
        return None

    fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius, nBits=nBits)
    fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius, nBits=nBits)

    return DataStructs.TanimotoSimilarity(fp1, fp2)

def is_exact_match(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)

    if mol1 is None or mol2 is None:
        return False

    return Chem.MolToSmiles(mol1) == Chem.MolToSmiles(mol2)

def is_valid_molecule(smiles):
    if smiles is None or smiles == "":
        return False
    return Chem.MolFromSmiles(smiles) is not None

def evaluate_model_on_test(model, test_smiles, vocab, atom_vocab, batch_size=20):
    model.eval()

    dataset = MoleculeDataset(test_smiles, vocab, atom_vocab, batch_size)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x[0])

    exact_matches = 0
    valid_outputs = 0
    tanimoto_scores = []
    n_evaluated = 0

    actual_test = dataset.batches
    flat_test = [smi for batch in actual_test for smi in batch]

    with torch.no_grad():
        batch_idx = 0
        for batch in loader:
            start_idx = batch_size * batch_idx
            end_idx = min(batch_size * (batch_idx + 1), len(flat_test))
            orig_batch = flat_test[start_idx:end_idx]

            try:
                dec_smiles = model.reconstruct(batch)

                for orig, dec in zip(orig_batch, dec_smiles):
                    n_evaluated += 1

                    if is_valid_molecule(dec):
                        valid_outputs += 1
                        tanimoto = calculate_tanimoto(orig, dec)
                        if tanimoto is not None:
                            tanimoto_scores.append(tanimoto)

                        if is_exact_match(orig, dec):
                            exact_matches += 1
            except:
                pass

            batch_idx += 1

    return {
        'exact_match':  exact_matches / n_evaluated if n_evaluated > 0 else 0,
        'mean_tanimoto':  np.mean(tanimoto_scores) if tanimoto_scores else 0,
        'median_tanimoto':  np.median(tanimoto_scores) if tanimoto_scores else 0,
        'validity':  valid_outputs / n_evaluated if n_evaluated > 0 else 0,
        'n_evaluated': n_evaluated
    }

In [None]:
checkpoint_epochs = []
for f in os.listdir(train_args.save_dir):
    if f.startswith('model.epoch_'):
        epoch = int(f.split('_')[1])
        checkpoint_epochs.append(epoch)

checkpoint_epochs = sorted(checkpoint_epochs)
print(f"Found checkpoints: {checkpoint_epochs}")

results = []

for epoch in tqdm(checkpoint_epochs, desc="Evaluating checkpoints"):
    ckpt_path = os.path.join(train_args.save_dir, f"model.epoch_{epoch}")

    eval_model = HierVAE(train_args).to(device)
    eval_model.load_state_dict(torch.load(ckpt_path))
    eval_model.eval()

    metrics = evaluate_model_on_test(eval_model, test_smiles, vocab, common_atom_vocab)
    metrics['epoch'] = epoch
    results.append(metrics)

    print(f"Epoch {epoch}:  Exact={metrics['exact_match']:.3f}, "
          f"Tanimoto={metrics['mean_tanimoto']:.3f}, Valid={metrics['validity']:.3f}")

    del eval_model
    torch.cuda.empty_cache()

checkpoint_df = pd.DataFrame(results)
checkpoint_df = checkpoint_df.sort_values('epoch').reset_index(drop=True)

In [None]:
#graph
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

#exact match acc vs training step
ax1 = axes[0, 0]
ax1.plot(checkpoint_df['epoch'], checkpoint_df['exact_match'] * 100,
         'o-', color='steelblue', linewidth=2, markersize=10)
ax1.set_xlabel('Training Epoch', fontsize=12)
ax1.set_ylabel('Exact Match Accuracy (%)', fontsize=12)
ax1.set_title('Exact Match Accuracy vs Training Progress', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(checkpoint_df['epoch'])

#tanimoto mean/median vs training step
ax2 = axes[0, 1]
ax2.plot(checkpoint_df['epoch'], checkpoint_df['mean_tanimoto'],
         's-', color='coral', linewidth=2, markersize=10, label='Mean')
ax2.plot(checkpoint_df['epoch'], checkpoint_df['median_tanimoto'],
         '^--', color='green', linewidth=2, markersize=10, label='Median')
ax2.set_xlabel('Training Epoch', fontsize=12)
ax2.set_ylabel('Tanimoto Similarity', fontsize=12)
ax2.set_title('Tanimoto Similarity vs Training Progress', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_xticks(checkpoint_df['epoch'])
ax2.set_ylim(0, 1.05)

plt.tight_layout()
plt.savefig('part_b_checkpoint_dynamics.png', dpi=150, bbox_inches='tight')
plt.show()