***approach 1 2d graph , descriptors***

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score



In [17]:
# ---- Load graphs ----

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

# ============================
# Text-Enhanced Model WITHOUT ECFP
# ============================

class TextEnhancedNoECFP(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks):
        super().__init__()
        
        self.n_tasks = n_tasks
        self.desc_feature_dim = desc_feature_dim

        # --- GNN Backbone ---
        nn1 = nn.Sequential(
            nn.Linear(node_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )
        self.gnn_conv1 = GINEConv(nn1, edge_dim=edge_feature_dim)

        nn2 = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )
        self.gnn_conv2 = GINEConv(nn2, edge_dim=edge_feature_dim)

        self.gnn_batch_norm1 = nn.BatchNorm1d(128)
        self.gnn_batch_norm2 = nn.BatchNorm1d(128)

        # Graph output dimension
        gnn_out_dim = 256  # mean + max pool

        # --- Learnable Text Prompts for Each Assay ---
        self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))
        
        # Text projection
        self.text_proj = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # --- Assay-Conditioned Fusion ---
        self.assay_weights = nn.Parameter(torch.ones(n_tasks, 2))  # [12, 2] for gnn, desc
        
        # Final classifier (NO ECFP dimension)
        classifier_input_dim = 256 + desc_feature_dim + 128  # gnn + desc + text ONLY
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, n_tasks)
        )

    def forward_gnn(self, x, edge_index, edge_attr, batch):
        # GNN processing
        x = self.gnn_conv1(x, edge_index, edge_attr)
        x = self.gnn_batch_norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=0.2, training=self.training)

        x = self.gnn_conv2(x, edge_index, edge_attr)
        x = self.gnn_batch_norm2(x)
        x = F.elu(x)

        # Readout
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        return torch.cat([mean_pool, max_pool], dim=1)

    def forward(self, data, assay_attention=None):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        B = data.num_graphs

        # 1. Process GNN features
        graph_out = self.forward_gnn(x, edge_index, edge_attr, batch)  # [B, 256]
        desc_out = data.desc_features.view(B, self.desc_feature_dim)   # [B, desc_dim]

        # 2. Text prompts for assay conditioning
        if assay_attention is None:
            # Default: equal attention to all assays
            text_weights = torch.ones(B, self.n_tasks, device=graph_out.device) / self.n_tasks
        else:
            text_weights = assay_attention
        
        # Weighted average of text prompts
        text_feat = torch.einsum('bi,ij->bj', text_weights, self.assay_prompts)  # [B, 128]
        text_feat = self.text_proj(text_feat)  # [B, 128]

        # 3. Assay-weighted fusion of modalities
        modality_weights = F.softmax(self.assay_weights, dim=1)  # [12, 2]
        
        # Use average weights across assays
        avg_weights = modality_weights.mean(dim=0)  # [2]
        
        # Apply weights to modalities (NO ECFP)
        weighted_graph = graph_out * avg_weights[0]
        weighted_desc = desc_out * avg_weights[1]

        # 4. Concatenate all features (NO ECFP)
        combined = torch.cat([weighted_graph, weighted_desc, text_feat], dim=1)
        
        # 5. Final prediction
        return self.classifier(combined)

# ============================
# Text Prompts Definition
# ============================

# Define meaningful text prompts for each assay
ASSAY_DESCRIPTIONS = {
    "NR-AR": "androgen receptor binding and endocrine disruption potential",
    "NR-AR-LBD": "androgen receptor ligand binding domain interaction", 
    "NR-AhR": "aryl hydrocarbon receptor activation and xenobiotic metabolism",
    "NR-Aromatase": "aromatase enzyme inhibition and steroid metabolism",
    "NR-ER": "estrogen receptor binding and hormonal activity",
    "NR-ER-LBD": "estrogen receptor ligand binding domain interaction",
    "NR-PPAR-gamma": "peroxisome proliferator-activated receptor gamma activation",
    "SR-ARE": "antioxidant response element activation and oxidative stress",
    "SR-ATAD5": "ATAD5 biomarker response and genotoxicity",
    "SR-HSE": "heat shock response element activation and protein stress",
    "SR-MMP": "mitochondrial membrane potential disruption",
    "SR-p53": "p53 tumor suppressor pathway activation and DNA damage"
}

# Convert to list in correct order
ASSAY_TEXTS = [ASSAY_DESCRIPTIONS[assay] for assay in ASSAYS]

# ============================
# Training Components
# ============================

def train_text_enhanced_epoch(loader, model, optimizer, criterion, device, n_tasks):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        B = batch.num_graphs
        
        # Strategy 1: Equal attention to all assays
        assay_attention = torch.ones(B, n_tasks, device=device) / n_tasks
        
        # Strategy 2: Focus on assays with positive labels in this batch
        y_batch = batch.y.float().view(B, n_tasks)
        w_batch = batch.weight.float().view(B, n_tasks)
        labeled_mask = (w_batch > 0).float()
        
        # If sample has specific assay labels, focus on those
        if labeled_mask.sum() > 0:
            assay_attention = labeled_mask / labeled_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
        
        logits = model(batch, assay_attention)

        # Targets and weights
        y = batch.y.float().view(-1, n_tasks)
        w = batch.weight.float().view(-1, n_tasks)

        # Compute loss (only on labeled positions)
        loss_unreduced = criterion(logits, y)
        mask = (w > 0).float()
        loss = (loss_unreduced * mask).sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * B
        total_graphs += B

    return total_loss / total_graphs

