In [1]:
!pip install torch torch_geometric tqdm rdkit transformers

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting rdkit
  Downloading rdkit-2025.9.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaPreTrainedModel, RobertaConfig, RobertaTokenizer
from transformers.models.roberta.modeling_roberta import RobertaSelfAttention, RobertaAttention, RobertaEncoder, RobertaLayer
import numpy as np
import os
import pickle
from typing import List, Tuple

2025-11-29 14:44:35.651245: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764427475.840414      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764427475.894419      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [19]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEXT_VOCAB_SIZE = 50265 
ATOM_TYPES = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'B', 'Si', 'Unknown']
BOND_TYPES = [0, 1, 2, 3] # 0: No Bond, 1: Single, 2: Double, 3: Triple
EDGE_VOCAB_SIZE = 4
NODE_VOCAB_SIZE = len(ATOM_TYPES) + 1 
FULL_VOCAB_SIZE = TEXT_VOCAB_SIZE + NODE_VOCAB_SIZE
NODE_VOCAB_START_ID = TEXT_VOCAB_SIZE
MASK_NODE_ID = NODE_VOCAB_START_ID + len(ATOM_TYPES)

In [20]:
MAX_SEQ_LEN_TEXT = 128
MAX_SEQ_LEN_GRAPH = 50 
DIFFUSION_STEPS = 1000

In [56]:
class PubChemDataset(InMemoryDataset):
    """
    Custom Dataset class to correctly load the PyG data tuple (data, slices)
    from the .pt file using torch.load.
    """
    def __init__(self, path):
        super(PubChemDataset, self).__init__()
        self.data, self.slices = torch.load(path, weights_only=False)

In [26]:
def graph_to_V_E(data_point: Data, max_nodes: int = MAX_SEQ_LEN_GRAPH):
    """ 
    Converts a PyG Data object into padded Node Tokens (V) and Edge Matrix (E). 
    """
    x = data_point.x 
    edge_index = data_point.edge_index 
    edge_attr = data_point.edge_attr 
    
    N = x.size(0)
    N_padded = min(N, max_nodes)
    
    # 1. Node Token IDs (V)
    atom_type_ids = torch.argmax(x[:N_padded].float(), dim=-1)
    V_tokens = atom_type_ids + NODE_VOCAB_START_ID 
    
    V_padded = torch.ones(max_nodes, dtype=torch.long) * MASK_NODE_ID
    V_padded[:N_padded] = V_tokens
    
    # 2. Edge Matrix (E)
    E_padded = torch.zeros(max_nodes, max_nodes, dtype=torch.long) # 0: No Bond
    
    for i in range(edge_index.size(1)):
        u, v = edge_index[:, i]
        if u < N_padded and v < N_padded:
            bond_type = torch.argmax(edge_attr[i].float()) + 1 
            E_padded[u, v] = bond_type
            E_padded[v, u] = bond_type

    return V_padded, E_padded, N_padded

