In [None]:
!pip install torch torch_geometric
!pip install deepchem
!pip install rdkit
!pip install optuna

# Download files
For all datasets and model files, download using this link: https://drive.google.com/drive/folders/1VFI8eS-SUUkvcUi4scdOVY5Ijq5J8boB?usp=sharing

# Data Pre-processing

Fetching the data

In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys, Descriptors
import torch
from torch_geometric.data import Data
import deepchem as dc

In [None]:
# set random seed for reproducibility

np.random.seed(42)
torch.manual_seed(42)
os.makedirs("tox21_processed", exist_ok=True)

In [None]:
print("ðŸ“¥ Loading Tox21 dataset...")
tasks, datasets, transformers = dc.molnet.load_tox21(featurizer="Raw", data_dir=".", save_dir=".")
train_dataset, valid_dataset, test_dataset = datasets


In [None]:
# convert datasets to pandas DataFrames for easier manipulation
train_df = train_dataset.to_dataframe()
valid_df = valid_dataset.to_dataframe()
test_df = test_dataset.to_dataframe()

# map tasks to column names such as tasks[0] is y1 column, replace it with tasks[0]
for i, task in enumerate(tasks):
    train_df.rename(columns={f'y{i+1}': task}, inplace=True)
    valid_df.rename(columns={f'y{i+1}': task}, inplace=True)
    test_df.rename(columns={f'y{i+1}': task}, inplace=True)

# save raw data to CSV files
train_df.to_csv("tox21_processed/train_raw.csv", index=False)
valid_df.to_csv("tox21_processed/valid_raw.csv", index=False)
test_df.to_csv("tox21_processed/test_raw.csv", index=False)

Basic EDA

In [None]:
# do some basic EDA
print(f"Number of tasks: {len(tasks)}")
print(f"Tasks: {tasks}")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(valid_df)}")
print(f"Test set size: {len(test_df)}")
print("Sample data:")

print("âœ… Data loading and initial processing complete.")
train_df.head()

In [None]:
# check for missing values
print("Missing values in training set:")
print(train_df.isnull().sum())
print("Missing values in validation set:")
print(valid_df.isnull().sum())
print("Missing values in test set:")
print(test_df.isnull().sum())

# distribution of labels for each task
for task in tasks:
    print(f"Label distribution for task {task} in training set:")
    print(train_df[task].value_counts(dropna=False))
    print(f"Label distribution for task {task} in validation set:")
    print(valid_df[task].value_counts(dropna=False))
    print(f"Label distribution for task {task} in test set:")
    print(test_df[task].value_counts(dropna=False))

# visualize some molecules
print("Sample molecules from training set:")
for smi in train_df['ids'].head(5):
    mol = Chem.MolFromSmiles(smi)
    display(Chem.Draw.MolToImage(mol))

Extract molecular features for training basic machine learning models.

In [None]:
# extract molecular features using RDKit and save to CSV files
def featurize_molecule(smi):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    features = {}
    # Basic descriptors
    features['MolWt'] = Descriptors.MolWt(mol)
    features['NumHDonors'] = Descriptors.NumHDonors(mol)
    features['NumHAcceptors'] = Descriptors.NumHAcceptors(mol)
    features['TPSA'] = Descriptors.TPSA(mol)
    features['LogP'] = Descriptors.MolLogP(mol)
    # MACCS keys
    maccs = MACCSkeys.GenMACCSKeys(mol)
    for i in range(167):
        features[f'MACCS_{i}'] = int(maccs.GetBit(i))
    # Morgan fingerprint
    morgan_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
    for i in range(2048):
        features[f'Morgan_{i}'] = int(morgan_fp.GetBit(i))
    return features

def featurize_dataset(df):
    feature_list = []
    for smi in tqdm(df['ids'], desc="Featurizing molecules"):
        feats = featurize_molecule(smi)
        if feats is not None:
            feature_list.append(feats)
        else:
            feature_list.append({})  # Append empty dict for invalid SMILES
    features_df = pd.DataFrame(feature_list)
    return pd.concat([df.reset_index(drop=True), features_df.reset_index(drop=True)], axis=1)

train_features_df = featurize_dataset(train_df)
valid_features_df = featurize_dataset(valid_df)
test_features_df = featurize_dataset(test_df)

# save featurized data to CSV files
train_features_df.to_csv("tox21_processed/train_featurized.csv", index=False)
valid_features_df.to_csv("tox21_processed/valid_featurized.csv", index=False)
test_features_df.to_csv("tox21_processed/test_featurized.csv", index=False)


In [None]:
#check for any remaining missing values
print("Missing values in featurized training set:")
print(train_features_df.isnull().sum().sum())
print("Missing values in featurized validation set:")
print(valid_features_df.isnull().sum().sum())
print("Missing values in featurized test set:")
print(test_features_df.isnull().sum().sum())

In [None]:
# normalize continuous features
from sklearn.preprocessing import StandardScaler
continuous_features = ['MolWt', 'NumHDonors', 'NumHAcceptors', 'TPSA', 'LogP']
scaler = StandardScaler()
train_features_df[continuous_features] = scaler.fit_transform(train_features_df[continuous_features])
valid_features_df[continuous_features] = scaler.transform(valid_features_df[continuous_features])
test_features_df[continuous_features] = scaler.transform(test_features_df[continuous_features])
print("âœ… Normalization complete.")

In [None]:
# save normalized data to CSV files
train_features_df.to_csv("tox21_processed/train_normalized.csv", index=False)
valid_features_df.to_csv("tox21_processed/valid_normalized.csv", index=False)
test_features_df.to_csv("tox21_processed/test_normalized.csv", index=False)
print("âœ… Data processing pipeline complete. Processed files saved in 'tox21_processed' directory.")

# Baseline Models

Random Forest

In [None]:
# train a simple model to verify the pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
X_train = train_features_df[continuous_features + [col for col in train_features_df.columns if col.startswith('MACCS_') or col.startswith('Morgan_')]]
y_train = train_features_df[tasks].fillna(0).astype(int)
X_valid = valid_features_df[continuous_features + [col for col in valid_features_df.columns if col.startswith('MACCS_') or col.startswith('Morgan_')]]
y_valid = valid_features_df[tasks].fillna(0).astype(int)
X_test = test_features_df[continuous_features + [col for col in test_features_df.columns if col.startswith('MACCS_') or col.startswith('Morgan_')]]
y_test = test_features_df[tasks].fillna(0).astype(int)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Classification report on test set:")
print(classification_report(y_test, y_pred, zero_division=0))

# calculate roc_auc_score for each task
from sklearn.metrics import roc_auc_score
for i, task in enumerate(tasks):
    try:
        auc = roc_auc_score(y_test[task], y_xgb_pred[:, i])
        print(f"ROC AUC for task {task}: {auc:.4f}")
    except ValueError:
        print(f"ROC AUC for task {task}: Cannot be computed (only one class present in y_true)")
        continue

print("âœ… Model training and evaluation complete.")

DNN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, classification_report

# ============================
# Improved Deep Neural Network
# ============================
class DeepToxDNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[2048, 1024, 512, 256], dropout=0.5):
        super(DeepToxDNN, self).__init__()

        layers = []
        prev_dim = input_dim

        # Build hidden layers with batch norm and dropout
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim

        # Output layer (no activation - use BCEWithLogitsLoss)
        layers.append(nn.Linear(prev_dim, output_dim))

        self.network = nn.Sequential(*layers)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

    def forward(self, x):
        return self.network(x)

# ============================
# Training Function
# ============================
def train_dnn(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        logits = model(X_batch)

        # Handle missing labels (if y contains NaN, mask them)
        mask = ~torch.isnan(y_batch)
        if mask.sum() == 0:
            continue

        loss = criterion(logits[mask], y_batch[mask])
        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * X_batch.size(0)
        total_samples += X_batch.size(0)

    return total_loss / total_samples if total_samples > 0 else 0

# ============================
# Evaluation Function
# ============================
def evaluate_dnn(model, loader, device, tasks):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for X_batch, y_batch in loader:
            X_batch = X_batch.to(device)
            logits = model(X_batch)
            probs = torch.sigmoid(logits)

            all_probs.append(probs.cpu().numpy())
            all_labels.append(y_batch.cpu().numpy())

    probs = np.vstack(all_probs)
    labels = np.vstack(all_labels)

    # Compute per-task metrics
    aucs = []
    print("\nPer-task ROC-AUC:")
    for i, task in enumerate(tasks):
        # Filter valid labels (handle NaN if present)
        mask = ~np.isnan(labels[:, i])
        if mask.sum() < 10:
            print(f"  {task}: Insufficient data")
            continue

        y_true = labels[mask, i]
        y_pred = probs[mask, i]

        try:
            auc = roc_auc_score(y_true, y_pred)
            aucs.append(auc)
            print(f"  {task}: {auc:.4f}")
        except ValueError as e:
            print(f"  {task}: Cannot compute - {e}")

    mean_auc = np.mean(aucs) if aucs else 0
    return mean_auc

# ============================
# Main Training Script
# ============================
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")

    # Hyperparameters
    input_dim = X_train.shape[1]
    output_dim = len(tasks)
    batch_size = 128
    learning_rate = 0.001
    num_epochs = 100
    patience = 15

    print(f"Input dimension: {input_dim}")
    print(f"Output dimension: {output_dim}")
    print(f"Training samples: {len(X_train)}")
    print(f"Validation samples: {len(X_valid)}")
    print(f"Test samples: {len(X_test)}\n")

    # Convert to tensors (handle missing labels by keeping them as NaN)
    X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32)

    X_valid_tensor = torch.tensor(X_valid.values, dtype=torch.float32)
    y_valid_tensor = torch.tensor(y_valid.values, dtype=torch.float32)

    X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)

    # Create data loaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    valid_dataset = TensorDataset(X_valid_tensor, y_valid_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model = DeepToxDNN(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dims=[2048, 1024, 512, 256],
        dropout=0.5
    ).to(device)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    # Loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )

    # Training loop
    best_val_auc = 0
    patience_counter = 0

    print("="*60)
    print("Starting Training")
    print("="*60)

    for epoch in range(1, num_epochs + 1):
        train_loss = train_dnn(model, train_loader, optimizer, criterion, device)

        if epoch % 5 == 0 or epoch == 1:
            print(f"\nEpoch {epoch}/{num_epochs}")
            print(f"  Train Loss: {train_loss:.4f}")

            # Validation
            val_auc = evaluate_dnn(model, valid_loader, device, tasks)
            print(f"  Mean Validation AUC: {val_auc:.4f}")

            # Learning rate scheduling
            scheduler.step(val_auc)

            # Save best model
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                torch.save(model.state_dict(), 'best_dnn_model.pt')
                patience_counter = 0
                print(f"  âœ“ New best model saved! (AUC: {val_auc:.4f})")
            else:
                patience_counter += 1

            # Early stopping
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch}")
                break

    # Load best model and evaluate on test set
    print("\n" + "="*60)
    print("Final Test Evaluation")
    print("="*60)

    model.load_state_dict(torch.load('best_dnn_model.pt'))
    test_auc = evaluate_dnn(model, test_loader, device, tasks)

    print(f"\n{'='*60}")
    print(f"Mean Test ROC-AUC: {test_auc:.4f}")
    print(f"{'='*60}")

    # Detailed predictions for classification report
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            logits = model(X_batch)
            probs = torch.sigmoid(logits)
            all_probs.append(probs.cpu().numpy())
            all_labels.append(y_batch.cpu().numpy())

    y_pred_probs = np.vstack(all_probs)
    y_true = np.vstack(all_labels)
    y_pred_binary = (y_pred_probs > 0.5).astype(int)

    print("\nClassification Report (averaged across tasks):")
    # Handle NaN in labels for classification report
    valid_mask = ~np.isnan(y_true)
    if valid_mask.any():
        print(classification_report(
            y_true[valid_mask].astype(int).flatten(),
            y_pred_binary[valid_mask].flatten(),
            zero_division=0
        ))

    print("\nâœ… DNN training and evaluation complete.")

    return model, test_auc

if __name__ == "__main__":
    model, test_auc = main()

GCN+GIN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, global_add_pool
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, classification_report

