<a href="https://www.kaggle.com/code/phoenix301123/graph-diffusion-transformer-for-multi-conditional?scriptVersionId=280985356" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install torch torch_geometric tqdm

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collectin

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
import numpy as np
import os
import math
from tqdm import tqdm

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [4]:
ATOM_TYPES = [6, 7, 8, 9, 0] # C, N, O, F, Mask
NUM_NODE_FEATURES = 11
NUM_EDGE_FEATURES = 4 # Single, Double, Triple, Aromatic

In [5]:
HIDDEN_DIM = 128
NUM_LAYERS = 4
NUM_HEADS = 4
BATCH_SIZE = 64
EPOCHS = 20
LEARNING_RATE = 1e-4
MAX_ATOMS = 9

In [6]:
class SinusoidalPositionEmbeddings(nn.Module):
    """
    Used to embed the diffusion time step (t) into a continuous vector.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=time.device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


In [7]:
class ConditionEncoder(nn.Module):
    """
    Maps the molecular property (condition) to a high-dimensional embedding.
    The condition is a single normalized QM9 property (e.g., U0).
    """
    def __init__(self, condition_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )

    def forward(self, condition):
        return self.encoder(condition)

In [8]:
class GraphDiTBlock(nn.Module):
    """
    A conceptual Graph Transformer Block with FiLM conditioning (similar to DiT).
    Processes node features and integrates time/condition embeddings.
    """
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=hidden_dim, 
            num_heads=num_heads, 
            dropout=dropout, 
            batch_first=True
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
        self.film_gamma = nn.Linear(hidden_dim, hidden_dim)
        self.film_beta = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, batch_mask, cond_embedding):
        gamma = self.film_gamma(cond_embedding)
        beta = self.film_beta(cond_embedding)   
        x_modulated = x * (1 + gamma[:, None, :]) + beta[:, None, :]
        q = k = v = x_modulated
        
        attn_mask = ~batch_mask
        key_padding_mask = ~batch_mask
        attn_output, _ = self.attn(
            q, k, v, 
            key_padding_mask=key_padding_mask
        )
        x = x + attn_output 
        x_modulated = self.norm2(x)
        x_modulated = x_modulated * (1 + gamma[:, None, :]) + beta[:, None, :]
        x = x + self.ffn(x_modulated)

        return x

In [9]:
class GraphDiT(nn.Module):
    """
    The complete Graph Diffusion Transformer Denoiser Model.
    """
    def __init__(self, num_node_features, num_edge_features, hidden_dim, num_layers, num_heads, max_atoms, conditional_dim=1):
        super().__init__()
        
        self.max_atoms = max_atoms
        self.hidden_dim = hidden_dim
        self.node_embed = nn.Linear(num_node_features, hidden_dim)
        self.time_embed = SinusoidalPositionEmbeddings(hidden_dim)
        self.cond_encoder = ConditionEncoder(conditional_dim, hidden_dim)

        self.transformer_blocks = nn.ModuleList([
            GraphDiTBlock(hidden_dim, num_heads) for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(hidden_dim)
        self.final_projection = nn.Linear(hidden_dim, num_node_features)

    def forward(self, x_noisy, t, condition, batch):
        x = self.node_embed(x_noisy)
        x_dense, batch_mask = to_dense_batch(x, batch) 
        time_emb = self.time_embed(t)             
        cond_emb = self.cond_encoder(condition)     
        joint_cond_emb = time_emb + cond_emb
        for block in self.transformer_blocks:
            x_dense = block(x_dense, batch_mask, joint_cond_emb)
        x_dense = self.final_norm(x_dense)
        output_dense = self.final_projection(x_dense) 
        output = output_dense[batch_mask].contiguous()
        return output

In [10]:
class GraphDiffusion(nn.Module):
    """
    Implements the simplified DDPM forward process and training loss.
    Operates on a discrete (categorical) noise model for node types (atoms).
    The loss is MSE between predicted noise and actual noise (epsilon).
    """
    def __init__(self, model, num_timesteps=1000, beta_start=1e-4, beta_end=2e-2):
        super().__init__()
        self.model = model
        self.num_timesteps = num_timesteps
        self.register_buffer(
            'betas', torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
        )
        self.register_buffer('alphas', 1.0 - self.betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))

    def forward_diffusion(self, x_start, t, noise=None):
        """
        Applies noise to the original node features (x_start).
        Assumes continuous features for the sake of simplicity in DDPM framework, 
        where x is a feature vector.
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_prod_t = self.sqrt_alphas_cumprod[t].view(-1, 1)
        sqrt_one_minus_alpha_prod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        x_noisy = sqrt_alpha_prod_t * x_start + sqrt_one_minus_alpha_prod_t * noise
        return x_noisy, noise

    def get_loss(self, data):
        """
        The DDPM training objective: predict the noise (epsilon) added.
        """
        x_start = data.x.float()
        batch_size = data.y.size(0)
        t = torch.randint(0, self.num_timesteps, (batch_size,), device=DEVICE).long()
        batch_index = data.batch
        t_nodes = t[batch_index]
        condition_unnorm = data.y[:, 2].unsqueeze(1).float() # [B, 1]
        MEAN, STD = 40.0, 10.0
        condition = (condition_unnorm - MEAN) / STD
        x_noisy, noise = self.forward_diffusion(x_start, t_nodes)
        noise_pred = self.model(x_noisy, t, condition, data.batch)
        loss = F.mse_loss(noise_pred, noise, reduction='none')
        loss = loss.mean()
        
        return loss
        
    @torch.no_grad()
    def reverse_diffusion(self, condition, num_nodes, steps=None):
        """
        The DDPM sampling process: Denoises data iteratively from pure noise (x_T) to x_0.
        
        Args:
            condition (torch.Tensor): The normalized property constraint ([1, 1]).
            num_nodes (int): The number of nodes (atoms) in the molecule to generate.
            steps (int): Number of steps to run the reverse process. Defaults to num_timesteps.
        """
        steps = steps if steps is not None else self.num_timesteps
        x_t = torch.randn(num_nodes, NUM_NODE_FEATURES, device=self.betas.device).float()
        batch_tensor = torch.zeros(num_nodes, dtype=torch.long, device=self.betas.device)
        cond_batch = condition.expand(1, -1) 

        for t in tqdm(reversed(range(1, steps + 1)), desc="Sampling", total=steps):
            time_tensor = torch.tensor([t], device=self.betas.device).long()
            noise_pred = self.model(x_t, time_tensor, cond_batch, batch_tensor)
            t_idx = t - 1
            alpha_t_val = self.alphas[t_idx]
            beta_t = self.betas[t_idx]
            sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t_idx]
            sqrt_recip_alpha_t = 1.0 / alpha_t_val.sqrt()
            mu_t = sqrt_recip_alpha_t * (x_t - beta_t / sqrt_one_minus_alpha_cumprod_t * noise_pred)
            if t > 1:
                sigma_t = beta_t.sqrt() 
                z = torch.randn_like(x_t)
                x_t = mu_t + sigma_t * z
            else:
                x_t = mu_t
        return x_t.float()