In [57]:
def utgdiff_collate_fn(batch):
    """ Custom collate function for S+V sequence formation. """
    
    V_padded_list, E_padded_list = [], []
    text_input_ids, text_attention_masks = [], []
    num_nodes_list = []
    
    for data_point in batch:
        # Process Graph
        V_padded, E_padded, num_nodes = graph_to_V_E(data_point)
        V_padded_list.append(V_padded)
        E_padded_list.append(E_padded)
        num_nodes_list.append(num_nodes)
        
        # Process Text (extract attribute 'text' from Data object)
        text_content = data_point.text
        # Ensure text is a string; some datasets might have lists
        if isinstance(text_content, list): text_content = text_content[0]
            
        text_encoded = tokenizer.encode_plus(
            text_content,
            max_length=MAX_SEQ_LEN_TEXT,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        text_input_ids.append(text_encoded['input_ids'].squeeze(0))
        text_attention_masks.append(text_encoded['attention_mask'].squeeze(0))

    V_true_0 = torch.stack(V_padded_list) 
    E_true_0 = torch.stack(E_padded_list) 
    S_input_ids = torch.stack(text_input_ids) 
    S_attention_mask = torch.stack(text_attention_masks) 
    
    L_s_max = S_input_ids.size(1)
    N_max = V_true_0.size(1)
    B = len(batch)
    L_unified = L_s_max + N_max 
    
    input_ids_SV = torch.zeros(B, L_unified, dtype=torch.long)
    attention_mask_SV = torch.zeros(B, L_unified, dtype=torch.long)
    
    text_indices = []
    node_indices = []
    
    for i in range(B):
        S_i = S_input_ids[i]
        L_s_i = torch.sum(S_attention_mask[i]).item() 
        N_i = num_nodes_list[i]
        V_i = V_true_0[i, :N_i] 
        
        full_sequence = torch.cat([S_i[:L_s_i], V_i])
        L_total = full_sequence.size(0)
        
        input_ids_SV[i, :L_total] = full_sequence
        attention_mask_SV[i, :L_total] = 1
        
        text_indices.append(list(range(L_s_i)))
        node_indices.append(list(range(L_s_i, L_s_i + N_i)))
        
    return {
        'input_ids_SV': input_ids_SV.to(DEVICE),
        'attention_mask_SV': attention_mask_SV.to(DEVICE),
        'V_true_0': V_true_0.to(DEVICE),
        'E_true_0': E_true_0.to(DEVICE),
        'S_true_0': S_input_ids.to(DEVICE),
        'S_attention_mask': S_attention_mask.to(DEVICE),
        'text_indices': text_indices,
        'node_indices': node_indices,
    }

In [58]:
class UTGDiffSelfAttention(RobertaSelfAttention):
    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.FloatTensor = None, edge_bias: torch.Tensor = None, **kwargs) -> Tuple[torch.Tensor]:
        
        mixed_query_layer = self.query(hidden_states)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / np.sqrt(self.attention_head_size)

        # Apply Edge Bias B
        if edge_bias is not None:
            attention_scores = attention_scores + edge_bias

        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        
        output_layer = self.output.dense(context_layer)
        output_layer = self.output.dropout(output_layer)
        output = self.output.LayerNorm(output_layer + mixed_query_layer)
        
        return (output, attention_probs)

In [90]:
class UTGDiffAttention(RobertaAttention):
    """ FIX: Custom Attention layer that replaces RobertaAttention to accept 'edge_bias'. """
    def __init__(self, config):
        super().__init__(config)
        # Ensure the internal self-attention uses our custom implementation
        self.self = UTGDiffSelfAttention(config)

    def forward(self, hidden_states, attention_mask=None, edge_bias=None, **kwargs):
        # The crucial step: pass the edge_bias into the custom self-attention module
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            edge_bias=edge_bias,
            **kwargs
        )
        # Pass through the standard output layer (which is part of RobertaAttention)
        attention_output = self.output(self_outputs[0], hidden_states)
        return (attention_output,) + self_outputs[1:]

In [91]:
class UTGDiffLayer(RobertaLayer):
    def __init__(self, config):
        super().__init__(config)
        # FIX: Use the custom attention wrapper here
        self.attention = UTGDiffAttention(config)

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        edge_bias = kwargs.pop('edge_bias', None)
        
        # Now self.attention is UTGDiffAttention and correctly accepts edge_bias
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            edge_bias=edge_bias
        )
        attention_output = self_attention_outputs[0]
        layer_output = self.feed_forward_chunk(attention_output)
        return (layer_output,) + self_attention_outputs[1:]