# ============================
# Improved GCN Model
# ============================
class ImprovedGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super(ImprovedGCN, self).__init__()

        # 5 GCN layers with batch normalization
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        self.conv4 = GCNConv(hidden_dim, hidden_dim)
        self.bn4 = nn.BatchNorm1d(hidden_dim)
        self.conv5 = GCNConv(hidden_dim, hidden_dim)
        self.bn5 = nn.BatchNorm1d(hidden_dim)

        # 3-layer MLP head
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.bn6 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.bn7 = nn.BatchNorm1d(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Graph convolutions with batch norm and ReLU
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        x = F.relu(self.bn3(self.conv3(x, edge_index)))
        x = F.relu(self.bn4(self.conv4(x, edge_index)))
        x = F.relu(self.bn5(self.conv5(x, edge_index)))

        # Global pooling
        x = global_add_pool(x, batch)

        # MLP head with batch norm
        x = self.dropout(F.relu(self.bn6(self.fc1(x))))
        x = self.dropout(F.relu(self.bn7(self.fc2(x))))
        return self.fc3(x)  # Return logits for BCEWithLogitsLoss

# ============================
# Improved GIN Model
# ============================
class ImprovedGIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super(ImprovedGIN, self).__init__()

        # Helper function for GIN MLPs
        def mlp(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim),
                nn.ReLU(),
                nn.Linear(out_dim, out_dim)
            )

        # 5 GIN convolutional layers
        self.conv1 = GINConv(mlp(input_dim, hidden_dim))
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.conv2 = GINConv(mlp(hidden_dim, hidden_dim))
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.conv3 = GINConv(mlp(hidden_dim, hidden_dim))
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        self.conv4 = GINConv(mlp(hidden_dim, hidden_dim))
        self.bn4 = nn.BatchNorm1d(hidden_dim)
        self.conv5 = GINConv(mlp(hidden_dim, hidden_dim))
        self.bn5 = nn.BatchNorm1d(hidden_dim)

        # 3-layer MLP head
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.bn6 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.bn7 = nn.BatchNorm1d(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Graph convolutions
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        x = F.relu(self.bn3(self.conv3(x, edge_index)))
        x = F.relu(self.bn4(self.conv4(x, edge_index)))
        x = F.relu(self.bn5(self.conv5(x, edge_index)))

        # Global pooling
        x = global_add_pool(x, batch)

        # MLP head
        x = self.dropout(F.relu(self.bn6(self.fc1(x))))
        x = self.dropout(F.relu(self.bn7(self.fc2(x))))
        return self.fc3(x)  # Return logits

# ============================
# Training Function
# ============================
def train_gnn(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_samples = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        logits = model(data)
        y = data.y.float().view(logits.shape)

        # Only compute loss on labeled samples
        mask = (y >= 0) & (y <= 1)  # Valid labels
        if mask.sum() == 0:
            continue

        loss = criterion(logits[mask], y[mask])
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
        total_samples += data.num_graphs

    return total_loss / total_samples if total_samples > 0 else 0

# ============================
# Evaluation Function
# ============================
def evaluate_gnn(model, loader, device, tasks):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            logits = model(data)
            probs = torch.sigmoid(logits)

            all_probs.append(probs.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    probs = np.vstack(all_probs)
    labels = np.vstack(all_labels)

    # Compute AUC per task
    aucs = []
    for i, task in enumerate(tasks):
        # Filter valid labels
        mask = (labels[:, i] >= 0) & (labels[:, i] <= 1)
        if mask.sum() < 10:  # Need at least 10 samples
            print(f"{task}: Insufficient data")
            continue

        y_true = labels[mask, i]
        y_pred = probs[mask, i]

        try:
            auc = roc_auc_score(y_true, y_pred)
            aucs.append(auc)
            print(f"{task}: {auc:.4f}")
        except ValueError as e:
            print(f"{task}: Cannot compute AUC - {e}")

    mean_auc = np.mean(aucs) if aucs else 0
    return mean_auc

# ============================
# Main Training Script
# ============================
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Hyperparameters
    input_dim = train_graphs[0].x.shape[1]
    hidden_dim = 128
    output_dim = len(tasks)
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 100
    patience = 15

    # Data loaders
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

    # Train both models
    for model_name, ModelClass in [("GCN", ImprovedGCN), ("GIN", ImprovedGIN)]:
        print(f"\n{'='*60}")
        print(f"Training {model_name} Model")
        print(f"{'='*60}")

        model = ModelClass(input_dim, hidden_dim, output_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.BCEWithLogitsLoss()

        best_val_auc = 0
        patience_counter = 0

        for epoch in range(1, num_epochs + 1):
            train_loss = train_gnn(model, train_loader, optimizer, criterion, device)

            if epoch % 5 == 0:
                print(f"\nEpoch {epoch}/{num_epochs} | Loss: {train_loss:.4f}")
                print("Validation AUCs:")
                val_auc = evaluate_gnn(model, valid_loader, device, tasks)
                print(f"Mean Validation AUC: {val_auc:.4f}")

                if val_auc > best_val_auc:
                    best_val_auc = val_auc
                    torch.save(model.state_dict(), f'best_{model_name.lower()}_model.pt')
                    patience_counter = 0
                    print(f"âœ“ New best model saved!")
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

        # Load best model and evaluate on test set
        model.load_state_dict(torch.load(f'best_{model_name.lower()}_model.pt'))
        print(f"\n{'='*60}")
        print(f"{model_name} Test Results:")
        print(f"{'='*60}")
        test_auc = evaluate_gnn(model, test_loader, device, tasks)
        print(f"\nMean Test AUC: {test_auc:.4f}")

if __name__ == "__main__":
    main()

# Early Fusion Model

Build fusion dataset

In [None]:
"""
Builds the "Early Fusion" dataset from the raw DeepChem CSVs.

This script performs all preprocessing:
1.  Loads the raw CSV data (train, valid, test).
2.  Generates three feature sets:
    - Graph Features (Nodes, Edges)
    - Fingerprint Features (ECFP4)
    - Descriptor Features (RDKit 2D)
3.  Normalizes the Node features using a StandardScaler fit *only* on the training set.
4.  Handles and skips invalid SMILES.
5.  Saves the final, processed data lists (for PyTorch Geometric) to disk.
"""

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from sklearn.preprocessing import StandardScaler
import joblib
import torch
from torch_geometric.data import Data

# --- Configuration ---
RAW_DATA_DIR = "dataset"
PROCESSED_DATA_DIR = "processed_fusion_data"
SCALER_PATH = os.path.join(PROCESSED_DATA_DIR, "node_feature_scaler.joblib")

# Column names from the CSV
SMILES_COLUMN = 'ids'
LABEL_COLUMNS = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
    'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]
WEIGHT_COLUMNS = [
    'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8', 'w9', 'w10', 'w11', 'w12'
]

# --- RDKit Descriptors ---
# List of 2D descriptors to calculate
DESCRIPTOR_FUNCTIONS = {
    name: func for name, func in Descriptors.descList
    if not any(prefix in name for prefix in ['Ipc', 'Kappa', 'Chi']) # Exclude some complex ones
    and '3D' not in name # Exclude 3D descriptors
}
# Sort to ensure consistent order
DESCRIPTOR_NAMES = sorted(DESCRIPTOR_FUNCTIONS.keys())
print(f"Calculating {len(DESCRIPTOR_NAMES)} RDKit descriptors.")


# --- 1. Graph Featurization Functions ---

def get_atom_features(atom):
    """ Generates a feature vector for a single atom. """
    # One-hot encodings for categorical features
    symbol = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'I', 'other']
    degree = [0, 1, 2, 3, 4, 5, 6]
    hybridization = [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2,
        'other'
    ]
    formal_charge = [-2, -1, 0, 1, 2]
    num_hydrogens = [0, 1, 2, 3, 4]

    # Get features
    atom_symbol = atom.GetSymbol()
    atom_degree = atom.GetDegree()
    atom_hybrid = atom.GetHybridization()
    atom_charge = atom.GetFormalCharge()
    atom_hs = atom.GetTotalNumHs()

    features = []
    # Symbol
    features.extend([int(atom_symbol == s) for s in symbol[:-1]])
    if sum(features[-len(symbol)+1:]) == 0: features.append(1) # 'other'
    else: features.append(0)
    # Degree
    features.extend([int(atom_degree == d) for d in degree])
    # Hybridization
    features.extend([int(atom_hybrid == h) for h in hybridization[:-1]])
    if sum(features[-len(hybridization)+1:]) == 0: features.append(1) # 'other'
    else: features.append(0)
    # Formal Charge
    features.extend([int(atom_charge == c) for c in formal_charge])
    # Num Hydrogens
    features.extend([int(atom_hs == h) for h in num_hydrogens])
    # Boolean features
    features.append(int(atom.GetIsAromatic()))
    features.append(int(atom.IsInRing()))

    # --- New features from your script (Bug fixed) ---
    features.append(atom.GetAtomicNum())
    features.append(atom.GetMass())

    chirality = atom.GetChiralTag()
    features.append(int(chirality == Chem.rdchem.ChiralType.CHI_UNSPECIFIED))
    features.append(int(chirality == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW))
    features.append(int(chirality == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW))

    return features

def mol_to_graph_data_obj(mol):
    """ Converts an RDKit Mol object into a PyG Data object. """
    if mol is None:
        return None

    # Node features
    node_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(node_features, dtype=torch.float)

    # Edge index and edge features
    edge_indices = []
    edge_attrs = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.extend([[i, j], [j, i]])

        # Edge features
        bond_type = bond.GetBondType()
        bt_features = [
            int(bond_type == Chem.rdchem.BondType.SINGLE),
            int(bond_type == Chem.rdchem.BondType.DOUBLE),
            int(bond_type == Chem.rdchem.BondType.TRIPLE),
            int(bond_type == Chem.rdchem.BondType.AROMATIC),
        ]
        stereo = bond.GetStereo()
        stereo_feats = [
            int(stereo == Chem.rdchem.BondStereo.STEREONONE),
            int(stereo == Chem.rdchem.BondStereo.STEREOZ),
            int(stereo == Chem.rdchem.BondStereo.STEREOE),
            int(stereo == Chem.rdchem.BondStereo.STEREOANY),
        ]

        attr = bt_features + [int(bond.GetIsConjugated()), int(bond.IsInRing())] + stereo_feats
        edge_attrs.extend([attr, attr]) # Add for both directions

    if len(edge_indices) > 0:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, len(edge_attrs[0]) if edge_attrs else 10), dtype=torch.float) # Match dim

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


# --- 2. Fingerprint Featurization Function ---

def get_fingerprint_features(mol):
    """
    Generates ECFP4 fingerprint.
    RDKit's 'MorganFingerprint' with radius=2 is equivalent to ECFP4.
    """
    if mol is None:
        return None
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
    return fp.ToBitString()


# --- 3. Descriptor Featurization Function ---

def get_descriptor_features(mol):
    """ Generates RDKit 2D descriptor features. """
    if mol is None:
        return [np.nan] * len(DESCRIPTOR_NAMES)

    # Calculate all descriptors
    mol_descriptors = {}
    for name, func in DESCRIPTOR_FUNCTIONS.items():
        try:
            mol_descriptors[name] = func(mol)
        except Exception:
            mol_descriptors[name] = np.nan # Handle calculation errors

    # Return in consistent order
    return [mol_descriptors[name] for name in DESCRIPTOR_NAMES]


# --- Main Data Creation Function ---

def create_data_list(df, scaler=None):
    """
    Processes a DataFrame and creates a list of PyG Data objects,
    each populated with all three feature sets.

    Args:
        df (pd.DataFrame): The raw data.
        scaler (StandardScaler, optional): A *fitted* scaler to apply to node features.
                                          If None, no scaling is done.

    Returns:
        list: A list of processed torch_geometric.data.Data objects.
    """
    data_list = []

    for i, row in tqdm(df.iterrows(), total=len(df), desc="Featurizing molecules"):
        smi = row[SMILES_COLUMN]
        mol = Chem.MolFromSmiles(smi)

        # Add explicit H's
        if mol is not None:
            mol = Chem.AddHs(mol)

        # 1. Graph Features
        graph_data = mol_to_graph_data_obj(mol)
        if graph_data is None:
            print(f"Warning: Skipping invalid SMILES: {smi}")
            continue
        # Apply normalization if scaler is provided
        if scaler is not None:
            try:
                graph_data.x = torch.tensor(scaler.transform(graph_data.x.numpy()), dtype=torch.float)
            except Exception as e:
                print(f"Warning: Failed to scale features for {smi}. Skipping. Error: {e}")
                continue
        # --- END FIX ---

        # 2. Fingerprint Features
        fp_str = get_fingerprint_features(mol)
        if fp_str is None:
            print(f"Warning: Skipping molecule with failed FP: {smi}")
            continue
        fp_features = [int(b) for b in fp_str]
        graph_data.fp_features = torch.tensor(fp_features, dtype=torch.float).reshape(1, -1)

        # 3. Descriptor Features
        desc_features = get_descriptor_features(mol)
        graph_data.desc_features = torch.tensor(desc_features, dtype=torch.float).reshape(1, -1)

        # 4. Get Labels
        labels = list(row[LABEL_COLUMNS])
        graph_data.y = torch.tensor(labels, dtype=torch.float).reshape(1, 12)

        # 5. Get Weights (Fix for Bug 1)
        weights = list(row[WEIGHT_COLUMNS])
        graph_data.w = torch.tensor(weights, dtype=torch.float).reshape(1, 12)

        data_list.append(graph_data)

    return data_list


Model architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    GINEConv,
    global_mean_pool,
    global_max_pool
)

class EarlyFusionModel(nn.Module):
    def __init__(self,
                 node_feature_dim,
                 edge_feature_dim,
                 fp_feature_dim,
                 desc_feature_dim,
                 n_tasks,
                 graph_hidden_dim=128,
                 graph_out_dim=64,
                 gnn_dropout=0.2,
                 fp_out_dim=256,
                 classifier_dropout_1=0.5,
                 classifier_dropout_2=0.3
                ):

        super(EarlyFusionModel, self).__init__()

        # --- 1. GNN Branch (Using GINEConv) ---
        # ... (GNN branch code is unchanged) ...
        nn1 = nn.Sequential(
            nn.Linear(node_feature_dim, graph_hidden_dim),
            nn.ReLU(),
            nn.Linear(graph_hidden_dim, graph_hidden_dim)
        )
        self.gnn_conv1 = GINEConv(nn1, edge_dim=edge_feature_dim)
        nn2 = nn.Sequential(
            nn.Linear(graph_hidden_dim, graph_hidden_dim),
            nn.ReLU(),
            nn.Linear(graph_hidden_dim, graph_hidden_dim)
        )
        self.gnn_conv2 = GINEConv(nn2, edge_dim=edge_feature_dim)
        self.gnn_batch_norm1 = nn.BatchNorm1d(graph_hidden_dim)
        self.gnn_batch_norm2 = nn.BatchNorm1d(graph_hidden_dim)

        self.gnn_dropout = gnn_dropout

        gnn_mlp_in_dim = graph_hidden_dim * 2
        self.graph_fc = nn.Sequential(
            nn.Linear(gnn_mlp_in_dim, graph_out_dim),
            nn.ReLU(),
            nn.Dropout(gnn_dropout)
        )

        # --- 2. Fingerprint (FP) Branch ---
        # ... (FP branch code is unchanged) ...
        self.fp_fc = nn.Sequential(
            nn.Linear(fp_feature_dim, fp_out_dim),
            nn.ReLU(),
            nn.Dropout(classifier_dropout_1)
        )

        # --- 3. Descriptor (Desc) Branch ---
        desc_output_size = desc_feature_dim

        # --- 4. Learnable Branch Weights (NEW!) ---
        # We create a learnable parameter vector with 3 weights, one for each branch.
        # We initialize them all to 1.0.
        self.branch_weights = nn.Parameter(torch.ones(3))

        # --- 5. Final Classifier Head (Fusion) ---
        classifier_in_dim = graph_out_dim + fp_out_dim + desc_output_size

        # ... (Classifier code is unchanged) ...
        self.classifier = nn.Sequential(
            nn.Linear(classifier_in_dim, 128),
            nn.ReLU(),
            nn.Dropout(classifier_dropout_2),
            nn.Linear(128, n_tasks)
        )

    # ... (forward_graph helper function is unchanged) ...
    def forward_graph(self, x, edge_index, edge_attr, batch):
        # GNN Conv 1
        x = self.gnn_conv1(x, edge_index, edge_attr)
        x = self.gnn_batch_norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.gnn_dropout, training=self.training)
        # GNN Conv 2
        x = self.gnn_conv2(x, edge_index, edge_attr)
        x = self.gnn_batch_norm2(x)
        x = F.elu(x)
        # Readout/Pooling
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        # Concatenate the two pooling results
        graph_out = torch.cat([mean_pool, max_pool], dim=1)
        # Pass through GNN's MLP
        graph_out = self.graph_fc(graph_out)
        return graph_out

    def forward(self, graph_data):

        x, edge_index, edge_attr, fp_features, desc_features, batch = graph_data.x, graph_data.edge_index, graph_data.edge_attr, graph_data.fp_features, graph_data.desc_features, graph_data.batch

        # 1. GNN Branch
        graph_out = self.forward_graph(x, edge_index, edge_attr, batch)

        # 2. FP Branch
        fp_out = self.fp_fc(fp_features)

        # 3. Descriptor Branch
        desc_out = desc_features

        # 4. Fusion (NEW: Apply learned weights!)
        # We multiply each branch output by its learned weight before concatenating.
        # This allows the model to scale the "importance" of each branch.
        fused_vector = torch.cat([
            graph_out * self.branch_weights[0],
            fp_out * self.branch_weights[1],
            desc_out * self.branch_weights[2]
        ], dim=1)

        # 5. Final Classification
        out = self.classifier(fused_vector)
        return out

# =============================================================================
# --- SELF-TEST BLOCK ---
# To run this test: `python model.py`
# =============================================================================
# if __name__ == "__main__":

#     print("--- Running Model Self-Test ---")

#     # --- 1. Define Mock Dimensions ---
#     B = 4  # Batch size
#     N_TASKS = 12

#     # Feature dimensions (must match build_fusion_dataset.py)
#     NODE_DIM = 41
#     EDGE_DIM = 11
#     FP_DIM = 2048
#     DESC_DIM = 200

#     # Model hyperparameters
#     GRAPH_HIDDEN = 128
#     GRAPH_OUT = 64
#     FP_OUT = 256

#     # --- 2. Create a Dummy Batch (simulating PyG DataLoader) ---
#     # We'll create 4 "molecules" of different sizes
#     from torch_geometric.data import Data, Batch
#     # Mol 1: 3 nodes, 2 edges
#     d1 = Data(
#         x=torch.rand(3, NODE_DIM),
#         edge_index=torch.tensor([[0, 1], [1, 2]], dtype=torch.long).t().contiguous(),
#         edge_attr=torch.rand(2, EDGE_DIM),
#         fp=torch.rand(1, FP_DIM),
#         desc=torch.rand(1, DESC_DIM),
#         y=torch.rand(1, N_TASKS), # Not used in forward, but good to have
#         w=torch.rand(1, N_TASKS)  # Not used in forward
#     )