def evaluate(loader, model, device, assays):
    model.eval()
    all_probs = []
    all_labels = []
    all_weights = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            B = batch.num_graphs
            
            # For evaluation, use equal attention to all assays
            assay_attention = torch.ones(B, len(assays), device=device) / len(assays)
            
            logits = model(batch, assay_attention)
            probs = torch.sigmoid(logits)

            y = batch.y.float().view(-1, len(assays))
            w = batch.weight.float().view(-1, len(assays))

            all_probs.append(probs.cpu())
            all_labels.append(y.cpu())
            all_weights.append(w.cpu())

    probs = torch.cat(all_probs, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    weights = torch.cat(all_weights, dim=0).numpy()

    roc_scores = {}
    pr_scores = {}

    for j, assay in enumerate(assays):
        mask = weights[:, j] > 0
        if mask.sum() < 5:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan
            continue

        y_true = labels[mask, j]
        y_pred = probs[mask, j]

        try:
            roc_scores[assay] = roc_auc_score(y_true, y_pred)
            pr_scores[assay] = average_precision_score(y_true, y_pred)
        except ValueError:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan

    mean_roc = np.nanmean(list(roc_scores.values()))
    mean_pr = np.nanmean(list(pr_scores.values()))
    return roc_scores, pr_scores, mean_roc, mean_pr

# ============================
# Data Preparation (NO ECFP)
# ============================

# Remove ECFP from your data loading
print("Preparing data WITHOUT ECFP...")
train_graphs = torch.load("graphs/train_2d.pt")
val_graphs   = torch.load("graphs/val_2d.pt")
test_graphs  = torch.load("graphs/test_2d.pt")
# Create zero ECFP features (minimal dimension to avoid errors)
train_fp = np.zeros((len(train_graphs), 1), dtype=np.float32)
val_fp = np.zeros((len(val_graphs), 1), dtype=np.float32)  
test_fp = np.zeros((len(test_graphs), 1), dtype=np.float32)
fp_dim = 1

# Keep descriptors
if use_desc:
    train_desc = np.load(r"E:\graphml project\novel\processed\train_rdkit_desc.npz")["X"]
    val_desc = np.load(r"E:\graphml project\novel\processed\val_rdkit_desc.npz")["X"]
    test_desc = np.load(r"E:\graphml project\novel\processed\test_rdkit_desc.npz")["X"]
    desc_dim = train_desc.shape[1]
else:
    desc_dim = 32
    train_desc = np.zeros((len(train_graphs), desc_dim), dtype=np.float32)
    val_desc = np.zeros((len(val_graphs), desc_dim), dtype=np.float32)
    test_desc = np.zeros((len(test_graphs), desc_dim), dtype=np.float32)

# Attach features (ECFP will be zeros)
def attach_features_no_ecfp(graph_list, desc_array):
    for i, g in enumerate(graph_list):
        g.fp_features = torch.zeros(1).float()  # Minimal ECFP
        g.desc_features = torch.from_numpy(desc_array[i]).float()
    return graph_list

train_graphs = attach_features_no_ecfp(train_graphs, train_desc)
val_graphs = attach_features_no_ecfp(val_graphs, val_desc)
test_graphs = attach_features_no_ecfp(test_graphs, test_desc)

# DataLoaders (same as before)
BATCH_SIZE = 64
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

# ============================
# Main Training Pipeline
# ============================

def main():
    EPOCHS = 100
    LR = 2e-4
    WEIGHT_DECAY = 1e-5

    print("Initializing Text-Enhanced GNN WITHOUT ECFP...")
    
    # Get dimensions
    sample = train_graphs[0]
    node_dim = sample.x.size(1)
    edge_dim = sample.edge_attr.size(1)
    
    print(f"Node features: {node_dim}")
    print(f"Edge features: {edge_dim}")
    print(f"Descriptor features: {desc_dim}")
    print(f"Number of tasks: {len(ASSAYS)}")
    print("\nUsing Assay Prompts:")
    for assay, desc in ASSAY_DESCRIPTIONS.items():
        print(f"  {assay}: {desc}")
    
    # Initialize model
    model = TextEnhancedNoECFP(
        node_feature_dim=node_dim,
        edge_feature_dim=edge_dim,
        desc_feature_dim=desc_dim,
        n_tasks=len(ASSAYS)
    ).to(device)
    
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Optimizer and loss
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=LR, 
        weight_decay=WEIGHT_DECAY
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=10, verbose=True
    )
    
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    
    # Training loop
    best_val_roc = -1.0
    best_state = None
    patience = 15
    patience_counter = 0
    
    print("\nStarting Training...")
    print("Epoch | Train Loss | Val ROC-AUC | Val PR-AUC | LR")
    print("-" * 55)
    
    for epoch in range(1, EPOCHS + 1):
        # Training
        train_loss = train_text_enhanced_epoch(
            train_loader, model, optimizer, criterion, device, len(ASSAYS)
        )
        
        # Validation
        roc_val, pr_val, mean_roc_val, mean_pr_val = evaluate(val_loader, model, device, ASSAYS)
        
        # Update learning rate
        scheduler.step(mean_roc_val)
        
        print(f"{epoch:5d} | {train_loss:.4f}      | {mean_roc_val:.4f}      | {mean_pr_val:.4f}    | {optimizer.param_groups[0]['lr']:.2e}")
        
        # Save best model
        if mean_roc_val > best_val_roc:
            best_val_roc = mean_roc_val
            best_state = model.state_dict().copy()
            patience_counter = 0
            torch.save(model.state_dict(), "text_enhanced_no_ecfp_best.pt")
            print(f"  → New best! (ROC-AUC: {best_val_roc:.4f})")
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    print(f"\nTraining completed. Best validation ROC-AUC: {best_val_roc:.4f}")
    
    # Load best model for testing
    if best_state is not None:
        model.load_state_dict(best_state)
        print("Loaded best model for testing")
    
    # Final evaluation
    roc_test, pr_test, mean_roc_test, mean_pr_test = evaluate(test_loader, model, device, ASSAYS)
    
    print("\n" + "=" * 65)
    print("FINAL TEST METRICS (Text-Enhanced GNN WITHOUT ECFP)")
    print("=" * 65)
    for assay in ASSAYS:
        print(f"{assay:15s} | ROC-AUC: {roc_test[assay]:.4f} | PR-AUC: {pr_test[assay]:.4f}")
    print("-" * 65)
    print(f"{'Mean':15s} | ROC-AUC: {mean_roc_test:.4f} | PR-AUC: {mean_pr_test:.4f}")
    
    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'assay_prompts': model.assay_prompts.detach().cpu(),
        'assay_descriptions': ASSAY_DESCRIPTIONS,
        'test_metrics': {
            'roc_auc': roc_test,
            'pr_auc': pr_test,
            'mean_roc': mean_roc_test,
            'mean_pr': mean_pr_test
        },
        'config': {
            'use_ecfp': False,
            'use_text_prompts': True,
            'use_descriptors': True
        }
    }, "text_enhanced_no_ecfp_final.pt")
    
    print("\nModel saved as 'text_enhanced_no_ecfp_final.pt'")
    
    # Show learned prompt similarities
    print("\nLearned assay prompt similarities:")
    prompts = model.assay_prompts.detach().cpu()
    similarities = F.cosine_similarity(prompts.unsqueeze(1), prompts.unsqueeze(0), dim=2)
    
    # Show top similar assay pairs
    similar_pairs = []
    for i in range(len(ASSAYS)):
        for j in range(i + 1, len(ASSAYS)):
            similar_pairs.append((i, j, similarities[i, j].item()))
    
    similar_pairs.sort(key=lambda x: x[2], reverse=True)
    for i, j, sim in similar_pairs[:5]:  # Top 5 most similar
        print(f"  {ASSAYS[i]:15s} ↔ {ASSAYS[j]:15s}: {sim:.3f}")