In [11]:
class QM9PreTransform:
    """
    A pre_transform to discretize node features (atom types) and select target.
    This prepares the data for a categorical graph diffusion model.
    """
    def __init__(self, max_atoms):
        self.max_atoms = max_atoms
        self.atom_types = ATOM_TYPES
        
    def __call__(self, data: Data) -> Data:
        return data

In [12]:
def load_data(root='./data/QM9_DiT'):
    print("Loading QM9 dataset...")

    transform = QM9PreTransform(max_atoms=MAX_ATOMS)
    dataset = QM9(root=root, pre_filter=transform)
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset = dataset[:train_size]
    val_dataset = dataset[train_size:train_size + val_size]
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=2, 
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2, 
        pin_memory=True
    )
    
    print(f"Dataset loaded: {len(dataset)} total samples.")
    print(f"Train/Val samples: {len(train_dataset)} / {len(val_dataset)}")
    
    return train_loader, val_loader

In [13]:
def train_model(train_loader, val_loader):
    # Initialize Model, Diffusion Process, Optimizer
    dit_model = GraphDiT(
        num_node_features=NUM_NODE_FEATURES,
        num_edge_features=NUM_EDGE_FEATURES,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        max_atoms=MAX_ATOMS,
        conditional_dim=1
    ).to(DEVICE)

    diffusion_process = GraphDiffusion(dit_model).to(DEVICE)
    optimizer = torch.optim.Adam(dit_model.parameters(), lr=LEARNING_RATE)
    
    print("\nStarting training...")
    
    best_val_loss = float('inf')
    KAGGLE_SAVE_DIR = '/kaggle/working/checkpoint'
    checkpoint_path = os.path.join(KAGGLE_SAVE_DIR, 'best_graph_dit_qm9.pth')
    os.makedirs(KAGGLE_SAVE_DIR, exist_ok=True)
    for epoch in range(1, EPOCHS + 1):
        # Training Phase
        dit_model.train()
        train_loss_sum = 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False)
        for data in train_pbar:
            data = data.to(DEVICE)
            if data is None: continue 
            if data.x.size(0) == 0: continue
            data.x = data.x.to(torch.float32)
            optimizer.zero_grad()

            loss = diffusion_process.get_loss(data)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(dit_model.parameters(), 1.0)
            optimizer.step()
            
            train_loss_sum += loss.item()
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

        avg_train_loss = train_loss_sum / len(train_loader)
        dit_model.eval()
        val_loss_sum = 0
        
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch} [Valid]", leave=False)
        
        with torch.no_grad():
            for data in val_pbar:
                data = data.to(DEVICE)
                if data is None: continue
                if data.x.size(0) == 0: continue
                
                loss = diffusion_process.get_loss(data)
                val_loss_sum += loss.item()
                val_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

        avg_val_loss = val_loss_sum / len(val_loader)

        print(f"Epoch {epoch}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(dit_model.state_dict(), checkpoint_path)
            print(f"-> Saved best model checkpoint to {checkpoint_path} (Val Loss: {best_val_loss:.4f})")

In [14]:
if __name__ == '__main__':
    train_loader, val_loader = load_data()
    train_model(train_loader, val_loader)

Loading QM9 dataset...


Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting data/QM9_DiT/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


Dataset loaded: 130831 total samples.
Train/Val samples: 104664 / 13083

Starting training...


                                                                               

Epoch 1/20 | Train Loss: 0.0888 | Val Loss: 0.0433
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0433)


                                                                               

