In [15]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.tensorboard import SummaryWriter
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image
import numpy as np
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import SelectQM9TargetProperties, SelectQM9NodeFeatures
import random
import os

# disable logging
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

def smiles_to_image(smiles: str) -> torch.tensor:
    mol = Chem.MolFromSmiles(smiles)
    image = Draw.MolToImage(mol)
    image = np.array(image)
    # Convert to CHW format
    return torch.tensor(np.transpose(image, (2, 0, 1)))

def molecule_graph_data_to_image(data) -> torch.tensor:
    # create empty molecule
    mol = Chem.RWMol()

    class_index_to_atomic_number = {
        0: 1, 1: 6, 2: 7, 3: 8, 4: 9
    }
    # Add atoms
    for atom_features in sample.x:
        # convert the one-hot encoded atom class to the atomic number
        class_index = torch.argmax(atom_features[:5]).item()
        atomic_number = class_index_to_atomic_number[class_index]
        atom = Chem.Atom(int(atomic_number))
        mol.AddAtom(atom)  

    # Create set of undirected bonds
    undirected_bonds = set()
    for edge_indices, edge_feature in zip(data.edge_index.t(), data.edge_attr):
        start_atom, end_atom = edge_indices
        bond_type_index = torch.argmax(edge_feature).item()
        bond = tuple(sorted((start_atom.item(), end_atom.item())) + [bond_type_index])
        undirected_bonds.add(bond)

    bond_type_map = {
        0: Chem.BondType.SINGLE,
        1: Chem.BondType.DOUBLE,
        2: Chem.BondType.TRIPLE,
        3: Chem.BondType.AROMATIC
    }
    # Add bonds
    for start_atom, end_atom, bond_type_index in undirected_bonds:
        mol.AddBond(int(start_atom), int(end_atom), bond_type_map[bond_type_index])

    # Check if the molecule is chemically valid
    try:
        Chem.SanitizeMol(mol)
    except Exception as e:
        print(f"Chemically invalid molecule! Reason: {e}")
    
    # Convert to a standard RDKit mol object
    mol = mol.GetMol()

    # Remove hydrogen atoms for visualization
    mol = Chem.RemoveHs(mol)

    image = Draw.MolToImage(mol)
    image = np.array(image)
    # Convert to CHW format
    return torch.tensor(np.transpose(image, (2, 0, 1)))

random_index = random.randint(0, len(dataset))
random_index = 32
sample = dataset[random_index]

logdir = "./tb_logs/qm9-visualization"
os.makedirs(logdir, exist_ok=True)
experiment_index = len(os.listdir(logdir))
writer = SummaryWriter(os.path.join(logdir, str(experiment_index).zfill(3)))

writer.add_image('Molecule (from SMILES)', smiles_to_image(sample.smiles), 0)
writer.add_image('Molecule (from Graph)', molecule_graph_data_to_image(sample), 0)
writer.close()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