#     # Mol 2: 5 nodes, 4 edges
#     d2 = Data(
#         x=torch.rand(5, NODE_DIM),
#         edge_index=torch.tensor([[0, 1, 1, 2], [1, 2, 3, 4]], dtype=torch.long).t().contiguous(),
#         edge_attr=torch.rand(4, EDGE_DIM),
#         fp=torch.rand(1, FP_DIM),
#         desc=torch.rand(1, DESC_DIM),
#         y=torch.rand(1, N_TASKS),
#         w=torch.rand(1, N_TASKS)
#     )

#     # Mol 3: 2 nodes, 1 edge
#     d3 = Data(
#         x=torch.rand(2, NODE_DIM),
#         edge_index=torch.tensor([[0], [1]], dtype=torch.long).t().contiguous(),
#         edge_attr=torch.rand(1, EDGE_DIM),
#         fp=torch.rand(1, FP_DIM),
#         desc=torch.rand(1, DESC_DIM),
#         y=torch.rand(1, N_TASKS),
#         w=torch.rand(1, N_TASKS)
#     )

#     # Mol 4: 4 nodes, 3 edges
#     d4 = Data(
#         x=torch.rand(4, NODE_DIM),
#         edge_index=torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long).t().contiguous(),
#         edge_attr=torch.rand(3, EDGE_DIM),
#         fp=torch.rand(1, FP_DIM),
#         desc=torch.rand(1, DESC_DIM),
#         y=torch.rand(1, N_TASKS),
#         w=torch.rand(1, N_TASKS)
#     )

#     # Create a PyG Batch from this list
#     data_list = [d1, d2, d3, d4]
#     data_batch = Batch.from_data_list(data_list)

#     print(f"Created a dummy batch of {B} graphs.")
#     print(f"  Batch.x shape (total nodes):         {data_batch.x.shape}")
#     print(f"  Batch.edge_index shape:              {data_batch.edge_index.shape}")
#     print(f"  Batch.edge_attr shape (total edges): {data_batch.edge_attr.shape}")
#     print(f"  Batch.fp_features shape:             {data_batch.fp_features.shape}")
#     print(f"  Batch.desc_features shape:           {data_batch.desc_features.shape}")
#     print(f"  Batch.batch vector shape:            {data_batch.batch.shape}")

#     # --- 3. Instantiate Model ---
#     model = EarlyFusionModel(
#         node_feature_dim=NODE_DIM,
#         edge_feature_dim=EDGE_DIM,
#         fp_feature_dim=FP_DIM,
#         desc_feature_dim=DESC_DIM,
#         n_tasks=N_TASKS,
#         graph_hidden_dim=GRAPH_HIDDEN,
#         graph_out_dim=GRAPH_OUT,
#         fp_out_dim=FP_OUT
#     )

#     model.train() # Set to training mode
#     print("\nModel instantiated successfully.")

#     # --- 4. Run Forward Pass ---
#     try:
#         out = model(data_batch)

#         print("\n--- TEST SUCCESSFUL ---")
#         print(f"Forward pass ran without errors.")
#         print(f"Input batch size:  {B}")
#         print(f"Output shape:      {out.shape}")

#         # Check if the output shape is correct
#         assert out.shape == (B, N_TASKS)
#         print("Output shape is correct! (Batch Size, Num Tasks)")

#     except Exception as e:
#         print("\n--- TEST FAILED ---")
#         print(f"An error occurred during the forward pass:")
#         print(e)
#         import traceback
#         traceback.print_exc()

Optuna tuning code (please uncomment the code to run)

In [None]:
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch_geometric.loader import DataLoader
# from sklearn.metrics import roc_auc_score
# from tqdm import tqdm
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# import optuna


# # Import our custom *tunable* model
# # from temp import EarlyFusionModel
# # import gine_branchw model
# from gine_with_branchw import EarlyFusionModel

# # --- 1. Constants and Setup ---
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {DEVICE}")

# DATA_DIR = "processed_fusion_data"
# # DB_STORAGE_PATH = "sqlite:///tox21_tuning3.db"
# # STUDY_NAME = "early_fusion_v3"
# DB_STORAGE_PATH = "sqlite:///optuna_tuning_gine_bw_clsimb1.db" # <-- New DB file
# STUDY_NAME = "early_fusion_v5_gine"      # <-- New study name
# LABEL_COLUMNS = [
#         'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER',
#         'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5',
#         'SR-HSE', 'SR-MMP', 'SR-p53'
#     ]

# # Hyperparameters
# EPOCHS = 200      # Max epochs *per trial*
# PATIENCE = 15    # Early stopping patience
# N_TASKS = 12

# # --- 2. Load Data (Load ONCE, outside the objective) ---
# print("Loading processed data...")
# try:
#     train_data = torch.load(os.path.join(DATA_DIR, "train_data.pt"), weights_only=False)
#     valid_data = torch.load(os.path.join(DATA_DIR, "valid_data.pt"), weights_only=False)
# except FileNotFoundError:
#     print(f"Error: Processed data not found in '{DATA_DIR}'.")
#     print("Please run 'build_fusion_dataset.py' first.")
#     exit()

# # Get feature dimensions (ONCE)
# first_data = train_data[0]
# NODE_DIM = first_data.x.shape[1]
# EDGE_DIM = first_data.edge_attr.shape[1]
# FP_DIM = first_data.fp.shape[1]
# DESC_DIM = first_data.desc.shape[1]
# print(f"Data loaded. Feature Dims: Node={NODE_DIM}, Edge={EDGE_DIM}, FP={FP_DIM}, Desc={DESC_DIM}")

# # --- Calculate pos_weight for Class Imbalance (NEW!) ---
# def calculate_pos_weights(data_list):
#     num_pos = torch.zeros(N_TASKS)
#     num_neg = torch.zeros(N_TASKS)

#     for data in data_list:
#         labels = data.y.squeeze() # Shape [12]
#         weights = data.w.squeeze() # Shape [12]
#         is_valid = (weights > 0) & (~torch.isnan(labels))

#         pos_mask = (labels == 1) & is_valid
#         neg_mask = (labels == 0) & is_valid

#         # --- THIS IS THE FIX ---
#         # We add the boolean mask (shape [12]) directly.
#         # This correctly adds 1s and 0s element-wise.
#         # The old code had `.sum(dim=0)`, which was a bug.
#         num_pos += pos_mask
#         num_neg += neg_mask
#         # --- END OF FIX ---

#     pos_weight = num_neg / (num_pos + 1e-6)

#     # We clip the weights to be at most 15. This stops runaway gradients.
#     # pos_weight = torch.clamp(pos_weight, min=1.0, max=15.0)

#     print("--- Class Imbalance (unCLIPPED pos_weight) Calculated ---")
#     for name, weight in zip(LABEL_COLUMNS, pos_weight):
#         print(f"  {name:<16}: {weight:.2f}") # Now these will all be different!
#     print("-----------------------------------------------")

#     return pos_weight.to(DEVICE)

# # Calculate weights *only* from the training set
# pos_weight = calculate_pos_weights(train_data)

# # --- 3. Training/Evaluation Functions (copied from train.py) ---

# # Loss function (defined globally)
# loss_fn = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)

# def train_epoch(model, loader, loss_fn, optimizer):
#     model.train()
#     total_loss = 0
#     for batch in loader:
#         batch = batch.to(DEVICE)
#         logits = model(batch)
#         y_true = batch.y
#         weights = batch.w

#         raw_loss = loss_fn(logits, y_true)
#         weighted_loss = raw_loss * weights
#         final_loss = weighted_loss.sum() / (weights.sum() + 1e-8)

#         optimizer.zero_grad()
#         final_loss.backward()
#         optimizer.step()
#         total_loss += final_loss.item() * batch.num_graphs
#     return total_loss / len(loader.dataset)

# @torch.no_grad()
# def eval_model(model, loader):
#     model.eval()
#     all_preds, all_labels, all_weights = [], [], []
#     for batch in loader:
#         batch = batch.to(DEVICE)
#         logits = model(batch)
#         all_preds.append(torch.sigmoid(logits).cpu().numpy())
#         all_labels.append(batch.y.cpu().numpy())
#         all_weights.append(batch.w.cpu().numpy())

#     all_preds = np.concatenate(all_preds, axis=0)
#     all_labels = np.concatenate(all_labels, axis=0)
#     all_weights = np.concatenate(all_weights, axis=0)

#     task_aucs = []
#     for i in range(N_TASKS):
#         valid_indices = all_weights[:, i] > 0
#         if np.sum(valid_indices) > 0 and len(np.unique(all_labels[valid_indices, i])) == 2:
#             task_aucs.append(roc_auc_score(all_labels[valid_indices, i], all_preds[valid_indices, i]))
#         else:
#             task_aucs.append(np.nan)

#     # print(f"--- Trial {trial.number} AUCs ---")
#     # use task names
#     for i, auc in enumerate(task_aucs):
#         print(f"  Task {i+1} ({LABEL_COLUMNS[i]}): AUC = {auc:.4f}" if not np.isnan(auc) else f"  Task {i+1} ({LABEL_COLUMNS[i]}): AUC = N/A")

#     return np.nanmean(task_aucs)

# # --- 4. Optuna Objective Function ---

# def objective(trial):
#     """
#     This function is called by Optuna for each trial.
#     """
#     # --- A. Suggest Hyperparameters ---
#     print(f"\n--- Starting Trial {trial.number} ---")

#     # gnn type
#     # gnn = trial.suggest_categorical("gnn_type", ['gat', 'gin'])

#     # Optimization params
#     lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
#     weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
#     batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128])

#     # Model architecture params
#     # gat_heads = trial.suggest_categorical("gat_heads", [2, 4, 8, 16])
#     gnn_dropout = trial.suggest_float("gnn_dropout", 0.0, 0.5)
#     classifier_dropout_1 = trial.suggest_float("classifier_dropout_1", 0.0, 0.6)
#     classifier_dropout_2 = trial.suggest_float("classifier_dropout_2", 0.0, 0.5)

#     # Fixed model params
#     graph_hidden_dim = trial.suggest_categorical("graph_hidden_dim", [64, 128, 256])
#     graph_out_dim = trial.suggest_categorical("graph_out_dim", [32, 64, 128])
#     fp_out_dim = trial.suggest_categorical("fp_out_dim", [128, 256, 512])

#     # --- B. Setup Model, Loaders, Optimizer ---

#     # DataLoaders must be created inside objective to use batch_size
#     train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
#     valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)

#     model = EarlyFusionModel(
#         node_feature_dim=NODE_DIM,
#         edge_feature_dim=EDGE_DIM,
#         fp_feature_dim=FP_DIM,
#         desc_feature_dim=DESC_DIM,
#         n_tasks=N_TASKS,
#         graph_hidden_dim=graph_hidden_dim,
#         graph_out_dim=graph_out_dim,
#         gnn_dropout=gnn_dropout,
#         fp_out_dim=fp_out_dim,
#         classifier_dropout_1=classifier_dropout_1,
#         classifier_dropout_2=classifier_dropout_2
#     ).to(DEVICE)

#     optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

#     scheduler = ReduceLROnPlateau(
#         optimizer,
#         mode='max',      # We monitor AUC, so we want to maximize it
#         factor=0.5,      # Reduce LR by half
#         patience=5,      # Wait 5 epochs for improvement
#         min_lr=1e-7      # Don't go below this
#     )

#     # --- C. Run Training Loop ---

#     best_valid_auc = 0.0
#     epochs_no_improve = 0

#     for epoch in range(1, EPOCHS + 1):
#         train_loss = train_epoch(model, train_loader, loss_fn, optimizer)
#         valid_auc = eval_model(model, valid_loader)

#         scheduler.step(valid_auc)

#         print(f"Trial {trial.number} Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Valid AUC: {valid_auc:.4f}")

#         # Optuna Pruning: Stop unpromising trials early
#         trial.report(valid_auc, epoch)
#         if trial.should_prune():
#             print("--- Trial Pruned ---")
#             raise optuna.TrialPruned()

#         # Early Stopping
#         if valid_auc > best_valid_auc:
#             best_valid_auc = valid_auc
#             epochs_no_improve = 0
#         else:
#             epochs_no_improve += 1

#         if epochs_no_improve >= PATIENCE:
#             print(f"--- Trial Early-Stopped ---")
#             break

#     return best_valid_auc # Return the best validation AUC for this trial

# # --- 5. Main Study Execution ---

# if __name__ == "__main__":
#     print(f"Starting Optuna study: {STUDY_NAME}")
#     print(f"Database will be saved to: {DB_STORAGE_PATH}")

#     # Create a study object and specify direction to "maximize" AUC
#     study = optuna.create_study(
#         study_name=STUDY_NAME,
#         storage=DB_STORAGE_PATH,
#         direction="maximize",
#         load_if_exists=True  # Allows you to resume tuning
#     )

#     # redoing trials with branchw
#     print("Running trials with branch weights")

#     # Start the optimization
#     try:
#         study.optimize(objective, n_trials=200)  # Run 200 trials
#     except KeyboardInterrupt:
#         print("Tuning interrupted by user.")

#     # --- 6. Print Results ---
#     print("\n--- Tuning Complete ---")

#     pruned_trials = study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.PRUNED])
#     completed_trials = study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.COMPLETE])

#     print("Study statistics: ")
#     print(f"  Number of finished trials: {len(study.trials)}")
#     print(f"  Number of pruned trials:   {len(pruned_trials)}")
#     print(f"  Number of complete trials: {len(completed_trials)}")

#     print("\n--- Best Trial ---")
#     trial = study.best_trial
#     print(f"  Value (AUC): {trial.value:.4f}")
#     print("  Params: ")
#     for key, value in trial.params.items():
#         print(f"    {key}: {value}")

Training with branch and positive weights

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score
import numpy as np
import os
import joblib
from gine_with_branchw import EarlyFusionModel

# --- Configuration ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PROCESSED_DATA_DIR = "processed_fusion_data_3d"
N_TASKS = 12
MODEL_SAVE_PATH = "test_run_best_model_class_imbalance_3d.pth"
N_EPOCHS = 200
EARLY_STOP_PATIENCE = 15
TASK_NAMES = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD",
    "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]
DEFAULT_PARAMS = {
    'lr': 1e-4, 'weight_decay': 1e-5, 'batch_size': 64, 'gnn_dropout': 0.3,
    'classifier_dropout_1': 0.2, 'classifier_dropout_2': 0.5,
    'graph_hidden_dim': 128, 'graph_out_dim': 64, 'fp_out_dim': 256
}

# ... --- Load Scalers (for model dimensions) ...
try:
    node_scaler = joblib.load(os.path.join(PROCESSED_DATA_DIR, "node_feature_scaler.joblib"))
    desc_imputer = joblib.load(os.path.join(PROCESSED_DATA_DIR, "desc_feature_imputer.joblib"))
    desc_scaler = joblib.load(os.path.join(PROCESSED_DATA_DIR, "desc_feature_scaler.joblib"))
except FileNotFoundError:
    print("Error: Scaler/imputer files not found. Please run 'build_fusion_dataset_3d.py' first.")
    exit()