In [92]:
class UTGDiffEncoder(RobertaEncoder):
    def __init__(self, config):
        super().__init__(config)
        # Use custom layers
        self.layer = nn.ModuleList([UTGDiffLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(self, hidden_states, **kwargs):
        edge_bias = kwargs.pop('edge_bias', None)
        
        for layer_module in self.layer:
            layer_outputs = layer_module(
                hidden_states, 
                edge_bias=edge_bias,
                **kwargs
            )
            hidden_states = layer_outputs[0]
        
        return (hidden_states,)

In [93]:
class UTGDiffModel(RobertaPreTrainedModel):
    def __init__(self, config, roberta_base_model=None):
        super().__init__(config)
        
        if roberta_base_model is not None:
            self.roberta = roberta_base_model
            
            # Transfer weights and inject custom encoder
            old_encoder = self.roberta.encoder
            new_encoder = UTGDiffEncoder(config)
            
            # The layers are structurally the same, so state_dict loading should work
            new_encoder.load_state_dict(old_encoder.state_dict(), strict=False)
            self.roberta.encoder = new_encoder
        else:
            self.roberta = RobertaModel(config, add_pooling_layer=False)
            self.roberta.encoder = UTGDiffEncoder(config)
            self.roberta.resize_token_embeddings(FULL_VOCAB_SIZE) 

        self.node_prediction_head = nn.Linear(config.hidden_size, NODE_VOCAB_SIZE)
        self.edge_prediction_head = nn.Linear(config.hidden_size, EDGE_VOCAB_SIZE)
        self.text_prediction_head = nn.Linear(config.hidden_size, TEXT_VOCAB_SIZE)
        
        self.bond_embedding = nn.Embedding(EDGE_VOCAB_SIZE, config.hidden_size // config.num_attention_heads)
        self.bond_projection = nn.Linear(config.hidden_size // config.num_attention_heads, config.num_attention_heads)

    def generate_edge_bias(self, edge_tensor, seq_len):
        batch_size, N_v = edge_tensor.shape[:2]
        L_s = seq_len - N_v       
        
        edge_features = self.bond_embedding(edge_tensor)
        B_graph = self.bond_projection(edge_features).permute(0, 3, 1, 2)
        
        B_full = torch.zeros(batch_size, self.config.num_attention_heads, seq_len, seq_len, device=edge_tensor.device)
        B_full[:, :, L_s:, L_s:] = B_graph
        return B_full

    def forward(self, input_ids, attention_mask, edge_tensor, text_token_indices, node_token_indices):
        seq_len = input_ids.size(1)
        edge_bias = self.generate_edge_bias(edge_tensor, seq_len)
        
        # Manually call embeddings and encoder to pass edge_bias
        embedding_output = self.roberta.embeddings(input_ids=input_ids)
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
        
        encoder_outputs = self.roberta.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            edge_bias=edge_bias 
        )
        
        sequence_output = encoder_outputs[0]
        
        batch_size, N_max = edge_tensor.shape[:2]
        L_s_max = input_ids.size(1) - N_max
        hidden_size = sequence_output.size(-1)

        node_logits = torch.zeros(batch_size, N_max, NODE_VOCAB_SIZE, device=DEVICE)
        text_logits = torch.zeros(batch_size, L_s_max, TEXT_VOCAB_SIZE, device=DEVICE)
        edge_logits = torch.zeros(batch_size, N_max, N_max, EDGE_VOCAB_SIZE, device=DEVICE)
        
        for i in range(batch_size):
            v_idx = node_token_indices[i]
            N_i = len(v_idx)
            node_h = sequence_output[i, v_idx, :]
            node_logits[i, :N_i] = self.node_prediction_head(node_h) 

            node_A = node_h.unsqueeze(1).expand(-1, N_i, hidden_size) 
            node_B = node_h.unsqueeze(0).expand(N_i, -1, hidden_size) 
            edge_logits[i, :N_i, :N_i] = self.edge_prediction_head(node_A * node_B)
            
            s_idx = text_token_indices[i]
            text_logits[i, :len(s_idx)] = self.text_prediction_head(sequence_output[i, s_idx, :])
        
        return node_logits, edge_logits, text_logits

In [94]:
def compute_utgdiff_loss(node_logits, edge_logits, text_logits, V_true, E_true, S_true, text_mask):
    L_nodes = F.cross_entropy(node_logits.permute(0, 2, 1), V_true, ignore_index=MASK_NODE_ID)
    L_edges = F.cross_entropy(edge_logits.permute(0, 3, 1, 2), E_true) 
    
    active_loss = text_mask.view(-1) == 1
    L_text = F.cross_entropy(
        text_logits.transpose(1, 2).reshape(-1, TEXT_VOCAB_SIZE)[active_loss], 
        S_true.view(-1)[active_loss]
    )
    return L_nodes + L_edges + L_text, L_nodes, L_edges, L_text

In [95]:
class DiffusionForwardModel(nn.Module):
    def __init__(self, mask_node_id, mask_edge_id, steps):
        super().__init__()
        self.mask_node_id = mask_node_id
        
    def forward(self, V_true, E_true, t):
        mask_V = torch.rand_like(V_true.float()) < 0.3
        V_t = torch.where(mask_V, torch.full_like(V_true, self.mask_node_id), V_true)
        
        mask_E = torch.rand_like(E_true.float()) < 0.3
        E_t = torch.where(mask_E, torch.zeros_like(E_true), E_true)
        return V_t, E_t

In [96]:
LEARNING_RATE = 1e-5

def train_model(data_path, batch_size, num_epochs):
    print(f"Initializing Model on {DEVICE}...")
    
    # 1. Load RoBERTa
    base_roberta = RobertaModel.from_pretrained('roberta-base', add_pooling_layer=False)
    if base_roberta.config.vocab_size != FULL_VOCAB_SIZE:
        print(f"Resizing embeddings to {FULL_VOCAB_SIZE}...")
        base_roberta.resize_token_embeddings(FULL_VOCAB_SIZE)
        
    # 2. Instantiate UTGDiffModel
    config = base_roberta.config
    config.vocab_size = FULL_VOCAB_SIZE
    model = UTGDiffModel(config, roberta_base_model=base_roberta).to(DEVICE)
    
    # 3. Setup
    diffusion_forward = DiffusionForwardModel(MASK_NODE_ID, 0, DIFFUSION_STEPS)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # 4. Load Data
    train_dataset = PubChemDataset(data_path)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=utgdiff_collate_fn)

    print(f"Starting training: {num_epochs} epochs, {len(train_dataset)} samples.")
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            
            # Forward Diffusion
            t = torch.randint(1, DIFFUSION_STEPS + 1, (batch['input_ids_SV'].size(0),), device=DEVICE)
            V_t, E_t = diffusion_forward(batch['V_true_0'], batch['E_true_0'], t)

            # Reconstruct Unified Input
            input_ids_SV_t = batch['input_ids_SV'].clone() 
            for i in range(input_ids_SV_t.size(0)):
                v_indices = batch['node_indices'][i]
                input_ids_SV_t[i, v_indices] = V_t[i, :len(v_indices)]
                
            # Denoising Network Forward Pass
            node_logits, edge_logits, text_logits = model(
                input_ids_SV_t, batch['attention_mask_SV'], E_t, 
                batch['text_indices'], batch['node_indices']
            )
            
            # Compute Loss
            loss_total, L_v, L_e, L_s = compute_utgdiff_loss(
                node_logits, edge_logits, text_logits, 
                batch['V_true_0'], batch['E_true_0'], batch['S_true_0'], batch['S_attention_mask'] 
            )
            
            loss_total.backward()
            optimizer.step()
            total_loss += loss_total.item()

            if (step + 1) % 10 == 0:
                print(f"  Step {step+1}: Loss: {loss_total.item():.4f} (V: {L_v:.3f}, E: {L_e:.3f}, S: {L_s:.3f})")
        
        print(f"Epoch {epoch+1} finished. Avg Loss: {total_loss / len(train_dataloader):.4f}")

In [97]:
KAGGLE_DATA_PATH = "/kaggle/input/pubchem324k-dataset/train.pt"

In [98]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

In [99]:
train_model(KAGGLE_DATA_PATH, batch_size=64, num_epochs=20)

Initializing Model on cuda...
Resizing embeddings to 50278...
Starting training: 20 epochs, 12000 samples.


AttributeError: 'UTGDiffSelfAttention' object has no attribute 'output'