Epoch 2/20 | Train Loss: 0.0475 | Val Loss: 0.0336
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0336)


                                                                               

Epoch 3/20 | Train Loss: 0.0417 | Val Loss: 0.0353


                                                                               

Epoch 4/20 | Train Loss: 0.0383 | Val Loss: 0.0307
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0307)


                                                                               

Epoch 5/20 | Train Loss: 0.0362 | Val Loss: 0.0285
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0285)


                                                                               

Epoch 6/20 | Train Loss: 0.0346 | Val Loss: 0.0291


                                                                               

Epoch 7/20 | Train Loss: 0.0334 | Val Loss: 0.0294


                                                                               

Epoch 8/20 | Train Loss: 0.0321 | Val Loss: 0.0271
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0271)


                                                                               

Epoch 9/20 | Train Loss: 0.0315 | Val Loss: 0.0264
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0264)


                                                                                

Epoch 10/20 | Train Loss: 0.0310 | Val Loss: 0.0269


                                                                                

Epoch 11/20 | Train Loss: 0.0305 | Val Loss: 0.0275


                                                                                

Epoch 12/20 | Train Loss: 0.0302 | Val Loss: 0.0264


                                                                                

Epoch 13/20 | Train Loss: 0.0294 | Val Loss: 0.0261
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0261)


                                                                                

Epoch 14/20 | Train Loss: 0.0288 | Val Loss: 0.0256
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0256)


                                                                                

Epoch 15/20 | Train Loss: 0.0285 | Val Loss: 0.0247
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0247)


                                                                                

Epoch 16/20 | Train Loss: 0.0279 | Val Loss: 0.0250


                                                                                

Epoch 17/20 | Train Loss: 0.0276 | Val Loss: 0.0244
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0244)


                                                                                

Epoch 18/20 | Train Loss: 0.0274 | Val Loss: 0.0246


                                                                                