# --- Determine Feature Dimensions ---
try:
    _temp_data_list = torch.load(os.path.join(PROCESSED_DATA_DIR, "train_data.pt"), weights_only=False)
    if not _temp_data_list: exit()
    _temp_data = _temp_data_list[0]
    NODE_FEATURE_DIM = _temp_data.x.shape[1]
    EDGE_FEATURE_DIM = _temp_data.edge_attr.shape[1]
    FP_FEATURE_DIM = _temp_data.fp_features.shape[1]
    DESC_FEATURE_DIM = _temp_data.desc_features.shape[1]
    print(f"--- Feature Dimensions Detected ---")
    print(f"Node Features:   {NODE_FEATURE_DIM}")
    print(f"Edge Features:   {EDGE_FEATURE_DIM}")
    print(f"FP Features:     {FP_FEATURE_DIM}")
    print(f"Desc Features:   {DESC_FEATURE_DIM}")
    print(f"---------------------------------")
    train_data_list = _temp_data_list
    valid_data_list = torch.load(os.path.join(PROCESSED_DATA_DIR, "valid_data.pt"), weights_only=False)
    test_data_list = torch.load(os.path.join(PROCESSED_DATA_DIR, "test_data.pt"), weights_only=False)
    del _temp_data, _temp_data_list
except Exception as e:
    print(f"Error loading data: {e}")
    exit()
except Exception as e:
    print(f"Error loading data for dimension check: {e}")
    exit()


# --- 1. Calculate pos_weight for Class Imbalance (NEW!) ---
def calculate_pos_weights(data_list):
    num_pos = torch.zeros(N_TASKS)
    num_neg = torch.zeros(N_TASKS)

    for data in data_list:
        labels = data.y.squeeze() # Shape [12]
        weights = data.w.squeeze() # Shape [12]
        is_valid = (weights > 0) & (~torch.isnan(labels))

        pos_mask = (labels == 1) & is_valid
        neg_mask = (labels == 0) & is_valid

        # We add the boolean mask (shape [12]) directly.
        # This correctly adds 1s and 0s element-wise.
        num_pos += pos_mask
        num_neg += neg_mask

    pos_weight = num_neg / (num_pos + 1e-6)

    # We clip the weights to be at most 15. This stops runaway gradients.
    # pos_weight = torch.clamp(pos_weight, min=1.0, max=15.0)

    print("--- Class Imbalance (unCLIPPED pos_weight) Calculated ---")
    for name, weight in zip(TASK_NAMES, pos_weight):
        print(f"  {name:<16}: {weight:.2f}") # Now these will all be different!
    print("-----------------------------------------------")

    return pos_weight.to(DEVICE)

# Calculate weights *only* from the training set
pos_weight_tensor = calculate_pos_weights(train_data_list)


# --- 2. Update Loss Function to use pos_weight (NEW!) ---
def weighted_bce_loss(y_pred, y_true, weights, pos_weight):
    """
    Our full, robust loss function.
    - `weights` handles MISSING labels (w=0).
    - `pos_weight` handles CLASS IMBALANCE (rare positives).
    """
    # pos_weight is shape [12], we give it to BCEWithLogitsLoss
    loss_fn = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)

    raw_loss = loss_fn(y_pred, y_true)

    # Mask out NaN labels
    is_valid = ~torch.isnan(y_true)
    raw_loss = torch.where(is_valid, raw_loss, torch.zeros_like(raw_loss))

    # Apply the missing-label weights
    weighted_loss = raw_loss * weights

    # Normalize by the sum of weights
    total_weight = weights.sum()
    final_loss = weighted_loss.sum() / (total_weight + 1e-8)
    return final_loss

# ... (eval_model is unchanged) ...
@torch.no_grad()
def eval_model(model, loader, print_scores=False):
    model.eval()
    all_preds, all_labels, all_weights = [], [], []
    for data in loader:
        data = data.to(DEVICE)
        out = model(data)
        all_preds.append(out.cpu())
        all_labels.append(data.y.cpu())
        all_weights.append(data.w.cpu())
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_weights = torch.cat(all_weights, dim=0)

    val_loss = weighted_bce_loss(all_preds.to(DEVICE), all_labels.to(DEVICE), all_weights.to(DEVICE), pos_weight_tensor)
    # print(f"Validation Loss: {val_loss.item():.4f}")
    task_aucs = []
    valid_labels_mask = ~torch.isnan(all_labels) & (all_weights > 0)
    for i in range(N_TASKS):
        task_labels = all_labels[valid_labels_mask[:, i], i]
        task_preds = all_preds[valid_labels_mask[:, i], i]
        if len(task_labels) > 1 and len(torch.unique(task_labels)) > 1:
            try: task_aucs.append(roc_auc_score(task_labels.numpy(), torch.sigmoid(task_preds).numpy()))
            except ValueError: task_aucs.append(np.nan)
        else: task_aucs.append(np.nan)
    if print_scores:
        print("\n--- Per-Task ROC-AUC Scores ---")
        for name, auc in zip(TASK_NAMES, task_aucs):
            if not np.isnan(auc): print(f"  {name:<16}: {auc:.4f}")
            else: print(f"  {name:<16}: N/A (not enough samples)")
        print("---------------------------------")
    mean_auc = np.nanmean(task_aucs)
    return val_loss, mean_auc

# --- Main Training Function ---
def main():
    print("--- Starting Test Run (with Imbalance Fix) ---")
    print(f"Using device: {DEVICE}")

    # Data loader creation
    print(f"Loading {len(train_data_list)} train, {len(valid_data_list)} valid, {len(test_data_list)} test samples.")
    train_loader = DataLoader(train_data_list, batch_size=8, shuffle=True)
    valid_loader = DataLoader(valid_data_list, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_data_list, batch_size=8, shuffle=False)

    # 3. Initialize Model (uses branch_weights) ---
    # we use the best parameters from optuna tuning
    model = EarlyFusionModel(
        node_feature_dim=NODE_FEATURE_DIM,
        edge_feature_dim=EDGE_FEATURE_DIM,
        fp_feature_dim=FP_FEATURE_DIM,
        desc_feature_dim=DESC_FEATURE_DIM,
        n_tasks=N_TASKS,
        graph_hidden_dim=256,
        graph_out_dim=64,
        fp_out_dim=64,
        gnn_dropout=0.317,
        classifier_dropout_1=0.572,
        classifier_dropout_2=0.432
    ).to(DEVICE)

    print(f"Model Architecture:{model}")

    optimizer = optim.Adam(model.parameters(), lr=0.00024, weight_decay=3.948e-06)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, min_lr=1e-7)
    print(f"Optimizer:{optimizer}")
    print(f"Scheduler:{scheduler}")

    # 4. Training Loop (Pass pos_weight_tensor to loss)
    best_valid_auc = 0.0
    epochs_no_improve = 0

    for epoch in range(N_EPOCHS):
        model.train()
        total_loss = 0
        for data in train_loader:
            data = data.to(DEVICE)
            optimizer.zero_grad()

            out = model(data)

            loss = weighted_bce_loss(out, data.y, data.w, pos_weight_tensor)

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs

        avg_train_loss = total_loss / len(train_loader.dataset)
        val_loss, valid_auc = eval_model(model, valid_loader, print_scores=False)
        print(f"Epoch {epoch+1:02d}/{N_EPOCHS:02d} | Train Loss: {avg_train_loss:.4f} | Valid Loss: {val_loss.item():.4f} | Valid AUC: {valid_auc:.4f} | LR: {optimizer.param_groups[0]['lr']:.1e}")

        # Print the learned branch weights each epoch!
        weights = model.branch_weights.data.cpu().numpy()
        print(f"  Branch Weights (GNN, FP, Desc): {weights[0]:.2f}, {weights[1]:.2f}, {weights[2]:.2f}")

        scheduler.step(valid_auc)

        if valid_auc > best_valid_auc:
            best_valid_auc = valid_auc
            epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"  -> New best model saved to {MODEL_SAVE_PATH} (AUC: {best_valid_auc:.4f})")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOP_PATIENCE:
            print(f"\n--- Early stopping at epoch {epoch+1} ---")
            break

    print("\n--- Test Run Training Complete ---")

    # ... (Final evaluation block is unchanged) ...
    print(f"Loading best model from {MODEL_SAVE_PATH} for final test evaluation...")
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    test_loss, test_mean_auc = eval_model(model, test_loader, print_scores=True)
    print(f"\n======================================")
    print(f" Final Test Loss: {test_loss.item():.4f}")
    print(f" Final Test Mean ROC-AUC: {test_mean_auc:.4f}")
    print(f"======================================")

if __name__ == "__main__":
    main()

# Graph Transformer with MolCLR
Adapted from https://github.com/yuyangw/MolCLR

Dataloading for Tox21

In [None]:
import os
import csv
import math
import time
import random
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

from torch_scatter import scatter
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
from rdkit import RDLogger
from openbabel import pybel
RDLogger.DisableLog('rdApp.*')


ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

def canonical(s):
    m = Chem.MolFromSmiles(s)
    return Chem.MolToSmiles(m, canonical=True) if m else None

def _generate_scaffold(smiles, include_chirality=False):
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
    return scaffold


def generate_scaffolds(dataset, log_every_n=1000):
    scaffolds = {}
    data_len = len(dataset)
    print(data_len)

    print("About to generate scaffolds")
    for ind, smiles in enumerate(dataset.smiles_data):
        if ind % log_every_n == 0:
            print("Generating scaffold %d/%d" % (ind, data_len))
        scaffold = _generate_scaffold(smiles)
        if scaffold not in scaffolds:
            scaffolds[scaffold] = [ind]
        else:
            scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    return scaffold_sets


def scaffold_split(dataset, valid_size, test_size, seed=None, log_every_n=1000):
    train_size = 1.0 - valid_size - test_size
    scaffold_sets = generate_scaffolds(dataset)

    train_cutoff = train_size * len(dataset)
    valid_cutoff = (train_size + valid_size) * len(dataset)
    train_inds: List[int] = []
    valid_inds: List[int] = []
    test_inds: List[int] = []

    print("About to sort in scaffold sets")
    for scaffold_set in scaffold_sets:
        if len(train_inds) + len(scaffold_set) > train_cutoff:
            if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
                test_inds += scaffold_set
            else:
                valid_inds += scaffold_set
        else:
            train_inds += scaffold_set
    return train_inds, valid_inds, test_inds


def read_smiles(data_path, target, task):
    smiles_data, labels = [], []
    with open(data_path) as csv_file:
        csv_reader = csv.DictReader(csv_file, delimiter=',')
        for i, row in enumerate(csv_reader):
            if i != 0:
                smiles = row['smiles']
                label = row[target]
                mol = Chem.MolFromSmiles(smiles)
                if mol != None and label != '':
                    smiles_data.append(smiles)
                    if task == 'classification':
                        labels.append(int(label))
                    elif task == 'regression':
                        labels.append(float(label))
                    else:
                        ValueError('task must be either regression or classification')
    print(len(smiles_data))
    return smiles_data, labels
def generate_3d(mol):
    try:
        mol3d = Chem.AddHs(mol)
        params = AllChem.ETKDGv3()

        result = AllChem.EmbedMolecule(mol3d, params)

        if result != 0:
            return None, False

        #AllChem.UFFOptimizeMolecule(mol3d)
        pos = mol3d.GetConformer().GetPositions()
        pos = torch.tensor(pos, dtype=torch.float)
        pos = pos - pos.mean(dim=0, keepdim=True)
        return pos,  True

    except Exception as e:
        raise e
        return None, False

def generate_3d_openbabel(smiles):
    try:
        mol = pybel.readstring("smi", smiles)
        mol.make3D()
        coords = torch.tensor([list(a.coords) for a in mol.atoms], dtype=torch.float)
        return coords, True
    except Exception:

        return None, False

class MolTestDataset(Dataset):
    def __init__(self, data_path, target, task, use3D=False):
        super(Dataset, self).__init__()
        #self.smiles_data, self.labels = read_smiles(data_path, target, task)


        self.use3D= use3D
        smiles_list, labels = read_smiles(data_path, target, task)
        filtered_smiles, filtered_labels = [], []
        invalid_count = 0

        for s,l in zip(smiles_list,labels):
            try:
                mol = Chem.MolFromSmiles(s)
                mol = Chem.AddHs(mol)
            except:
                invalid_smiles += 1
                continue

            if use3D:
                pos, flag = generate_3d(mol)

                if not flag or pos is None:
                    invalid_count += 1
                    continue

                filtered_smiles.append((s, pos))
                filtered_labels.append(l)

            else:
                filtered_smiles.append(s)
                filtered_labels.append(l)

        self.smiles_data = filtered_smiles
        self.labels = filtered_labels
        print(f" Filtered out {invalid_count} invalid SMILES.")



        self.task = task

        self.conversion = 1

    def __getitem__(self, index):
        if self.use3D:
            s, pos =  self.smiles_data[index]
        else:
            s =  self.smiles_data[index]

        mol = Chem.MolFromSmiles(s)
        mol = Chem.AddHs(mol)

        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()

        type_idx = []
        chirality_idx = []
        atomic_number = []
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())

        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        x = torch.cat([x1, x2], dim=-1)

        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
        if self.task == 'classification':
            y = torch.tensor(self.labels[index], dtype=torch.long).view(1,-1)
        data = Data(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr)
        if self.use3D:
            if pos is None:
                raise RuntimeError(f"pos became None at index {index}, SMILES={s}")
            data.pos = pos


        return data

    def __len__(self):
        return len(self.smiles_data)


class MolTestDatasetWrapper(object):

    def __init__(self,
        batch_size, num_workers, valid_size, test_size,
        data_path, target, task, splitting, use_3D= False
    ):
        super(object, self).__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.valid_size = valid_size
        self.test_size = test_size
        self.target = target
        self.task = task
        self.splitting = splitting
        self.use_3D= use_3D
        assert splitting in ['random', 'scaffold', 'tox21']

    def _load_tox21_split(self, dataset):
        """
        Read train_smiles.txt, valid_smiles.txt, test_smiles.txt
        and map each SMILES to index in dataset.smiles_data
        """
        def read_txt(path):
            with open(path) as f:
                return [l.strip() for l in f if l.strip()]

        root = self.data_path
        directory = os.path.dirname(root)

        train_smiles = read_txt(os.path.join(directory, "train_smiles.txt"))
        valid_smiles = read_txt(os.path.join(directory, "valid_smiles.txt"))
        test_smiles  = read_txt(os.path.join(directory, "test_smiles.txt"))

        # Map SMILES ? list index in dataset
        if self.use_3D:
            smiles_to_index = {canonical(s): i for i, (s,_) in enumerate(dataset.smiles_data)}
        else:
            smiles_to_index = {canonical(s): i for i, s in enumerate(dataset.smiles_data)}

        train_idx = [smiles_to_index[s] for s in train_smiles if s in smiles_to_index]
        valid_idx = [smiles_to_index[s] for s in valid_smiles if s in smiles_to_index]
        test_idx  = [smiles_to_index[s] for s in test_smiles  if s in smiles_to_index]

        print(f"train={len(train_idx)}, valid={len(valid_idx)}, test={len(test_idx)}")

        return train_idx, valid_idx, test_idx

    def get_data_loaders(self):
        train_dataset = MolTestDataset(data_path=self.data_path, target=self.target, task=self.task, use3D=self.use_3D)
        train_loader, valid_loader, test_loader = self.get_train_validation_data_loaders(train_dataset)
        return train_loader, valid_loader, test_loader

    def get_train_validation_data_loaders(self, train_dataset):
        if self.splitting == 'random':
            # obtain training indices that will be used for validation
            num_train = len(train_dataset)
            indices = list(range(num_train))
            np.random.shuffle(indices)

            split = int(np.floor(self.valid_size * num_train))
            split2 = int(np.floor(self.test_size * num_train))
            valid_idx, test_idx, train_idx = indices[:split], indices[split:split+split2], indices[split+split2:]

        elif self.splitting == 'scaffold':
            train_idx, valid_idx, test_idx = scaffold_split(train_dataset, self.valid_size, self.test_size)

        elif self.splitting == 'tox21':
            train_idx, valid_idx, test_idx = self._load_tox21_split(train_dataset)


        # define samplers for obtaining training and validation batches
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        test_sampler = SubsetRandomSampler(test_idx)

        train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, sampler=train_sampler,
            num_workers=self.num_workers, drop_last=False
        )
        valid_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
            num_workers=self.num_workers, drop_last=False
        )
        test_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, sampler=test_sampler,
            num_workers=self.num_workers, drop_last=False
        )

        return train_loader, valid_loader, test_loader