# Run the training
if __name__ == "__main__":
    main()

Preparing data WITHOUT ECFP...
Initializing Text-Enhanced GNN WITHOUT ECFP...
Node features: 5
Edge features: 6
Descriptor features: 10
Number of tasks: 12

Using Assay Prompts:
  NR-AR: androgen receptor binding and endocrine disruption potential
  NR-AR-LBD: androgen receptor ligand binding domain interaction
  NR-AhR: aryl hydrocarbon receptor activation and xenobiotic metabolism
  NR-Aromatase: aromatase enzyme inhibition and steroid metabolism
  NR-ER: estrogen receptor binding and hormonal activity
  NR-ER-LBD: estrogen receptor ligand binding domain interaction
  NR-PPAR-gamma: peroxisome proliferator-activated receptor gamma activation
  SR-ARE: antioxidant response element activation and oxidative stress
  SR-ATAD5: ATAD5 biomarker response and genotoxicity
  SR-HSE: heat shock response element activation and protein stress
  SR-MMP: mitochondrial membrane potential disruption
  SR-p53: p53 tumor suppressor pathway activation and DNA damage

Model parameters: 408,007

Starting

***PNA TEXT***


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import PNAConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

# ============================
# Data Loading (ADD THIS)
# ============================

ASSAYS = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", 
    "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Load graphs ----
train_graphs = torch.load("graphs/train_2d.pt")
val_graphs   = torch.load("graphs/val_2d.pt")
test_graphs  = torch.load("graphs/test_2d.pt")

# ---- Load descriptors ----
use_desc = True
if use_desc:
    train_desc = np.load(r"E:\graphml project\novel\processed\train_rdkit_desc.npz")["X"]
    val_desc   = np.load(r"E:\graphml project\novel\processed\val_rdkit_desc.npz")["X"]
    test_desc  = np.load(r"E:\graphml project\novel\processed\test_rdkit_desc.npz")["X"]
    desc_dim = train_desc.shape[1]
else:
    desc_dim = 32
    train_desc = np.zeros((len(train_graphs), desc_dim), dtype=np.float32)
    val_desc   = np.zeros((len(val_graphs), desc_dim), dtype=np.float32)
    test_desc  = np.zeros((len(test_graphs), desc_dim), dtype=np.float32)

# ---- Attach features ----
def attach_features(graph_list, desc_array):
    for i, g in enumerate(graph_list):
        g.desc_features = torch.from_numpy(desc_array[i]).float()
    return graph_list

train_graphs = attach_features(train_graphs, train_desc)
val_graphs   = attach_features(val_graphs, val_desc)
test_graphs  = attach_features(test_graphs, test_desc)

BATCH_SIZE = 64
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

# ============================
# Text Prompts Definition (ADD THIS)
# ============================

ASSAY_DESCRIPTIONS = {
    "NR-AR": "androgen receptor binding and endocrine disruption potential",
    "NR-AR-LBD": "androgen receptor ligand binding domain interaction", 
    "NR-AhR": "aryl hydrocarbon receptor activation and xenobiotic metabolism",
    "NR-Aromatase": "aromatase enzyme inhibition and steroid metabolism",
    "NR-ER": "estrogen receptor binding and hormonal activity",
    "NR-ER-LBD": "estrogen receptor ligand binding domain interaction",
    "NR-PPAR-gamma": "peroxisome proliferator-activated receptor gamma activation",
    "SR-ARE": "antioxidant response element activation and oxidative stress",
    "SR-ATAD5": "ATAD5 biomarker response and genotoxicity",
    "SR-HSE": "heat shock response element activation and protein stress",
    "SR-MMP": "mitochondrial membrane potential disruption and cytotoxicity",
    "SR-p53": "p53 tumor suppressor pathway activation and DNA damage response"
}

print("Assay Prompts:")
for assay, desc in ASSAY_DESCRIPTIONS.items():
    print(f"  {assay}: {desc}")

# ============================
# PNA + Text Enhanced Model
# ============================

class PNAWithText(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
                 hidden_dim=128, num_layers=4, dropout=0.2):
        super().__init__()
        
        self.n_tasks = n_tasks
        self.desc_feature_dim = desc_feature_dim
        self.hidden_dim = hidden_dim
        
        # --- PNA Configuration ---
        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']
        
        # Degree distribution for molecular graphs
        self.deg = torch.tensor([0, 1, 2, 3, 4])
        
        # --- Feature Projections ---
        self.node_proj = nn.Sequential(
            nn.Linear(node_feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # --- PNA Layers ---
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        for i in range(num_layers):
            conv = PNAConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                aggregators=aggregators,
                scalers=scalers,
                deg=self.deg,
                edge_dim=edge_feature_dim,
                towers=1,
                pre_layers=1,
                post_layers=1
            )
            self.convs.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # --- Text Prompts ---
        self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))
        self.text_proj = nn.Sequential(
            nn.Linear(128, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # --- Cross-Attention Fusion ---
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        # --- Enhanced Classifier ---
        classifier_input_dim = hidden_dim * 2 + hidden_dim + desc_feature_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, n_tasks)
        )
    
    def forward(self, data, assay_attention=None):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        # 1. Project node features
        x = self.node_proj(x)
        
        # 2. PNA Message Passing
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x_residual = x
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)
            
            # Residual connection
            if i % 2 == 1:
                x = x + x_residual
            
            x = F.dropout(x, p=0.2, training=self.training)
        
        # 3. Graph Readout
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        graph_features = torch.cat([mean_pool, max_pool], dim=1)
        
        # 4. Text Conditioning
        if assay_attention is None:
            assay_attention = torch.ones(len(graph_features), self.n_tasks, 
                                       device=graph_features.device) / self.n_tasks
        
        text_embeddings = torch.einsum('bi,ij->bj', assay_attention, self.assay_prompts)
        text_features = self.text_proj(text_embeddings)
        
        # 5. Cross-Attention
        text_as_query = text_features.unsqueeze(1)
        graph_as_kv = graph_features[:, :self.hidden_dim].unsqueeze(1)
        
        attended_features, _ = self.cross_attention(
            query=text_as_query,
            key=graph_as_kv,
            value=graph_as_kv
        )
        attended_features = attended_features.squeeze(1)
        
        # 6. Descriptor Features
        desc_features = data.desc_features.view(len(graph_features), -1)
        
        # 7. Final Fusion
        combined_features = torch.cat([graph_features, attended_features, desc_features], dim=1)
        logits = self.classifier(combined_features)
        
        return logits