Epoch 19/20 | Train Loss: 0.0274 | Val Loss: 0.0243
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0243)


                                                                                

Epoch 20/20 | Train Loss: 0.0270 | Val Loss: 0.0238
-> Saved best model checkpoint to /kaggle/working/checkpoint/best_graph_dit_qm9.pth (Val Loss: 0.0238)


In [15]:
import os
import shutil

cache_dir = './data/QM9_DiT' 

print(f"Checking for cache directory: {cache_dir}")
if os.path.exists(cache_dir):
    print("Found cache. Deleting old cache to force full dataset re-processing...")
    shutil.rmtree(cache_dir)
    print("Cache deleted. Run the training script now to download the full QM9 dataset (~134k samples).")
else:
    print("Cache directory not found. Proceeding with load.")

Checking for cache directory: ./data/QM9_DiT
Found cache. Deleting old cache to force full dataset re-processing...
Cache deleted. Run the training script now to download the full QM9 dataset (~134k samples).


In [16]:
def interpret_continuous_features(x_continuous):
    print("\n--- Interpreting Final Denoised Node Features ---")
    num_generated_atoms = len(x_continuous)
    atomic_number_feature = x_continuous[:, 0].cpu().numpy()
    
    print(f"Number of generated atoms: {len(x_continuous)}")
    print(f"First Atomic Number features (continuous output): {atomic_number_feature[:(num_generated_atoms)]}")
    suggested_atomic_numbers = np.round(atomic_number_feature).astype(int)
    print(f"First Suggested Atomic Numbers (Rounding the continuous output): {suggested_atomic_numbers[:(num_generated_atoms)]}")
    print("These features would need a dedicated classifier/decoder to map to valid discrete molecular structures.")
    
    return x_continuous

In [17]:
def sample_molecule(checkpoint_path, target_u0=50.0, num_atoms=9):
    """
    Loads the model and generates a new molecule conditioned on a target U0 energy.
    """
    print(f"\n--- Starting Conditional Molecule Sampling ---")
    print(f"Target property U0: {target_u0} | Generating molecule with {num_atoms} atoms.")
    dit_model = GraphDiT(
        num_node_features=NUM_NODE_FEATURES,
        num_edge_features=NUM_EDGE_FEATURES,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        max_atoms=MAX_ATOMS,
        conditional_dim=1
    ).to(DEVICE)

    diffusion_process = GraphDiffusion(dit_model).to(DEVICE) 
    if not os.path.exists(checkpoint_path):
        print(f"Error: Checkpoint file not found at {checkpoint_path}. Please train the model first.")
        return
        
    try:
        dit_model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
        print(f"Successfully loaded model weights from {checkpoint_path}")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return

    dit_model.eval()
    MEAN, STD = 40.0, 10.0
    condition_unnorm = torch.tensor([[target_u0]], dtype=torch.float32, device=DEVICE)
    condition_norm = (condition_unnorm - MEAN) / STD
    final_continuous_features = diffusion_process.reverse_diffusion(
        condition=condition_norm, 
        num_nodes=num_atoms
    )
    interpret_continuous_features(final_continuous_features)
    
    print("\nSampling complete. The result is the continuous node feature matrix.")

In [18]:
CHECKPOINT_PATH = '/kaggle/working/checkpoint/best_graph_dit_qm9.pth'
sample_molecule(CHECKPOINT_PATH, target_u0=30, num_atoms=4)


--- Starting Conditional Molecule Sampling ---
Target property U0: 30 | Generating molecule with 4 atoms.
Successfully loaded model weights from /kaggle/working/checkpoint/best_graph_dit_qm9.pth


Sampling: 100%|██████████| 1000/1000 [00:04<00:00, 243.50it/s]


--- Interpreting Final Denoised Node Features ---
Number of generated atoms: 4
First Atomic Number features (continuous output): [ 0.9934025   0.99875605  0.99524206 -0.04846189]
First Suggested Atomic Numbers (Rounding the continuous output): [1 1 1 0]
These features would need a dedicated classifier/decoder to map to valid discrete molecular structures.

Sampling complete. The result is the continuous node feature matrix.