Graph Transformer Model

In [None]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_add_pool, global_mean_pool, global_max_pool

num_atom_type = 119
num_chirality_tag = 3
num_bond_type = 5
num_bond_direction = 3


class GraphTransformer(nn.Module):
    def __init__(
        self,
        task='classification',
        num_layer=5,
        emb_dim=300,
        feat_dim=256,
        heads=4,
        drop_ratio=0.1,
        pool='mean',
        edge_emb_dim=32,
        use_3D = False
    ):
        super(GraphTransformer, self).__init__()
        self.num_layer = num_layer
        self.emb_dim = emb_dim
        self.feat_dim = feat_dim
        self.drop_ratio = drop_ratio
        self.heads = heads
        self.task = task
        self.edge_emb_dim = edge_emb_dim
        self.use_3D = use_3D

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")


        self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
        nn.init.xavier_uniform_(self.x_embedding1.weight)
        nn.init.xavier_uniform_(self.x_embedding2.weight)


        self.edge_embedding1 = nn.Embedding(num_bond_type, edge_emb_dim)
        self.edge_embedding2 = nn.Embedding(num_bond_direction, edge_emb_dim)
        nn.init.xavier_uniform_(self.edge_embedding1.weight)
        nn.init.xavier_uniform_(self.edge_embedding2.weight)
        self.edge_proj = nn.Linear(edge_emb_dim, edge_emb_dim)

        if use_3D:
            self.pos_emb = nn.Linear(3, emb_dim)

        self.layers = nn.ModuleList()
        self.norms1 = nn.ModuleList()
        self.norms2 = nn.ModuleList()
        self.ffns = nn.ModuleList()

        for _ in range(num_layer):
            self.layers.append(
                TransformerConv(
                    in_channels=emb_dim,
                    out_channels=emb_dim // heads,
                    heads=heads,
                    dropout=drop_ratio,
                    edge_dim=edge_emb_dim,
                )
            )
            self.norms1.append(nn.LayerNorm(emb_dim))
            self.norms2.append(nn.LayerNorm(emb_dim))
            self.ffns.append(
                nn.Sequential(
                    nn.Linear(emb_dim, 4 * emb_dim),
                    nn.ReLU(inplace=True),
                    nn.Dropout(drop_ratio),
                    nn.Linear(4 * emb_dim, emb_dim),
                    nn.Dropout(drop_ratio),
                )
            )


        if pool == 'mean':
            self.pool = global_mean_pool
        elif pool == 'max':
            self.pool = global_max_pool
        else:
            raise ValueError("Not defined pooling!")


        self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)

        if self.task == 'classification':
            self.pred_head = nn.Sequential(
                nn.Linear(self.feat_dim, self.feat_dim // 2),
                nn.Softplus(),
                nn.Linear(self.feat_dim // 2, 2),
            )
        else:
            raise ValueError("task must be 'classification'")

    def _edge_encode(self, edge_attr):
        e = self.edge_embedding1(edge_attr[:, 0].long()) + \
            self.edge_embedding2(edge_attr[:, 1].long())
        return self.edge_proj(e)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch


        if self.use_3D:
            pos= data.pos
            pe = self.pos_emb(pos)
        else:
            pe = 0


        h = self.x_embedding1(x[:, 0].long()) + self.x_embedding2(x[:, 1].long())
        h = h + pe
        edge_feat = self._edge_encode(edge_attr)


        for conv, norm1, norm2, ffn in zip(self.layers, self.norms1, self.norms2, self.ffns):
            h_res = h
            h = conv(h, edge_index, edge_attr=edge_feat)
            h = F.dropout(h, p=self.drop_ratio, training=self.training)
            h = norm1(h_res + h)

            h_res2 = h
            h = ffn(h)
            h = norm2(h_res2 + h)


        h = self.pool(h, batch)
        h = self.feat_lin(h)

        return h, self.pred_head(h)

    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.parameter.Parameter):
                param = param.data
            own_state[name].copy_(param)


Finetuning Loop for Pretrained Graph Transformer

In [None]:
import os
import shutil
import sys
import yaml
import numpy as np
import pandas as pd
from datetime import datetime

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error

def _save_config_file(model_checkpoints_folder):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        shutil.copy('./config_finetune.yaml', os.path.join(model_checkpoints_folder, 'config_finetune.yaml'))

class Normalizer(object):

    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']


class FineTune(object):
    def __init__(self, dataset, config):
        self.config = config
        self.device = self._get_device()

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        dir_name = current_time + '_' + config['task_name'] + '_' + config['dataset']['target']
        log_dir = os.path.join('finetune', dir_name)
        self.writer = SummaryWriter(log_dir=log_dir)
        self.dataset = dataset
        self.criterion = nn.CrossEntropyLoss()

    def _get_device(self):
        if torch.cuda.is_available() and self.config['gpu'] != 'cpu':
            device = self.config['gpu']
            torch.cuda.set_device(device)
        else:
            device = 'cpu'
        print("Running on:", device)

        return device

    def _step(self, model, data, n_iter):
        # get the prediction
        __, pred = model(data)  # [N,C]
        loss = self.criterion(pred, data.y.flatten())

        return loss

    def train(self):
        train_loader, valid_loader, test_loader = self.dataset.get_data_loaders()

        self.normalizer = None

        model = GraphTransformer(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)
        layer_list = []
        for name, param in model.named_parameters():
            if 'pred_head' in name:
                print(name, param.requires_grad)
                layer_list.append(name)

        params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in layer_list, model.named_parameters()))))
        base_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] not in layer_list, model.named_parameters()))))

        optimizer = torch.optim.Adam(
            [{'params': base_params, 'lr': self.config['init_base_lr']}, {'params': params}],
            self.config['init_lr'], weight_decay=eval(self.config['weight_decay'])
        )

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf
        best_valid_rgr = np.inf
        best_valid_cls = 0

        for epoch_counter in range(self.config['epochs']):
            for bn, data in enumerate(train_loader):
                optimizer.zero_grad()

                data = data.to(self.device)
                loss = self._step(model, data, n_iter)

                if n_iter % self.config['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss', loss, global_step=n_iter)
                    print(epoch_counter, bn, loss.item())
                loss.backward()

                optimizer.step()
                n_iter += 1

            # validate the model if requested
            if epoch_counter % self.config['eval_every_n_epochs'] == 0:
                if self.config['dataset']['task'] == 'classification':
                    valid_loss, valid_cls = self._validate(model, valid_loader)
                    if valid_cls > best_valid_cls:
                        # save the model weights
                        best_valid_cls = valid_cls
                        torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))

                self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1

        self._test(model, test_loader)

    def _load_pre_trained_weights(self, model):
        try:
            checkpoints_folder = os.path.join('./ckpt', self.config['fine_tune_from'], 'checkpoints')
            state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=self.device)
            # model.load_state_dict(state_dict)
            model.load_my_state_dict(state_dict)
            print("Loaded pre-trained model with success.")
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

        return model

    def _validate(self, model, valid_loader):
        predictions = []
        labels = []
        with torch.no_grad():
            model.eval()

            valid_loss = 0.0
            num_data = 0
            for bn, data in enumerate(valid_loader):
                data = data.to(self.device)

                __, pred = model(data)
                loss = self._step(model, data, bn)

                valid_loss += loss.item() * data.y.size(0)
                num_data += data.y.size(0)

                if self.normalizer:
                    pred = self.normalizer.denorm(pred)

                if self.config['dataset']['task'] == 'classification':
                    pred = F.softmax(pred, dim=-1)

                if self.device == 'cpu':
                    predictions.extend(pred.detach().numpy())
                    labels.extend(data.y.flatten().numpy())
                else:
                    predictions.extend(pred.cpu().detach().numpy())
                    labels.extend(data.y.cpu().flatten().numpy())

            valid_loss /= num_data

        model.train()
        predictions = np.array(predictions)
        labels = np.array(labels)
        roc_auc = roc_auc_score(labels, predictions[:,1])
        print('Validation loss:', valid_loss, 'ROC AUC:', roc_auc)
        return valid_loss, roc_auc

    def _test(self, model, test_loader):
        model_path = os.path.join(self.writer.log_dir, 'checkpoints', 'model.pth')
        state_dict = torch.load(model_path, map_location=self.device)
        model.load_state_dict(state_dict)
        print("Loaded trained model with success.")

        # test steps
        predictions = []
        labels = []
        with torch.no_grad():
            model.eval()

            test_loss = 0.0
            num_data = 0
            for bn, data in enumerate(test_loader):
                data = data.to(self.device)

                __, pred = model(data)
                loss = self._step(model, data, bn)

                test_loss += loss.item() * data.y.size(0)
                num_data += data.y.size(0)

                if self.normalizer:
                    pred = self.normalizer.denorm(pred)

                if self.config['dataset']['task'] == 'classification':
                    pred = F.softmax(pred, dim=-1)

                if self.device == 'cpu':
                    predictions.extend(pred.detach().numpy())
                    labels.extend(data.y.flatten().numpy())
                else:
                    predictions.extend(pred.cpu().detach().numpy())
                    labels.extend(data.y.cpu().flatten().numpy())

            test_loss /= num_data

        model.train()

        predictions = np.array(predictions)
        labels = np.array(labels)
        self.roc_auc = roc_auc_score(labels, predictions[:,1])
        print('Test loss:', test_loss, 'Test ROC AUC:', self.roc_auc)


def main(config):
    dataset = MolTestDatasetWrapper(config['batch_size'], **config['dataset'])

    fine_tune = FineTune(dataset, config)
    fine_tune.train()

    if config['dataset']['task'] == 'classification':
        return fine_tune.roc_auc


if __name__ == "__main__":
    config = yaml.load(open("config_finetune.yaml", "r"), Loader=yaml.FullLoader)


    if config['task_name'] == 'Tox21':
        config['dataset']['task'] = 'classification'
        config['dataset']['data_path'] = 'data/tox21/tox21.csv'
        target_list = [
            "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD",
            "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
        ]

    else:
        raise ValueError('Undefined downstream task!')

    results_list = []
    for target in target_list:
        config['dataset']['target'] = target
        result = main(config)
        results_list.append([target, result])

    os.makedirs('experiments', exist_ok=True)
    df = pd.DataFrame(results_list)
    df.to_csv(
        'experiments/{}_{}_finetune.csv'.format(config['fine_tune_from'], config['task_name']),
        mode='a', index=False, header=False
    )

# GINE Assay multi head cross attention

In [None]:
import os
import numpy as np
import pandas as pd

from rdkit import Chem

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool

from sklearn.metrics import roc_auc_score, average_precision_score

from sentence_transformers import SentenceTransformer


# ============================================================
# 0. CONFIG
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

DATA_DIR = r"E:\graphml project\novel\processed"

TRAIN_CSV  = os.path.join(DATA_DIR, "train_clean.csv")
VAL_CSV    = os.path.join(DATA_DIR, "val_clean.csv")
TEST_CSV   = os.path.join(DATA_DIR, "test_clean.csv")

TRAIN_DESC = os.path.join(DATA_DIR, "train_rdkit_desc.npz")
VAL_DESC   = os.path.join(DATA_DIR, "val_rdkit_desc.npz")
TEST_DESC  = os.path.join(DATA_DIR, "test_rdkit_desc.npz")

ASSAYS  = [
    "NR-AR","NR-AR-LBD","NR-AhR","NR-Aromatase",
    "NR-ER","NR-ER-LBD","NR-PPAR-gamma",
    "SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]
WEIGHTS = [f"w{i}" for i in range(1, 13)]
N_TASKS = len(ASSAYS)

# Natural language descriptions for each assay (for text prompts)
ASSAY_DESCRIPTIONS = {
    "NR-AR":          "androgen receptor binding and endocrine disruption potential",
    "NR-AR-LBD":      "androgen receptor ligand binding domain interaction",
    "NR-AhR":         "aryl hydrocarbon receptor activation and xenobiotic metabolism",
    "NR-Aromatase":   "aromatase enzyme inhibition and steroid metabolism disruption",
    "NR-ER":          "estrogen receptor binding and hormonal activity modulation",
    "NR-ER-LBD":      "estrogen receptor ligand binding domain interaction",
    "NR-PPAR-gamma":  "peroxisome proliferator activated receptor gamma activation",
    "SR-ARE":         "antioxidant response element activation and oxidative stress response",
    "SR-ATAD5":       "ATAD5 biomarker response indicating genotoxicity",
    "SR-HSE":         "heat shock response element activation and protein stress",
    "SR-MMP":         "mitochondrial membrane potential disruption and cytotoxicity",
    "SR-p53":         "p53 tumor suppressor pathway activation and DNA damage response"
}


# ============================================================
# 1. CHEM â†’ GRAPH HELPERS
# ============================================================

def atom_features(atom):
    """Simple numeric atom features."""
    return np.array([
        atom.GetAtomicNum(),       # Z
        atom.GetTotalDegree(),     # degree
        atom.GetFormalCharge(),    # charge
        atom.GetTotalNumHs(),      # attached Hs
        int(atom.GetIsAromatic())  # aromatic flag
    ], dtype=np.float32)


def bond_features(bond):
    """Edge features for GINE (numeric)."""
    bt = bond.GetBondType()
    bond_type = {
        Chem.BondType.SINGLE: 1.0,
        Chem.BondType.DOUBLE: 2.0,
        Chem.BondType.TRIPLE: 3.0,
        Chem.BondType.AROMATIC: 1.5,
    }.get(bt, 0.0)
    return np.array([
        bond_type,
        float(bond.GetIsConjugated()),
        float(bond.IsInRing())
    ], dtype=np.float32)


def smiles_to_graph_with_pe(smiles, y_vec, w_vec):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None


    atom_feats = []
    for atom in mol.GetAtoms():
        atom_feats.append(atom_features(atom))
    x = torch.tensor(np.stack(atom_feats, axis=0), dtype=torch.float)

    rows, cols, eattr = [], [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bf = bond_features(b)
        rows += [i, j]; cols += [j, i]
        eattr += [bf, bf]

    edge_index = torch.tensor([rows, cols], dtype=torch.long)
    edge_attr = torch.tensor(np.stack(eattr, axis=0), dtype=torch.float) if len(eattr) > 0 else torch.zeros((0, 6))

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.tensor(y_vec, dtype=torch.float).view(1, -1),
        weight=torch.tensor(w_vec, dtype=torch.float).view(1, -1)
    )

    # 2. APPLY POSITIONAL ENCODING (The New Part)
    # This adds a 'data.pe' attribute of shape [num_atoms, 20]
    data = pe_transform(data)

    return data


# ============================================================
# 2. LOAD CSV + DESCRIPTORS â†’ GRAPH DATASETS
# ============================================================

def load_split(csv_path, desc_path):
    df = pd.read_csv(csv_path)
    desc_npz = np.load(desc_path)["X"]  # [N, D]

    # Which SMILES column?
    if "smiles_canonical" in df.columns:
        smi_col = "smiles_canonical"
    elif "smiles" in df.columns:
        smi_col = "smiles"
    else:
        raise ValueError("No 'smiles' or 'smiles_canonical' column found.")

    assert len(df) == desc_npz.shape[0], "CSV and desc shapes mismatch."

    graphs = []
    skipped = 0
    for i, row in df.iterrows():
        smiles = str(row[smi_col])
        y_vec  = row[ASSAYS].astype(float).values
        w_vec  = row[WEIGHTS].astype(float).values
        d_vec  = desc_npz[i]

        g = smiles_to_data(smiles, y_vec, w_vec, d_vec)
        if g is None:
            skipped += 1
            continue
        graphs.append(g)

    print(f"[{os.path.basename(csv_path)}] Loaded {len(graphs)} graphs, skipped {skipped} invalid.")
    return graphs, desc_npz


print("Loading datasets...")
train_graphs, train_desc = load_split(TRAIN_CSV, TRAIN_DESC)
val_graphs,   val_desc   = load_split(VAL_CSV,   VAL_DESC)
test_graphs,  test_desc  = load_split(TEST_CSV,  TEST_DESC)

# ---- Standardize descriptor features (important!) ----
desc_mean = train_desc.mean(axis=0, keepdims=True)
desc_std  = train_desc.std(axis=0, keepdims=True) + 1e-8

train_desc_norm = (train_desc - desc_mean) / desc_std
val_desc_norm   = (val_desc   - desc_mean) / desc_std
test_desc_norm  = (test_desc  - desc_mean) / desc_std

# Attach normalized desc back to each Data object
for i, g in enumerate(train_graphs):
    g.desc_features = torch.tensor(train_desc_norm[i], dtype=torch.float)
for i, g in enumerate(val_graphs):
    g.desc_features = torch.tensor(val_desc_norm[i], dtype=torch.float)
for i, g in enumerate(test_graphs):
    g.desc_features = torch.tensor(test_desc_norm[i], dtype=torch.float)

# Descriptor dimension
desc_dim = train_graphs[0].desc_features.size(0)
print("Descriptor dim:", desc_dim)


# ============================================================
# 3. DATALOADERS
# ============================================================

BATCH_SIZE = 64

train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_graphs,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_graphs,  batch_size=BATCH_SIZE, shuffle=False)