# ============================
# Training Functions
# ============================

def train_pna_epoch(loader, model, optimizer, criterion, device, n_tasks):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        B = batch.num_graphs
        
        # Smart assay attention
        y_batch = batch.y.float().view(B, n_tasks)
        w_batch = batch.weight.float().view(B, n_tasks)
        labeled_mask = (w_batch > 0).float()
        
        if labeled_mask.sum() > 0:
            assay_attention = labeled_mask / labeled_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
        else:
            assay_attention = torch.ones(B, n_tasks, device=device) / n_tasks
        
        logits = model(batch, assay_attention)

        # Loss computation
        y = batch.y.float().view(-1, n_tasks)
        w = batch.weight.float().view(-1, n_tasks)
        
        loss_unreduced = criterion(logits, y)
        mask = (w > 0).float()
        loss = (loss_unreduced * mask).sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * B
        total_graphs += B

    return total_loss / total_graphs

def evaluate(loader, model, device, assays):
    model.eval()
    all_probs = []
    all_labels = []
    all_weights = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            B = batch.num_graphs
            
            assay_attention = torch.ones(B, len(assays), device=device) / len(assays)
            
            logits = model(batch, assay_attention)
            probs = torch.sigmoid(logits)

            y = batch.y.float().view(-1, len(assays))
            w = batch.weight.float().view(-1, len(assays))

            all_probs.append(probs.cpu())
            all_labels.append(y.cpu())
            all_weights.append(w.cpu())

    probs = torch.cat(all_probs, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    weights = torch.cat(all_weights, dim=0).numpy()

    roc_scores = {}
    pr_scores = {}

    for j, assay in enumerate(assays):
        mask = weights[:, j] > 0
        if mask.sum() < 5:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan
            continue

        y_true = labels[mask, j]
        y_pred = probs[mask, j]

        try:
            roc_scores[assay] = roc_auc_score(y_true, y_pred)
            pr_scores[assay] = average_precision_score(y_true, y_pred)
        except ValueError:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan

    mean_roc = np.nanmean(list(roc_scores.values()))
    mean_pr = np.nanmean(list(pr_scores.values()))
    return roc_scores, pr_scores, mean_roc, mean_pr

# ============================
# Main Execution
# ============================

def main():
    EPOCHS = 100
    LR = 2e-4
    WEIGHT_DECAY = 1e-5
    
    print(f"Using device: {device}")
    
    # Initialize model
    sample = train_graphs[0]
    node_dim = sample.x.size(1)
    edge_dim = sample.edge_attr.size(1)
    
    print(f"Node features: {node_dim}")
    print(f"Edge features: {edge_dim}")
    print(f"Descriptor features: {desc_dim}")
    print(f"Number of tasks: {len(ASSAYS)}")
    
    model = PNAWithText(
        node_feature_dim=node_dim,
        edge_feature_dim=edge_dim,
        desc_feature_dim=desc_dim,
        n_tasks=len(ASSAYS),
        hidden_dim=128,
        num_layers=4
    ).to(device)
    
    print(f"PNA Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Optimizer and loss
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=LR, 
        weight_decay=WEIGHT_DECAY
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    
    # Training loop
    best_val_roc = -1.0
    print("\nStarting PNA Training...")
    
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_pna_epoch(train_loader, model, optimizer, criterion, device, len(ASSAYS))
        roc_val, pr_val, mean_roc_val, mean_pr_val = evaluate(val_loader, model, device, ASSAYS)
        scheduler.step()
        
        print(f"Epoch {epoch:03d} | Loss: {train_loss:.4f} | Val ROC: {mean_roc_val:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
        
        if mean_roc_val > best_val_roc:
            best_val_roc = mean_roc_val
            torch.save(model.state_dict(), "pna_text_best.pt")
            print(f"  → New best! (ROC: {best_val_roc:.4f})")
    
    # Final test
    model.load_state_dict(torch.load("pna_text_best.pt"))
    roc_test, pr_test, mean_roc_test, mean_pr_test = evaluate(test_loader, model, device, ASSAYS)
    
    print(f"\n🎯 FINAL PNA TEST RESULTS:")
    print(f"ROC-AUC: {mean_roc_test:.4f}")
    print(f"PR-AUC: {mean_pr_test:.4f}")
    
    # Show per-assay results
    print("\nPer-assay ROC-AUC:")
    for assay in ASSAYS:
        print(f"  {assay}: {roc_test[assay]:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda


  train_graphs = torch.load("graphs/train_2d.pt")
  val_graphs   = torch.load("graphs/val_2d.pt")
  test_graphs  = torch.load("graphs/test_2d.pt")


Assay Prompts:
  NR-AR: androgen receptor binding and endocrine disruption potential
  NR-AR-LBD: androgen receptor ligand binding domain interaction
  NR-AhR: aryl hydrocarbon receptor activation and xenobiotic metabolism
  NR-Aromatase: aromatase enzyme inhibition and steroid metabolism
  NR-ER: estrogen receptor binding and hormonal activity
  NR-ER-LBD: estrogen receptor ligand binding domain interaction
  NR-PPAR-gamma: peroxisome proliferator-activated receptor gamma activation
  SR-ARE: antioxidant response element activation and oxidative stress
  SR-ATAD5: ATAD5 biomarker response and genotoxicity
  SR-HSE: heat shock response element activation and protein stress
  SR-MMP: mitochondrial membrane potential disruption and cytotoxicity
  SR-p53: p53 tumor suppressor pathway activation and DNA damage response
Using device: cuda
Node features: 5
Edge features: 6
Descriptor features: 10
Number of tasks: 12
PNA Model parameters: 1,543,564

Starting PNA Training...




Epoch 001 | Loss: 0.4582 | Val ROC: 0.5769 | LR: 2.00e-04
  → New best! (ROC: 0.5769)




Epoch 002 | Loss: 0.2485 | Val ROC: 0.6654 | LR: 2.00e-04
  → New best! (ROC: 0.6654)




Epoch 003 | Loss: 0.2050 | Val ROC: 0.6935 | LR: 2.00e-04
  → New best! (ROC: 0.6935)




Epoch 004 | Loss: 0.1913 | Val ROC: 0.6914 | LR: 1.99e-04




Epoch 005 | Loss: 0.1860 | Val ROC: 0.7147 | LR: 1.99e-04
  → New best! (ROC: 0.7147)




Epoch 006 | Loss: 0.1829 | Val ROC: 0.7163 | LR: 1.98e-04
  → New best! (ROC: 0.7163)




Epoch 007 | Loss: 0.1791 | Val ROC: 0.7220 | LR: 1.98e-04
  → New best! (ROC: 0.7220)




Epoch 008 | Loss: 0.1780 | Val ROC: 0.7306 | LR: 1.97e-04
  → New best! (ROC: 0.7306)




Epoch 009 | Loss: 0.1761 | Val ROC: 0.7320 | LR: 1.96e-04
  → New best! (ROC: 0.7320)




Epoch 010 | Loss: 0.1738 | Val ROC: 0.7372 | LR: 1.95e-04
  → New best! (ROC: 0.7372)




Epoch 011 | Loss: 0.1726 | Val ROC: 0.7397 | LR: 1.94e-04
  → New best! (ROC: 0.7397)




Epoch 012 | Loss: 0.1716 | Val ROC: 0.7419 | LR: 1.93e-04
  → New best! (ROC: 0.7419)




Epoch 013 | Loss: 0.1689 | Val ROC: 0.7474 | LR: 1.92e-04
  → New best! (ROC: 0.7474)




Epoch 014 | Loss: 0.1682 | Val ROC: 0.7447 | LR: 1.90e-04




Epoch 015 | Loss: 0.1669 | Val ROC: 0.7503 | LR: 1.89e-04
  → New best! (ROC: 0.7503)




Epoch 016 | Loss: 0.1656 | Val ROC: 0.7528 | LR: 1.88e-04
  → New best! (ROC: 0.7528)




Epoch 017 | Loss: 0.1642 | Val ROC: 0.7506 | LR: 1.86e-04




Epoch 018 | Loss: 0.1642 | Val ROC: 0.7539 | LR: 1.84e-04
  → New best! (ROC: 0.7539)




Epoch 019 | Loss: 0.1632 | Val ROC: 0.7416 | LR: 1.83e-04




Epoch 020 | Loss: 0.1630 | Val ROC: 0.7557 | LR: 1.81e-04
  → New best! (ROC: 0.7557)




Epoch 021 | Loss: 0.1615 | Val ROC: 0.7527 | LR: 1.79e-04




Epoch 022 | Loss: 0.1613 | Val ROC: 0.7514 | LR: 1.77e-04




Epoch 023 | Loss: 0.1600 | Val ROC: 0.7586 | LR: 1.75e-04
  → New best! (ROC: 0.7586)




Epoch 024 | Loss: 0.1596 | Val ROC: 0.7546 | LR: 1.73e-04




Epoch 025 | Loss: 0.1583 | Val ROC: 0.7574 | LR: 1.71e-04




Epoch 026 | Loss: 0.1582 | Val ROC: 0.7535 | LR: 1.68e-04




Epoch 027 | Loss: 0.1559 | Val ROC: 0.7562 | LR: 1.66e-04




Epoch 028 | Loss: 0.1565 | Val ROC: 0.7653 | LR: 1.64e-04
  → New best! (ROC: 0.7653)




Epoch 029 | Loss: 0.1557 | Val ROC: 0.7541 | LR: 1.61e-04




Epoch 030 | Loss: 0.1548 | Val ROC: 0.7641 | LR: 1.59e-04




Epoch 031 | Loss: 0.1541 | Val ROC: 0.7603 | LR: 1.56e-04




Epoch 032 | Loss: 0.1538 | Val ROC: 0.7661 | LR: 1.54e-04
  → New best! (ROC: 0.7661)




Epoch 033 | Loss: 0.1523 | Val ROC: 0.7623 | LR: 1.51e-04




Epoch 034 | Loss: 0.1529 | Val ROC: 0.7608 | LR: 1.48e-04




Epoch 035 | Loss: 0.1511 | Val ROC: 0.7599 | LR: 1.45e-04




Epoch 036 | Loss: 0.1511 | Val ROC: 0.7663 | LR: 1.43e-04
  → New best! (ROC: 0.7663)




Epoch 037 | Loss: 0.1488 | Val ROC: 0.7614 | LR: 1.40e-04




Epoch 038 | Loss: 0.1498 | Val ROC: 0.7633 | LR: 1.37e-04




Epoch 039 | Loss: 0.1487 | Val ROC: 0.7635 | LR: 1.34e-04




Epoch 040 | Loss: 0.1480 | Val ROC: 0.7657 | LR: 1.31e-04




Epoch 041 | Loss: 0.1471 | Val ROC: 0.7637 | LR: 1.28e-04




Epoch 042 | Loss: 0.1467 | Val ROC: 0.7681 | LR: 1.25e-04
  → New best! (ROC: 0.7681)




Epoch 043 | Loss: 0.1474 | Val ROC: 0.7638 | LR: 1.22e-04




Epoch 044 | Loss: 0.1466 | Val ROC: 0.7634 | LR: 1.19e-04




Epoch 045 | Loss: 0.1452 | Val ROC: 0.7652 | LR: 1.16e-04




Epoch 046 | Loss: 0.1453 | Val ROC: 0.7690 | LR: 1.13e-04
  → New best! (ROC: 0.7690)




Epoch 047 | Loss: 0.1441 | Val ROC: 0.7631 | LR: 1.09e-04




Epoch 048 | Loss: 0.1439 | Val ROC: 0.7661 | LR: 1.06e-04




Epoch 049 | Loss: 0.1430 | Val ROC: 0.7635 | LR: 1.03e-04




Epoch 050 | Loss: 0.1432 | Val ROC: 0.7678 | LR: 1.00e-04




Epoch 051 | Loss: 0.1435 | Val ROC: 0.7632 | LR: 9.69e-05




Epoch 052 | Loss: 0.1413 | Val ROC: 0.7678 | LR: 9.37e-05




Epoch 053 | Loss: 0.1416 | Val ROC: 0.7658 | LR: 9.06e-05




Epoch 054 | Loss: 0.1397 | Val ROC: 0.7660 | LR: 8.75e-05




Epoch 055 | Loss: 0.1405 | Val ROC: 0.7667 | LR: 8.44e-05




Epoch 056 | Loss: 0.1404 | Val ROC: 0.7673 | LR: 8.13e-05




Epoch 057 | Loss: 0.1394 | Val ROC: 0.7676 | LR: 7.82e-05




Epoch 058 | Loss: 0.1388 | Val ROC: 0.7666 | LR: 7.51e-05




Epoch 059 | Loss: 0.1388 | Val ROC: 0.7677 | LR: 7.21e-05




Epoch 060 | Loss: 0.1375 | Val ROC: 0.7691 | LR: 6.91e-05
  → New best! (ROC: 0.7691)




Epoch 061 | Loss: 0.1377 | Val ROC: 0.7645 | LR: 6.61e-05




Epoch 062 | Loss: 0.1370 | Val ROC: 0.7677 | LR: 6.32e-05




Epoch 063 | Loss: 0.1369 | Val ROC: 0.7639 | LR: 6.03e-05




Epoch 064 | Loss: 0.1354 | Val ROC: 0.7640 | LR: 5.74e-05




Epoch 065 | Loss: 0.1358 | Val ROC: 0.7669 | LR: 5.46e-05




Epoch 066 | Loss: 0.1351 | Val ROC: 0.7654 | LR: 5.18e-05




Epoch 067 | Loss: 0.1349 | Val ROC: 0.7689 | LR: 4.91e-05




Epoch 068 | Loss: 0.1342 | Val ROC: 0.7675 | LR: 4.64e-05




Epoch 069 | Loss: 0.1339 | Val ROC: 0.7674 | LR: 4.38e-05




Epoch 070 | Loss: 0.1332 | Val ROC: 0.7645 | LR: 4.12e-05




Epoch 071 | Loss: 0.1341 | Val ROC: 0.7670 | LR: 3.87e-05




Epoch 072 | Loss: 0.1333 | Val ROC: 0.7673 | LR: 3.63e-05




Epoch 073 | Loss: 0.1322 | Val ROC: 0.7658 | LR: 3.39e-05




Epoch 074 | Loss: 0.1318 | Val ROC: 0.7675 | LR: 3.15e-05




Epoch 075 | Loss: 0.1331 | Val ROC: 0.7690 | LR: 2.93e-05




Epoch 076 | Loss: 0.1331 | Val ROC: 0.7691 | LR: 2.71e-05
  → New best! (ROC: 0.7691)




Epoch 077 | Loss: 0.1322 | Val ROC: 0.7709 | LR: 2.50e-05
  → New best! (ROC: 0.7709)




Epoch 078 | Loss: 0.1314 | Val ROC: 0.7693 | LR: 2.29e-05




Epoch 079 | Loss: 0.1320 | Val ROC: 0.7679 | LR: 2.10e-05




Epoch 080 | Loss: 0.1320 | Val ROC: 0.7690 | LR: 1.91e-05




Epoch 081 | Loss: 0.1311 | Val ROC: 0.7674 | LR: 1.73e-05




Epoch 082 | Loss: 0.1305 | Val ROC: 0.7678 | LR: 1.56e-05




Epoch 083 | Loss: 0.1318 | Val ROC: 0.7696 | LR: 1.39e-05




Epoch 084 | Loss: 0.1311 | Val ROC: 0.7698 | LR: 1.24e-05




Epoch 085 | Loss: 0.1297 | Val ROC: 0.7681 | LR: 1.09e-05




Epoch 086 | Loss: 0.1295 | Val ROC: 0.7678 | LR: 9.52e-06




Epoch 087 | Loss: 0.1307 | Val ROC: 0.7687 | LR: 8.22e-06




Epoch 088 | Loss: 0.1307 | Val ROC: 0.7693 | LR: 7.02e-06




Epoch 089 | Loss: 0.1306 | Val ROC: 0.7693 | LR: 5.91e-06




Epoch 090 | Loss: 0.1286 | Val ROC: 0.7699 | LR: 4.89e-06




Epoch 091 | Loss: 0.1298 | Val ROC: 0.7699 | LR: 3.97e-06




Epoch 092 | Loss: 0.1299 | Val ROC: 0.7690 | LR: 3.14e-06




Epoch 093 | Loss: 0.1298 | Val ROC: 0.7685 | LR: 2.41e-06




Epoch 094 | Loss: 0.1294 | Val ROC: 0.7709 | LR: 1.77e-06




Epoch 095 | Loss: 0.1298 | Val ROC: 0.7691 | LR: 1.23e-06




Epoch 096 | Loss: 0.1295 | Val ROC: 0.7686 | LR: 7.89e-07




Epoch 097 | Loss: 0.1279 | Val ROC: 0.7693 | LR: 4.44e-07




Epoch 098 | Loss: 0.1292 | Val ROC: 0.7682 | LR: 1.97e-07




Epoch 099 | Loss: 0.1290 | Val ROC: 0.7687 | LR: 4.93e-08




Epoch 100 | Loss: 0.1303 | Val ROC: 0.7707 | LR: 0.00e+00


  model.load_state_dict(torch.load("pna_text_best.pt"))



🎯 FINAL PNA TEST RESULTS:
ROC-AUC: 0.7398
PR-AUC: 0.2310

Per-assay ROC-AUC:
  NR-AR: 0.7649
  NR-AR-LBD: 0.7996
  NR-AhR: 0.8087
  NR-Aromatase: 0.7024
  NR-ER: 0.6388
  NR-ER-LBD: 0.7250
  NR-PPAR-gamma: 0.7562
  SR-ARE: 0.6823
  SR-ATAD5: 0.6953
  SR-HSE: 0.7715
  SR-MMP: 0.7869
  SR-p53: 0.7459


***gt plus gnn***

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool
import numpy as np

# ============================
# REAL TEXT PROMPTS FOR TOX21 ASSAYS
# ============================

ASSAY_DESCRIPTIONS = {
    "NR-AR": "androgen receptor binding activity and endocrine disruption potential",
    "NR-AR-LBD": "androgen receptor ligand binding domain interaction and binding affinity", 
    "NR-AhR": "aryl hydrocarbon receptor activation and xenobiotic metabolism pathway",
    "NR-Aromatase": "aromatase enzyme inhibition and steroid hormone metabolism",
    "NR-ER": "estrogen receptor binding and hormonal activity assessment",
    "NR-ER-LBD": "estrogen receptor ligand binding domain interaction specificity",
    "NR-PPAR-gamma": "peroxisome proliferator-activated receptor gamma activation",
    "SR-ARE": "antioxidant response element activation and oxidative stress response",
    "SR-ATAD5": "ATAD5 biomarker response and DNA damage genotoxicity",
    "SR-HSE": "heat shock response element activation and protein stress response",
    "SR-MMP": "mitochondrial membrane potential disruption and cytotoxicity",
    "SR-p53": "p53 tumor suppressor pathway activation and DNA damage response"
}

# Convert to embedding initialization
def initialize_prompts_from_descriptions(descriptions_dict, assays_list, embedding_dim=128):
    """Initialize prompts using text descriptions"""
    prompts = []
    
    for assay in assays_list:
        description = descriptions_dict[assay]
        # Simple hash-based initialization
        hash_val = hash(description) % 10000
        torch.manual_seed(hash_val)
        prompt = torch.randn(embedding_dim) * 0.1
        prompts.append(prompt)
    
    return torch.stack(prompts)

# ============================
# Enhanced GNN Branch WITH REAL PROMPTS
# ============================

class EnhancedGNNBranch(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
                 hidden_dim=128, num_layers=4, dropout=0.2, assay_descriptions=None):
        super().__init__()
        
        self.n_tasks = n_tasks
        self.hidden_dim = hidden_dim
        
        # Node projection
        self.node_proj = nn.Sequential(
            nn.Linear(node_feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # GINE layers
        self.gnn_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        for i in range(num_layers):
            nn_mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim)
            )
            conv = GINEConv(nn_mlp, edge_dim=edge_feature_dim)
            self.gnn_layers.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # REAL TEXT PROMPTS INITIALIZATION
        if assay_descriptions:
            initial_prompts = initialize_prompts_from_descriptions(
                assay_descriptions, 
                list(assay_descriptions.keys())
            )
            self.assay_prompts = nn.Parameter(initial_prompts)
        else:
            self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))
        
        self.text_proj = nn.Sequential(
            nn.Linear(128, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Cross-attention
        self.cross_attention = nn.MultiheadAttention(
            hidden_dim, num_heads=8, dropout=0.1, batch_first=True
        )
        
        # Output
        self.output_proj = nn.Linear(hidden_dim * 2 + hidden_dim + desc_feature_dim, n_tasks)
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, -1.0)
    
    def forward(self, data, assay_attention=None):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        B = data.num_graphs
        
        # 1. GNN Processing
        x = self.node_proj(x)
        
        for i, (conv, bn) in enumerate(zip(self.gnn_layers, self.batch_norms)):
            x_res = x
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)
            if i % 2 == 1:
                x = x + x_res
            x = F.dropout(x, p=0.2, training=self.training)
        
        # 2. Graph Readout
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        graph_features = torch.cat([mean_pool, max_pool], dim=1)
        
        # 3. TEXT PROMPT CONDITIONING
        if assay_attention is None:
            assay_attention = torch.ones(B, self.n_tasks, device=graph_features.device) / self.n_tasks
        
        text_emb = torch.einsum('bi,ij->bj', assay_attention, self.assay_prompts)
        text_features = self.text_proj(text_emb)
        
        # 4. Cross-Attention
        text_as_query = text_features.unsqueeze(1)
        graph_as_kv = graph_features[:, :self.hidden_dim].unsqueeze(1)
        attended, _ = self.cross_attention(text_as_query, graph_as_kv, graph_as_kv)
        attended_features = attended.squeeze(1)
        
        # 5. Fusion
        desc_features = data.desc_features.view(B, -1)
        combined = torch.cat([graph_features, attended_features, desc_features], dim=1)
        
        return self.output_proj(combined)

# ============================
# Enhanced Transformer Branch WITH REAL PROMPTS
# ============================

class MolecularGraphTransformer(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
                 hidden_dim=128, num_layers=4, num_heads=8, dropout=0.1, assay_descriptions=None):
        super().__init__()
        
        self.n_tasks = n_tasks
        self.hidden_dim = hidden_dim
        
        # Feature projections
        self.node_proj = nn.Linear(node_feature_dim, hidden_dim)
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(hidden_dim))
        
        # Transformer layers
        self.transformer_layers = nn.ModuleList([
            GraphTransformerLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        
        # REAL TEXT PROMPTS INITIALIZATION
        if assay_descriptions:
            initial_prompts = initialize_prompts_from_descriptions(
                assay_descriptions,
                list(assay_descriptions.keys())
            )
            self.assay_prompts = nn.Parameter(initial_prompts)
        else:
            self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))
        
        self.text_proj = nn.Sequential(
            nn.Linear(128, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Multi-modal fusion
        self.fusion_attention = nn.MultiheadAttention(
            hidden_dim, num_heads=8, dropout=0.1, batch_first=True
        )
        
        # Output
        self.output_proj = nn.Linear(hidden_dim + hidden_dim + desc_feature_dim, n_tasks)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.cls_token, std=0.02)
    
    def forward(self, data, assay_attention=None):
        x, batch = data.x, data.batch
        B = data.num_graphs
        
        # 1. Project node features
        node_features = self.node_proj(x)
        
        # 2. Process each graph with transformer
        graph_representations = []
        unique_batches = torch.unique(batch)
        
        for graph_idx in unique_batches:
            graph_mask = (batch == graph_idx)
            graph_nodes = node_features[graph_mask]
            
            # Add CLS token
            cls_tokens = self.cls_token.unsqueeze(0)
            sequence = torch.cat([cls_tokens, graph_nodes], dim=0)
            
            # Transformer processing
            for transformer_layer in self.transformer_layers:
                sequence = transformer_layer(sequence)
            
            graph_rep = sequence[0]
            graph_representations.append(graph_rep)
        
        graph_features = torch.stack(graph_representations)
        
        # 3. TEXT PROMPT CONDITIONING
        if assay_attention is None:
            assay_attention = torch.ones(B, self.n_tasks, device=graph_features.device) / self.n_tasks
        
        text_emb = torch.einsum('bi,ij->bj', assay_attention, self.assay_prompts)
        text_features = self.text_proj(text_emb)
        
        # 4. Multi-modal fusion
        graph_as_query = graph_features.unsqueeze(1)
        text_as_kv = text_features.unsqueeze(1)
        fused, _ = self.fusion_attention(graph_as_query, text_as_kv, text_as_kv)
        fused_features = fused.squeeze(1)
        
        # 5. Final fusion
        desc_features = data.desc_features.view(B, -1)
        combined = torch.cat([graph_features, fused_features, desc_features], dim=1)
        
        return self.output_proj(combined)

class GraphTransformerLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attn_out, _ = self.self_attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x

# ============================
# Hybrid Model WITH REAL PROMPTS
# ============================

class HybridTextGraphModel(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks, assay_descriptions):
        super().__init__()
        
        self.n_tasks = n_tasks
        
        # Two branches with REAL PROMPTS
        self.gnn_branch = EnhancedGNNBranch(
            node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
            assay_descriptions=assay_descriptions
        )
        
        self.transformer_branch = MolecularGraphTransformer(
            node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
            assay_descriptions=assay_descriptions
        )
        
        # Learnable fusion
        self.fusion_weights = nn.Parameter(torch.ones(2))
        self.task_gates = nn.Parameter(torch.ones(n_tasks, 2))
    
    def forward(self, data, assay_attention=None):
        gnn_logits = self.gnn_branch(data, assay_attention)
        transformer_logits = self.transformer_branch(data, assay_attention)
        
        global_weights = F.softmax(self.fusion_weights, dim=0)
        task_weights = F.softmax(self.task_gates, dim=1)
        
        global_weights_expanded = global_weights.unsqueeze(0).unsqueeze(0)
        task_weights_expanded = task_weights.unsqueeze(0)
        
        combined_weights = global_weights_expanded * task_weights_expanded
        logits_stacked = torch.stack([gnn_logits, transformer_logits], dim=-1)
        fused_logits = (logits_stacked * combined_weights).sum(dim=-1)
        
        return fused_logits, gnn_logits, transformer_logits, global_weights

# ============================
# Training Pipeline
# ============================

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, logits, targets, weights=None):
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if weights is not None:
            focal_loss = focal_loss * weights
            
        if weights is not None and weights.sum() > 0:
            return focal_loss.sum() / weights.sum()
        return focal_loss.mean()

def main():
    ASSAYS = [
        "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
        "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
        "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
    ]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Print the REAL prompts we're using
    print(" REAL TEXT PROMPTS BEING USED:")
    for assay, desc in ASSAY_DESCRIPTIONS.items():
        print(f"   {assay}: {desc}")
    
    # FIXED: Pass assay_descriptions argument
    sample = train_graphs[0]
    model = HybridTextGraphModel(
        node_feature_dim=sample.x.size(1),
        edge_feature_dim=sample.edge_attr.size(1),
        desc_feature_dim=desc_dim,
        n_tasks=len(ASSAYS),
        assay_descriptions=ASSAY_DESCRIPTIONS  # THIS WAS MISSING!
    ).to(device)
    
    print(f"\n Model initialized with REAL text prompts for {len(ASSAYS)} assays")
    print(f"   GNN prompts shape: {model.gnn_branch.assay_prompts.shape}")
    print(f"   Transformer prompts shape: {model.transformer_branch.assay_prompts.shape}")
    
    # Show prompt values
    print(f"\n Sample prompt values (first 5 dimensions):")
    for i, assay in enumerate(ASSAYS):
        prompt_vals = model.gnn_branch.assay_prompts[i][:5].detach().cpu().numpy()
        print(f"   {assay}: {prompt_vals}")
    
    # Training setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
    criterion = FocalLoss(alpha=0.75, gamma=2.0)
    
    print("\n Starting training with real text prompts...")
    
    # Training loop
    for epoch in range(1, 101):
        model.train()
        total_loss = 0.0
        
        for batch in train_loader:
            batch = batch.to(device)
            B = batch.num_graphs
            
            # Smart assay attention
            y_batch = batch.y.float().view(B, len(ASSAYS))
            w_batch = batch.weight.float().view(B, len(ASSAYS))
            labeled_mask = (w_batch > 0).float()
            
            if labeled_mask.sum() > 0:
                assay_attention = labeled_mask / labeled_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
            else:
                assay_attention = torch.ones(B, len(ASSAYS), device=device) / len(ASSAYS)
            
            # Forward pass with REAL prompts
            fused_logits, gnn_logits, transformer_logits, fusion_weights = model(batch, assay_attention)
            
            # Loss
            y = batch.y.float().view(-1, len(ASSAYS))
            w = batch.weight.float().view(-1, len(ASSAYS))
            loss = criterion(fused_logits, y, w)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        if epoch % 10 == 0:
            gnn_w, tr_w = fusion_weights[0].item(), fusion_weights[1].item()
            print(f"Epoch {epoch:3d} | Loss: {total_loss:.4f} | Fusion: GNN:{gnn_w:.2f}/TR:{tr_w:.2f}")

if __name__ == "__main__":
    main()

Using device: cuda
🔬 REAL TEXT PROMPTS BEING USED:
   NR-AR: androgen receptor binding activity and endocrine disruption potential
   NR-AR-LBD: androgen receptor ligand binding domain interaction and binding affinity
   NR-AhR: aryl hydrocarbon receptor activation and xenobiotic metabolism pathway
   NR-Aromatase: aromatase enzyme inhibition and steroid hormone metabolism
   NR-ER: estrogen receptor binding and hormonal activity assessment
   NR-ER-LBD: estrogen receptor ligand binding domain interaction specificity
   NR-PPAR-gamma: peroxisome proliferator-activated receptor gamma activation
   SR-ARE: antioxidant response element activation and oxidative stress response
   SR-ATAD5: ATAD5 biomarker response and DNA damage genotoxicity
   SR-HSE: heat shock response element activation and protein stress response
   SR-MMP: mitochondrial membrane potential disruption and cytotoxicity
   SR-p53: p53 tumor suppressor pathway activation and DNA damage response

✅ Model initialized with R

***3 d feature extraction***