# ============================================================
# 4. BUILD TEXT PROMPT EMBEDDINGS (SentenceTransformer)
# ============================================================

print("Encoding assay descriptions with SentenceTransformer...")
text_model = SentenceTransformer("all-MiniLM-L6-v2")  # 384-dim

assay_names = ASSAYS  # same order as labels
descriptions = [ASSAY_DESCRIPTIONS[a] for a in assay_names]
text_embs = text_model.encode(descriptions, convert_to_numpy=True)  # [12, 384]
text_embs = torch.tensor(text_embs, dtype=torch.float)
text_dim = text_embs.shape[1]
print(f"Text embedding dim: {text_dim}")


# ============================================================
# 5. GINE + ASSAY-PROMPT CROSS-ATTENTION MODEL
# ============================================================

class GINECrossAttention(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim,
                 desc_feature_dim, n_tasks,
                 text_dim=384, gnn_hidden=128,
                 initial_prompts=None):
        super().__init__()

        self.n_tasks = n_tasks
        self.desc_feature_dim = desc_feature_dim

        # ----- 1. GINE backbone -----
        nn1 = nn.Sequential(
            nn.Linear(node_feature_dim, gnn_hidden),
            nn.ReLU(),
            nn.Linear(gnn_hidden, gnn_hidden),
            nn.ReLU(),
        )
        self.conv1 = GINEConv(nn1, edge_dim=edge_feature_dim)

        nn2 = nn.Sequential(
            nn.Linear(gnn_hidden, gnn_hidden),
            nn.ReLU(),
            nn.Linear(gnn_hidden, gnn_hidden),
            nn.ReLU(),
        )
        self.conv2 = GINEConv(nn2, edge_dim=edge_feature_dim)

        self.bn1 = nn.BatchNorm1d(gnn_hidden)
        self.bn2 = nn.BatchNorm1d(gnn_hidden)
        self.gnn_dropout = 0.2

        self.gnn_proj = nn.Sequential(
            nn.Linear(gnn_hidden * 2, gnn_hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # ----- 2. Assay prompt embeddings (initialized from text) -----
        if initial_prompts is None:
            self.assay_prompts = nn.Parameter(torch.randn(n_tasks, text_dim))
        else:
            assert initial_prompts.shape == (n_tasks, text_dim)
            self.assay_prompts = nn.Parameter(initial_prompts.clone())

        self.text_proj = nn.Sequential(
            nn.Linear(text_dim, gnn_hidden),
            nn.LayerNorm(gnn_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

        # ----- 3. Cross-attention -----
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=gnn_hidden,
            num_heads=8,
            dropout=0.1,
            batch_first=True,
        )
        self.layer_norm1 = nn.LayerNorm(gnn_hidden)
        self.layer_norm2 = nn.LayerNorm(gnn_hidden)
        self.cross_ffn = nn.Sequential(
            nn.Linear(gnn_hidden, gnn_hidden * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(gnn_hidden * 2, gnn_hidden),
        )

        # ----- 4. Classifier (cross + descriptors) -----
        classifier_dim = gnn_hidden + desc_feature_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, n_tasks),
        )

    def encode_graph(self, data):
        x, edge_index, edge_attr, batch = (
            data.x, data.edge_index, data.edge_attr, data.batch
        )

        h = self.conv1(x, edge_index, edge_attr)
        h = self.bn1(h)
        h = F.elu(h)
        h = F.dropout(h, p=self.gnn_dropout, training=self.training)

        h = self.conv2(h, edge_index, edge_attr)
        h = self.bn2(h)
        h = F.elu(h)

        h_mean = global_mean_pool(h, batch)
        h_max  = global_max_pool(h, batch)
        h_cat  = torch.cat([h_mean, h_max], dim=1)

        gnn_feat = self.gnn_proj(h_cat)   # [B, g_hidden]
        return gnn_feat

    def forward(self, data, assay_attention=None):
        B = data.num_graphs

        # 1. Graph encoding
        gnn_feat = self.encode_graph(data)   # [B, g_hidden]

        # 2. Descriptor features
        desc_feat = data.desc_features.view(B, -1)   # [B, desc_dim]

        # 3. Assay attention: shape [B, n_tasks]
        if assay_attention is None:
            assay_attention = torch.ones(B, self.n_tasks, device=gnn_feat.device) / self.n_tasks

        # 4. Prompt mixing: weighted sum of assay prompt embeddings
        # assay_prompts: [n_tasks, text_dim], assay_attention: [B, n_tasks]
        text_prompts = torch.einsum('bi,ij->bj', assay_attention, self.assay_prompts)  # [B, text_dim]
        text_feat = self.text_proj(text_prompts)  # [B, g_hidden]

        # 5. Cross-attention: text queries attend to graph representation
        text_q   = text_feat.unsqueeze(1)   # [B, 1, g_hidden]
        graph_kv = gnn_feat.unsqueeze(1)    # [B, 1, g_hidden]

        attended, _ = self.cross_attention(
            query=text_q,
            key=graph_kv,
            value=graph_kv,
        )
        attended = attended.squeeze(1)   # [B, g_hidden]

        # 6. Residual + FFN
        cross_feat = self.layer_norm1(text_feat + attended)
        cross_feat = self.layer_norm2(cross_feat + self.cross_ffn(cross_feat))

        # 7. Concatenate with descriptors and classify
        combined = torch.cat([cross_feat, desc_feat], dim=1)  # [B, g_hidden + desc_dim]
        logits = self.classifier(combined)                   # [B, n_tasks]
        return logits


# ============================================================
# 6. ASYMMETRIC LOSS FOR IMBALANCE
# ============================================================

class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps

    def forward(self, logits, targets, weights):
        """
        logits, targets, weights: [B, T]
        weights is used as mask (1 where label present, 0 where missing)
        """
        x_sigmoid = torch.sigmoid(logits)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1.0)

        los_pos = targets * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1.0 - targets) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg

        pt0 = xs_neg ** self.gamma_neg
        pt1 = xs_pos ** self.gamma_pos
        loss = loss * (pt0 + pt1)

        weighted_loss = -loss * weights
        return weighted_loss.sum() / (weights.sum() + 1e-8)


# ============================================================
# 7. TRAIN & EVAL FUNCTIONS
# ============================================================

def train_epoch(loader, model, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        B = batch.num_graphs

        y = batch.y.float().view(B, N_TASKS)
        w = batch.weight.float().view(B, N_TASKS)
        mask = (w > 0).float()

        # Assay attention: emphasize labeled assays
        if mask.sum() > 0:
            assay_attention = mask / mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
        else:
            assay_attention = torch.ones(B, N_TASKS, device=device) / N_TASKS

        logits = model(batch, assay_attention)   # [B, T]
        loss = criterion(logits, y, mask)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * B
        total_graphs += B

    return total_loss / total_graphs


def evaluate(loader, model):
    model.eval()
    all_probs = []
    all_labels = []
    all_weights = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            B = batch.num_graphs

            y = batch.y.float().view(B, N_TASKS)
            w = batch.weight.float().view(B, N_TASKS)

            # Uniform attention during eval
            assay_attention = torch.ones(B, N_TASKS, device=device) / N_TASKS
            logits = model(batch, assay_attention)
            probs  = torch.sigmoid(logits)

            all_probs.append(probs.cpu())
            all_labels.append(y.cpu())
            all_weights.append(w.cpu())

    probs  = torch.cat(all_probs, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    weights= torch.cat(all_weights, dim=0).numpy()

    roc_scores = {}
    pr_scores  = {}

    for j, assay in enumerate(ASSAYS):
        mask = weights[:, j] > 0
        if mask.sum() < 5:
            roc_scores[assay] = np.nan
            pr_scores[assay]  = np.nan
            continue

        y_true = labels[mask, j]
        y_pred = probs[mask, j]

        try:
            roc_scores[assay] = roc_auc_score(y_true, y_pred)
            pr_scores[assay]  = average_precision_score(y_true, y_pred)
        except ValueError:
            roc_scores[assay] = np.nan
            pr_scores[assay]  = np.nan

    mean_roc = np.nanmean(list(roc_scores.values()))
    mean_pr  = np.nanmean(list(pr_scores.values()))
    return roc_scores, pr_scores, mean_roc, mean_pr


# ============================================================
# 8. MAIN TRAIN LOOP
# ============================================================

def main():
    EPOCHS = 60
    LR = 2e-4
    WEIGHT_DECAY = 1e-4

    sample = train_graphs[0]
    node_dim = sample.x.size(1)
    edge_dim = sample.edge_attr.size(1)
    print("Node dim:", node_dim, "| Edge dim:", edge_dim)

    model = GINECrossAttention(
        node_feature_dim=node_dim,
        edge_feature_dim=edge_dim,
        desc_feature_dim=desc_dim,
        n_tasks=N_TASKS,
        text_dim=text_dim,
        gnn_hidden=128,
        initial_prompts=text_embs  # <-- text-initialized prompts
    ).to(device)

    print("Model params:", sum(p.numel() for p in model.parameters()))

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        betas=(0.9, 0.95),
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=EPOCHS, eta_min=1e-6
    )
    criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.05)

    best_val_roc = -1.0
    best_state = None

    for epoch in range(1, EPOCHS + 1):
        train_loss = train_epoch(train_loader, model, optimizer, criterion)
        roc_val, pr_val, mean_roc_val, mean_pr_val = evaluate(val_loader, model)
        scheduler.step()

        print(f"Epoch {epoch:03d} | loss={train_loss:.4f} | "
              f"val ROC={mean_roc_val:.4f} | val PR={mean_pr_val:.4f} | "
              f"LR={scheduler.get_last_lr()[0]:.2e}")

        if mean_roc_val > best_val_roc:
            best_val_roc = mean_roc_val
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            print("  â†’ New best model!")

    if best_state is not None:
        model.load_state_dict(best_state)
        model.to(device)
        torch.save(model.state_dict(), os.path.join(DATA_DIR, "gine_textfusion_best.pt"))
        print("Saved best model to gine_textfusion_best.pt")

    # Final test
    roc_test, pr_test, mean_roc_test, mean_pr_test = evaluate(test_loader, model)
    print("\n=== FINAL TEST METRICS (GINE + text prompts + desc) ===")
    for a in ASSAYS:
        print(f"{a:13s} | ROC-AUC={roc_test[a]:.4f} | PR-AUC={pr_test[a]:.4f}")
    print("---------------------------------------------")
    print(f"Mean            | ROC-AUC={mean_roc_test:.4f} | PR-AUC={mean_pr_test:.4f}")


if __name__ == "__main__":
    main()


***PNA cross attention***

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import PNAConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

# ============================
# Data Loading
# ============================

ASSAYS = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
    "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Load graphs ----
train_graphs = torch.load("graphs/train_2d.pt")
val_graphs   = torch.load("graphs/val_2d.pt")
test_graphs  = torch.load("graphs/test_2d.pt")

# ---- Load descriptors ----
use_desc = True
if use_desc:
    train_desc = np.load(r"E:\graphml project\novel\processed\train_rdkit_desc.npz")["X"]
    val_desc   = np.load(r"E:\graphml project\novel\processed\val_rdkit_desc.npz")["X"]
    test_desc  = np.load(r"E:\graphml project\novel\processed\test_rdkit_desc.npz")["X"]
    desc_dim = train_desc.shape[1]
else:
    desc_dim = 32
    train_desc = np.zeros((len(train_graphs), desc_dim), dtype=np.float32)
    val_desc   = np.zeros((len(val_graphs), desc_dim), dtype=np.float32)
    test_desc  = np.zeros((len(test_graphs), desc_dim), dtype=np.float32)

# ---- Attach features ----
def attach_features(graph_list, desc_array):
    for i, g in enumerate(graph_list):
        g.desc_features = torch.from_numpy(desc_array[i]).float()
    return graph_list

train_graphs = attach_features(train_graphs, train_desc)
val_graphs   = attach_features(val_graphs, val_desc)
test_graphs  = attach_features(test_graphs, test_desc)

BATCH_SIZE = 64
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

# ============================
# Text Prompts Definition
# ============================

ASSAY_DESCRIPTIONS = {
    "NR-AR": "androgen receptor binding and endocrine disruption potential",
    "NR-AR-LBD": "androgen receptor ligand binding domain interaction",
    "NR-AhR": "aryl hydrocarbon receptor activation and xenobiotic metabolism",
    "NR-Aromatase": "aromatase enzyme inhibition and steroid metabolism",
    "NR-ER": "estrogen receptor binding and hormonal activity",
    "NR-ER-LBD": "estrogen receptor ligand binding domain interaction",
    "NR-PPAR-gamma": "peroxisome proliferator-activated receptor gamma activation",
    "SR-ARE": "antioxidant response element activation and oxidative stress",
    "SR-ATAD5": "ATAD5 biomarker response and genotoxicity",
    "SR-HSE": "heat shock response element activation and protein stress",
    "SR-MMP": "mitochondrial membrane potential disruption and cytotoxicity",
    "SR-p53": "p53 tumor suppressor pathway activation and DNA damage response"
}

print("Assay Prompts:")
for assay, desc in ASSAY_DESCRIPTIONS.items():
    print(f"  {assay}: {desc}")

# ============================
# PNA + Text Enhanced Model
# ============================

class PNAWithText(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks,
                 hidden_dim=128, num_layers=4, dropout=0.2):
        super().__init__()

        self.n_tasks = n_tasks
        self.desc_feature_dim = desc_feature_dim
        self.hidden_dim = hidden_dim

        # --- PNA Configuration ---
        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        # Degree distribution for molecular graphs
        self.deg = torch.tensor([0, 1, 2, 3, 4])

        # --- Feature Projections ---
        self.node_proj = nn.Sequential(
            nn.Linear(node_feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # --- PNA Layers ---
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for i in range(num_layers):
            conv = PNAConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                aggregators=aggregators,
                scalers=scalers,
                deg=self.deg,
                edge_dim=edge_feature_dim,
                towers=1,
                pre_layers=1,
                post_layers=1
            )
            self.convs.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # --- Text Prompts ---
        self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))
        self.text_proj = nn.Sequential(
            nn.Linear(128, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # --- Cross-Attention Fusion ---
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # --- Enhanced Classifier ---
        classifier_input_dim = hidden_dim * 2 + hidden_dim + desc_feature_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, n_tasks)
        )

    def forward(self, data, assay_attention=None):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # 1. Project node features
        x = self.node_proj(x)

        # 2. PNA Message Passing
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x_residual = x
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)

            # Residual connection
            if i % 2 == 1:
                x = x + x_residual

            x = F.dropout(x, p=0.2, training=self.training)

        # 3. Graph Readout
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        graph_features = torch.cat([mean_pool, max_pool], dim=1)

        # 4. Text Conditioning
        if assay_attention is None:
            assay_attention = torch.ones(len(graph_features), self.n_tasks,
                                       device=graph_features.device) / self.n_tasks

        text_embeddings = torch.einsum('bi,ij->bj', assay_attention, self.assay_prompts)
        text_features = self.text_proj(text_embeddings)

        # 5. Cross-Attention
        text_as_query = text_features.unsqueeze(1)
        graph_as_kv = graph_features[:, :self.hidden_dim].unsqueeze(1)

        attended_features, _ = self.cross_attention(
            query=text_as_query,
            key=graph_as_kv,
            value=graph_as_kv
        )
        attended_features = attended_features.squeeze(1)

        # 6. Descriptor Features
        desc_features = data.desc_features.view(len(graph_features), -1)

        # 7. Final Fusion
        combined_features = torch.cat([graph_features, attended_features, desc_features], dim=1)
        logits = self.classifier(combined_features)

        return logits

# ============================
# Training Functions
# ============================

def train_pna_epoch(loader, model, optimizer, criterion, device, n_tasks):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        B = batch.num_graphs

        # Smart assay attention
        y_batch = batch.y.float().view(B, n_tasks)
        w_batch = batch.weight.float().view(B, n_tasks)
        labeled_mask = (w_batch > 0).float()

        if labeled_mask.sum() > 0:
            assay_attention = labeled_mask / labeled_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
        else:
            assay_attention = torch.ones(B, n_tasks, device=device) / n_tasks

        logits = model(batch, assay_attention)

        # Loss computation
        y = batch.y.float().view(-1, n_tasks)
        w = batch.weight.float().view(-1, n_tasks)

        loss_unreduced = criterion(logits, y)
        mask = (w > 0).float()
        loss = (loss_unreduced * mask).sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * B
        total_graphs += B

    return total_loss / total_graphs

def evaluate(loader, model, device, assays):
    model.eval()
    all_probs = []
    all_labels = []
    all_weights = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            B = batch.num_graphs

            assay_attention = torch.ones(B, len(assays), device=device) / len(assays)

            logits = model(batch, assay_attention)
            probs = torch.sigmoid(logits)

            y = batch.y.float().view(-1, len(assays))
            w = batch.weight.float().view(-1, len(assays))

            all_probs.append(probs.cpu())
            all_labels.append(y.cpu())
            all_weights.append(w.cpu())

    probs = torch.cat(all_probs, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    weights = torch.cat(all_weights, dim=0).numpy()

    roc_scores = {}
    pr_scores = {}

    for j, assay in enumerate(assays):
        mask = weights[:, j] > 0
        if mask.sum() < 5:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan
            continue

        y_true = labels[mask, j]
        y_pred = probs[mask, j]

        try:
            roc_scores[assay] = roc_auc_score(y_true, y_pred)
            pr_scores[assay] = average_precision_score(y_true, y_pred)
        except ValueError:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan

    mean_roc = np.nanmean(list(roc_scores.values()))
    mean_pr = np.nanmean(list(pr_scores.values()))
    return roc_scores, pr_scores, mean_roc, mean_pr

# ============================
# Main Execution
# ============================

def main():
    EPOCHS = 100
    LR = 2e-4
    WEIGHT_DECAY = 1e-5

    print(f"Using device: {device}")

    # Initialize model
    sample = train_graphs[0]
    node_dim = sample.x.size(1)
    edge_dim = sample.edge_attr.size(1)

    print(f"Node features: {node_dim}")
    print(f"Edge features: {edge_dim}")
    print(f"Descriptor features: {desc_dim}")
    print(f"Number of tasks: {len(ASSAYS)}")

    model = PNAWithText(
        node_feature_dim=node_dim,
        edge_feature_dim=edge_dim,
        desc_feature_dim=desc_dim,
        n_tasks=len(ASSAYS),
        hidden_dim=128,
        num_layers=4
    ).to(device)

    print(f"PNA Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Optimizer and loss
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LR,
        weight_decay=WEIGHT_DECAY
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    # Training loop
    best_val_roc = -1.0
    print("\nStarting PNA Training...")

    for epoch in range(1, EPOCHS + 1):
        train_loss = train_pna_epoch(train_loader, model, optimizer, criterion, device, len(ASSAYS))
        roc_val, pr_val, mean_roc_val, mean_pr_val = evaluate(val_loader, model, device, ASSAYS)
        scheduler.step()

        print(f"Epoch {epoch:03d} | Loss: {train_loss:.4f} | Val ROC: {mean_roc_val:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

        if mean_roc_val > best_val_roc:
            best_val_roc = mean_roc_val
            torch.save(model.state_dict(), "pna_text_best.pt")
            print(f"  â†’ New best! (ROC: {best_val_roc:.4f})")

    # Final test
    model.load_state_dict(torch.load("pna_text_best.pt"))
    roc_test, pr_test, mean_roc_test, mean_pr_test = evaluate(test_loader, model, device, ASSAYS)

    print(f"\n FINAL PNA TEST RESULTS:")
    print(f"ROC-AUC: {mean_roc_test:.4f}")
    print(f"PR-AUC: {mean_pr_test:.4f}")

    # Show per-assay results
    print("\nPer-assay ROC-AUC:")
    for assay in ASSAYS:
        print(f"  {assay}: {roc_test[assay]:.4f}")

if __name__ == "__main__":
    main()

***Text enhanced weighted assay GINE MLP projection***

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

# ============================
# Text-Enhanced Model WITHOUT ECFP
# ============================

class TextEnhancedNoECFP(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, desc_feature_dim, n_tasks):
        super().__init__()

        self.n_tasks = n_tasks
        self.desc_feature_dim = desc_feature_dim

        # --- GNN Backbone ---
        nn1 = nn.Sequential(
            nn.Linear(node_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )
        self.gnn_conv1 = GINEConv(nn1, edge_dim=edge_feature_dim)

        nn2 = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )
        self.gnn_conv2 = GINEConv(nn2, edge_dim=edge_feature_dim)

        self.gnn_batch_norm1 = nn.BatchNorm1d(128)
        self.gnn_batch_norm2 = nn.BatchNorm1d(128)

        # Graph output dimension
        gnn_out_dim = 256  # mean + max pool

        # --- Learnable Text Prompts for Each Assay ---
        self.assay_prompts = nn.Parameter(torch.randn(n_tasks, 128))

        # Text projection
        self.text_proj = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # --- Assay-Conditioned Fusion ---
        self.assay_weights = nn.Parameter(torch.ones(n_tasks, 2))  # [12, 2] for gnn, desc

        # Final classifier (NO ECFP dimension)
        classifier_input_dim = 256 + desc_feature_dim + 128  # gnn + desc + text ONLY
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, n_tasks)
        )

    def forward_gnn(self, x, edge_index, edge_attr, batch):
        # GNN processing
        x = self.gnn_conv1(x, edge_index, edge_attr)
        x = self.gnn_batch_norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=0.2, training=self.training)

        x = self.gnn_conv2(x, edge_index, edge_attr)
        x = self.gnn_batch_norm2(x)
        x = F.elu(x)

        # Readout
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        return torch.cat([mean_pool, max_pool], dim=1)

    def forward(self, data, assay_attention=None):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        B = data.num_graphs

        # 1. Process GNN features
        graph_out = self.forward_gnn(x, edge_index, edge_attr, batch)  # [B, 256]
        desc_out = data.desc_features.view(B, self.desc_feature_dim)   # [B, desc_dim]

        # 2. Text prompts for assay conditioning
        if assay_attention is None:
            # Default: equal attention to all assays
            text_weights = torch.ones(B, self.n_tasks, device=graph_out.device) / self.n_tasks
        else:
            text_weights = assay_attention

        # Weighted average of text prompts
        text_feat = torch.einsum('bi,ij->bj', text_weights, self.assay_prompts)  # [B, 128]
        text_feat = self.text_proj(text_feat)  # [B, 128]

        # 3. Assay-weighted fusion of modalities
        modality_weights = F.softmax(self.assay_weights, dim=1)  # [12, 2]

        # Use average weights across assays
        avg_weights = modality_weights.mean(dim=0)  # [2]

        # Apply weights to modalities (NO ECFP)
        weighted_graph = graph_out * avg_weights[0]
        weighted_desc = desc_out * avg_weights[1]

        # 4. Concatenate all features (NO ECFP)
        combined = torch.cat([weighted_graph, weighted_desc, text_feat], dim=1)

        # 5. Final prediction
        return self.classifier(combined)

# ============================
# Text Prompts Definition
# ============================

# Define meaningful text prompts for each assay
ASSAY_DESCRIPTIONS = {
    "NR-AR": "androgen receptor binding and endocrine disruption potential",
    "NR-AR-LBD": "androgen receptor ligand binding domain interaction",
    "NR-AhR": "aryl hydrocarbon receptor activation and xenobiotic metabolism",
    "NR-Aromatase": "aromatase enzyme inhibition and steroid metabolism",
    "NR-ER": "estrogen receptor binding and hormonal activity",
    "NR-ER-LBD": "estrogen receptor ligand binding domain interaction",
    "NR-PPAR-gamma": "peroxisome proliferator-activated receptor gamma activation",
    "SR-ARE": "antioxidant response element activation and oxidative stress",
    "SR-ATAD5": "ATAD5 biomarker response and genotoxicity",
    "SR-HSE": "heat shock response element activation and protein stress",
    "SR-MMP": "mitochondrial membrane potential disruption",
    "SR-p53": "p53 tumor suppressor pathway activation and DNA damage"
}

# Convert to list in correct order
ASSAY_TEXTS = [ASSAY_DESCRIPTIONS[assay] for assay in ASSAYS]

# ============================
# Training Components
# ============================

def train_text_enhanced_epoch(loader, model, optimizer, criterion, device, n_tasks):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        B = batch.num_graphs

        # Strategy 1: Equal attention to all assays
        assay_attention = torch.ones(B, n_tasks, device=device) / n_tasks

        # Strategy 2: Focus on assays with positive labels in this batch
        y_batch = batch.y.float().view(B, n_tasks)
        w_batch = batch.weight.float().view(B, n_tasks)
        labeled_mask = (w_batch > 0).float()

        # If sample has specific assay labels, focus on those
        if labeled_mask.sum() > 0:
            assay_attention = labeled_mask / labeled_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)

        logits = model(batch, assay_attention)

        # Targets and weights
        y = batch.y.float().view(-1, n_tasks)
        w = batch.weight.float().view(-1, n_tasks)

        # Compute loss (only on labeled positions)
        loss_unreduced = criterion(logits, y)
        mask = (w > 0).float()
        loss = (loss_unreduced * mask).sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * B
        total_graphs += B

    return total_loss / total_graphs

def evaluate(loader, model, device, assays):
    model.eval()
    all_probs = []
    all_labels = []
    all_weights = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            B = batch.num_graphs

            # For evaluation, use equal attention to all assays
            assay_attention = torch.ones(B, len(assays), device=device) / len(assays)

            logits = model(batch, assay_attention)
            probs = torch.sigmoid(logits)

            y = batch.y.float().view(-1, len(assays))
            w = batch.weight.float().view(-1, len(assays))

            all_probs.append(probs.cpu())
            all_labels.append(y.cpu())
            all_weights.append(w.cpu())

    probs = torch.cat(all_probs, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    weights = torch.cat(all_weights, dim=0).numpy()

    roc_scores = {}
    pr_scores = {}

    for j, assay in enumerate(assays):
        mask = weights[:, j] > 0
        if mask.sum() < 5:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan
            continue

        y_true = labels[mask, j]
        y_pred = probs[mask, j]

        try:
            roc_scores[assay] = roc_auc_score(y_true, y_pred)
            pr_scores[assay] = average_precision_score(y_true, y_pred)
        except ValueError:
            roc_scores[assay] = np.nan
            pr_scores[assay] = np.nan

    mean_roc = np.nanmean(list(roc_scores.values()))
    mean_pr = np.nanmean(list(pr_scores.values()))
    return roc_scores, pr_scores, mean_roc, mean_pr

# ============================
# Data Preparation (NO ECFP)
# ============================

# Remove ECFP from your data loading
print("Preparing data WITHOUT ECFP...")
train_graphs = torch.load("graphs/train_2d.pt")
val_graphs   = torch.load("graphs/val_2d.pt")
test_graphs  = torch.load("graphs/test_2d.pt")
# Create zero ECFP features (minimal dimension to avoid errors)
train_fp = np.zeros((len(train_graphs), 1), dtype=np.float32)
val_fp = np.zeros((len(val_graphs), 1), dtype=np.float32)
test_fp = np.zeros((len(test_graphs), 1), dtype=np.float32)
fp_dim = 1

# Keep descriptors
if use_desc:
    train_desc = np.load(r"E:\graphml project\novel\processed\train_rdkit_desc.npz")["X"]
    val_desc = np.load(r"E:\graphml project\novel\processed\val_rdkit_desc.npz")["X"]
    test_desc = np.load(r"E:\graphml project\novel\processed\test_rdkit_desc.npz")["X"]
    desc_dim = train_desc.shape[1]
else:
    desc_dim = 32
    train_desc = np.zeros((len(train_graphs), desc_dim), dtype=np.float32)
    val_desc = np.zeros((len(val_graphs), desc_dim), dtype=np.float32)
    test_desc = np.zeros((len(test_graphs), desc_dim), dtype=np.float32)

# Attach features (ECFP will be zeros)
def attach_features_no_ecfp(graph_list, desc_array):
    for i, g in enumerate(graph_list):
        g.fp_features = torch.zeros(1).float()  # Minimal ECFP
        g.desc_features = torch.from_numpy(desc_array[i]).float()
    return graph_list

train_graphs = attach_features_no_ecfp(train_graphs, train_desc)
val_graphs = attach_features_no_ecfp(val_graphs, val_desc)
test_graphs = attach_features_no_ecfp(test_graphs, test_desc)

# DataLoaders (same as before)
BATCH_SIZE = 64
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

# ============================
# Main Training Pipeline
# ============================

def main():
    EPOCHS = 100
    LR = 2e-5
    WEIGHT_DECAY = 1e-5

    print("Initializing Text-Enhanced GNN WITHOUT ECFP...")

    # Get dimensions
    sample = train_graphs[0]
    node_dim = sample.x.size(1)
    edge_dim = sample.edge_attr.size(1)

    print(f"Node features: {node_dim}")
    print(f"Edge features: {edge_dim}")
    print(f"Descriptor features: {desc_dim}")
    print(f"Number of tasks: {len(ASSAYS)}")
    print("\nUsing Assay Prompts:")
    for assay, desc in ASSAY_DESCRIPTIONS.items():
        print(f"  {assay}: {desc}")

    # Initialize model
    model = TextEnhancedNoECFP(
        node_feature_dim=node_dim,
        edge_feature_dim=edge_dim,
        desc_feature_dim=desc_dim,
        n_tasks=len(ASSAYS)
    ).to(device)

    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Optimizer and loss
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LR,
        weight_decay=WEIGHT_DECAY
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=10, verbose=True
    )

    criterion = nn.BCEWithLogitsLoss(reduction='none')

    # Training loop
    best_val_roc = -1.0
    best_state = None
    patience = 15
    patience_counter = 0

    print("\nStarting Training...")
    print("Epoch | Train Loss | Val ROC-AUC | Val PR-AUC | LR")
    print("-" * 55)

    for epoch in range(1, EPOCHS + 1):
        # Training
        train_loss = train_text_enhanced_epoch(
            train_loader, model, optimizer, criterion, device, len(ASSAYS)
        )

        # Validation
        roc_val, pr_val, mean_roc_val, mean_pr_val = evaluate(val_loader, model, device, ASSAYS)

        # Update learning rate
        scheduler.step(mean_roc_val)

        print(f"{epoch:5d} | {train_loss:.4f}      | {mean_roc_val:.4f}      | {mean_pr_val:.4f}    | {optimizer.param_groups[0]['lr']:.2e}")

        # Save best model
        if mean_roc_val > best_val_roc:
            best_val_roc = mean_roc_val
            best_state = model.state_dict().copy()
            patience_counter = 0
            torch.save(model.state_dict(), "text_enhanced_no_ecfp_best.pt")
            print(f"  â†’ New best! (ROC-AUC: {best_val_roc:.4f})")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

    print(f"\nTraining completed. Best validation ROC-AUC: {best_val_roc:.4f}")

    # Load best model for testing
    if best_state is not None:
        model.load_state_dict(best_state)
        print("Loaded best model for testing")

    # Final evaluation
    roc_test, pr_test, mean_roc_test, mean_pr_test = evaluate(test_loader, model, device, ASSAYS)

    print("\n" + "=" * 65)
    print("FINAL TEST METRICS (Text-Enhanced GNN WITHOUT ECFP)")
    print("=" * 65)
    for assay in ASSAYS:
        print(f"{assay:15s} | ROC-AUC: {roc_test[assay]:.4f} | PR-AUC: {pr_test[assay]:.4f}")
    print("-" * 65)
    print(f"{'Mean':15s} | ROC-AUC: {mean_roc_test:.4f} | PR-AUC: {mean_pr_test:.4f}")

    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'assay_prompts': model.assay_prompts.detach().cpu(),
        'assay_descriptions': ASSAY_DESCRIPTIONS,
        'test_metrics': {
            'roc_auc': roc_test,
            'pr_auc': pr_test,
            'mean_roc': mean_roc_test,
            'mean_pr': mean_pr_test
        },
        'config': {
            'use_ecfp': False,
            'use_text_prompts': True,
            'use_descriptors': True
        }
    }, "text_enhanced_no_ecfp_final.pt")

    print("\nModel saved as 'text_enhanced_no_ecfp_final.pt'")

    # Show learned prompt similarities
    print("\nLearned assay prompt similarities:")
    prompts = model.assay_prompts.detach().cpu()
    similarities = F.cosine_similarity(prompts.unsqueeze(1), prompts.unsqueeze(0), dim=2)

    # Show top similar assay pairs
    similar_pairs = []
    for i in range(len(ASSAYS)):
        for j in range(i + 1, len(ASSAYS)):
            similar_pairs.append((i, j, similarities[i, j].item()))

    similar_pairs.sort(key=lambda x: x[2], reverse=True)
    for i, j, sim in similar_pairs[:5]:  # Top 5 most similar
        print(f"  {ASSAYS[i]:15s} â†” {ASSAYS[j]:15s}: {sim:.3f}")

# Run the training
if __name__ == "__main__":
    main()

***Graph trasnformer with positional encoding Text+Graph***

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, LayerNorm
from torch_geometric.utils import to_dense_batch, degree
from torch_geometric.data import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score
import warnings

warnings.filterwarnings("ignore")

# ============================
# CONFIGURATION
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")

HIDDEN_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
DROPOUT = 0.4
BATCH_SIZE = 32
LR = 1e-4             # Slightly lower LR for stability with ASL
WEIGHT_DECAY = 1e-3
EPOCHS = 100
PATIENCE = 15

ASSAYS = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
    "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]
N_TASKS = len(ASSAYS)

# ============================
# 1. THE ENHANCED LOSS: ASYMMETRIC LOSS (ASL)
# ============================
class AsymmetricLossOptimized(nn.Module):
    """
    ASL: Focuses on Hard Negatives and Positives.
    - gamma_neg=4: Heavily suppresses easy negatives (95% of data).
    - gamma_pos=1: Lightly focuses on hard positives.
    - clip=0.05: Completely ignores negatives with p < 0.05 (Noise removal).
    """
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricLossOptimized, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y):
        # x: Logits, y: Targets (0 or 1)

        # Calculate Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic Cross Entropy Components
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))

        # Asymmetric Focusing
        loss = -1 * los_pos * (1 - xs_pos) ** self.gamma_pos - \
               1 * los_neg * (1 - xs_neg) ** self.gamma_neg

        return loss

# ============================
# 2. MULTI-TASK LEARNABLE WRAPPER
# ============================
class MultiTaskLossWrapper(nn.Module):
    """
    Learns a variance (sigma) for each task.
    Loss = Loss_task / (2 * sigma^2) + log(sigma)
    This allows the model to dynamically balance the 12 tasks during training.
    """
    def __init__(self, num_tasks):
        super().__init__()
        self.num_tasks = num_tasks
        # Initialize log_vars to 0 (sigma = 1)
        self.log_vars = nn.Parameter(torch.zeros((num_tasks)))

    def forward(self, losses):
        # losses should be shape [Batch, NumTasks] or [NumTasks]

        # Precision = 1 / (2 * sigma^2)
        precision = torch.exp(-self.log_vars)

        # Weighted Loss
        weighted_loss = torch.sum(precision.to(losses.device) * losses)

        # Regularization term (prevents sigma from going to infinity)
        log_term = torch.sum(self.log_vars)

        return weighted_loss + log_term

# ============================
# 3. GRAPH TRANSFORMER MODEL
# ============================
class LaplacianPE(nn.Module):
    def __init__(self, k=8, hidden_dim=128):
        super().__init__()
        self.k = k
        self.embedding = nn.Linear(k, hidden_dim)

    def forward(self, data):
        deg = degree(data.edge_index[0], data.num_nodes, dtype=torch.float)
        # Safe Inverse
        deg_inv = deg.pow(-1)
        deg_inv[deg == 0] = 0

        pe_list = [
            deg.unsqueeze(1),
            deg.pow(2).unsqueeze(1),
            deg_inv.unsqueeze(1),
            torch.log(deg + 1).unsqueeze(1)
        ]
        pe = torch.cat(pe_list, dim=1)
        if pe.size(1) < self.k:
            pe = F.pad(pe, (0, self.k - pe.size(1)))
        else:
            pe = pe[:, :self.k]
        return self.embedding(pe)

class DeepGraphTransformer(nn.Module):
    def __init__(self, node_dim, edge_dim, num_tasks,
                 hidden_dim=128, num_heads=4, num_layers=4, dropout=0.4):
        super().__init__()

        self.node_emb = nn.Linear(node_dim, hidden_dim)
        self.edge_emb = nn.Linear(edge_dim, hidden_dim)
        self.pe_enc = LaplacianPE(k=8, hidden_dim=hidden_dim)

        # Learnable Semantic Prompts
        self.task_prompts = nn.Parameter(torch.randn(num_tasks, hidden_dim))

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            self.layers.append(
                TransformerConv(hidden_dim, hidden_dim // num_heads, heads=num_heads,
                                dropout=dropout, edge_dim=hidden_dim)
            )
            self.norms.append(LayerNorm(hidden_dim))

        self.dropout = nn.Dropout(dropout)

        # Prompt Attention
        self.prompt_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=dropout)
        self.prompt_norm = nn.LayerNorm(hidden_dim)

        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        h = self.node_emb(x)
        edge_h = self.edge_emb(edge_attr)
        h = h + self.pe_enc(data)

        for conv, norm in zip(self.layers, self.norms):
            h_in = h
            h = conv(h, edge_index, edge_attr=edge_h)
            h = norm(h, batch)
            h = F.relu(h)
            h = self.dropout(h)
            h = h + h_in

        h_nodes, mask = to_dense_batch(h, batch)

        # Prompts Query the Molecule
        prompts = self.task_prompts.unsqueeze(0).expand(h_nodes.size(0), -1, -1)

        attn_out, _ = self.prompt_attn(prompts, h_nodes, h_nodes, key_padding_mask=~mask)
        h_out = self.prompt_norm(prompts + attn_out)

        # Output: [Batch, 12, Dim] -> [Batch, 12]
        logits = self.classifier(h_out).squeeze(-1)
        return logits

# ============================
# 4. DATA LOADING
# ============================
def load_data():
    print("Loading Data...")
    train_graphs = torch.load("graphs/train_2d.pt")
    val_graphs = torch.load("graphs/val_2d.pt")
    test_graphs = torch.load("graphs/test_2d.pt")

    def prepare(graphs):
        processed = []
        for g in graphs:
            # Generate dummy edges if missing (Safety)
            if not hasattr(g, 'edge_attr') or g.edge_attr is None:
                 g.edge_attr = torch.ones((g.edge_index.size(1), 1), dtype=torch.float)
            elif g.edge_attr.dim() == 1:
                g.edge_attr = g.edge_attr.unsqueeze(1)
            processed.append(g)
        return processed

    return prepare(train_graphs), prepare(val_graphs), prepare(test_graphs)

# ============================
# 5. TRAINING LOOP (WITH ASL + TASK WEIGHTING)
# ============================
def calculate_pos_weights(loader):
    """
    Calculates the exact imbalance ratio for each of the 12 tasks.
    If 'NR-ER' has 1 positive for 50 negatives, weight = 50.
    """
    all_y = []
    for batch in loader:
        all_y.append(batch.y.view(batch.num_graphs, -1))

    all_y = torch.cat(all_y, dim=0)
    weights = []

    print("\nCalculated Task Weights:")
    for i in range(12):
        # Filter out missing labels (-1 or NaN)
        valid = (all_y[:, i] != -1) & (~torch.isnan(all_y[:, i]))
        pos = (all_y[valid, i] == 1).sum().item()
        neg = (all_y[valid, i] == 0).sum().item()

        # Calculate Ratio
        if pos > 0:
            w = neg / pos
        else:
            w = 1.0

        # Clip max weight to 30 to prevent explosion
        w = min(w, 30.0)
        weights.append(w)
        print(f"  Task {i+1}: {w:.2f}")

    return torch.tensor(weights).to(device)

def train_epoch(model, loader, optimizer, pos_weights):
    model.train()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        logits = model(batch)
        y = batch.y.view(batch.num_graphs, -1).float()

        # Create mask for valid labels
        mask = (~torch.isnan(y)) & (y != -1)
        y_clean = torch.nan_to_num(y, 0.0)

        if mask.sum() == 0: continue

        # BCE with Explicit Positive Weighting
        # This handles the imbalance mathematically, not dynamically
        loss_ele = F.binary_cross_entropy_with_logits(
            logits,
            y_clean,
            reduction='none',
            pos_weight=pos_weights # <--- THE FIX
        )

        # Apply Mask
        loss = (loss_ele * mask).sum() / mask.sum().clamp(min=1)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    # 1. Initialize with consistent names
    y_true_list, y_pred_list, y_mask_list = [], [], []

    with torch.no_grad():
        for b in loader:
            b = b.to(device)
            logits = model(b)

            # 2. Append to the correct lists
            y_true_list.append(b.y.view(b.num_graphs, -1).cpu())
            y_pred_list.append(torch.sigmoid(logits).cpu())

            mask = (~torch.isnan(b.y.view(b.num_graphs, -1))) & (b.y.view(b.num_graphs, -1) != -1)
            y_mask_list.append(mask.cpu())

    # 3. Concatenate
    y_true = torch.cat(y_true_list, 0).numpy()
    y_pred = torch.cat(y_pred_list, 0).numpy()
    y_mask = torch.cat(y_mask_list, 0).numpy()

    rocs = []
    for i in range(12):
        valid = y_mask[:, i].astype(bool)
        # Check for sufficient data
        if valid.sum() < 5 or len(np.unique(y_true[valid, i])) < 2:
            continue

        rocs.append(roc_auc_score(y_true[valid, i], y_pred[valid, i]))

    return np.mean(rocs) if rocs else 0.0

def evaluate_per_prompt(model, loader):
    model.eval()
    # 1. Initialize with consistent names
    y_true_list, y_pred_list, y_mask_list = [], [], []

    with torch.no_grad():
        for b in loader:
            b = b.to(device)
            logits = model(b)

            # 2. Append to the correct lists
            y_true_list.append(b.y.view(b.num_graphs, -1).cpu())
            y_pred_list.append(torch.sigmoid(logits).cpu())

            mask = (~torch.isnan(b.y.view(b.num_graphs, -1))) & (b.y.view(b.num_graphs, -1) != -1)
            y_mask_list.append(mask.cpu())

    # 3. Concatenate
    y_true = torch.cat(y_true_list, 0).numpy()
    y_pred = torch.cat(y_pred_list, 0).numpy()
    y_mask = torch.cat(y_mask_list, 0).numpy()

    print(f"\n{'TASK':<20} | {'AUC':<10}")
    print("-" * 35)

    res = {}
    for i, a in enumerate(ASSAYS):
        v = y_mask[:, i].astype(bool)
        if v.sum() < 5 or len(np.unique(y_true[v, i])) < 2:
            continue

        auc = roc_auc_score(y_true[v, i], y_pred[v, i])
        res[a] = auc
        print(f"{a:<20} | {auc:.4f}")

    print("-" * 35)
    print(f"MEAN: {np.mean(list(res.values())):.4f}")

# ============================
# 6. MAIN
# ============================
def main():
    train_data, val_data, test_data = load_data()

    train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_data, BATCH_SIZE)
    test_loader = DataLoader(test_data, BATCH_SIZE)

    # 1. Calculate Weights
    pos_weights = calculate_pos_weights(train_loader)

    sample = train_data[0]
    # Init Model
    model = DeepGraphTransformer(
        sample.x.shape[1], sample.edge_attr.shape[1], N_TASKS,
        hidden_dim=HIDDEN_DIM, num_heads=NUM_HEADS, num_layers=NUM_LAYERS, dropout=DROPOUT
    ).to(device)

    # Standard Optimizer (No wrapper parameters)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5)

    print("\nStarting Stabilized Transformer Training...")
    best_auc = 0

    for epoch in range(1, EPOCHS+1):
        loss = train_epoch(model, train_loader, optimizer, pos_weights)
        val_auc = evaluate(model, val_loader)
        scheduler.step(val_auc)

        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "best_stable_transformer.pt")
            print(f"  >>> New Best! (AUC: {best_auc:.4f})")

    # Final Test
    print("\nFinal Test Breakdown:")
    model.load_state_dict(torch.load("best_stable_transformer.pt"))
    evaluate_per_prompt(model, test_loader)

if __name__ == "__main__":
    main()