In [None]:
# import necessary libraries

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys, Descriptors
import torch
from torch_geometric.data import Data
import deepchem as dc

In [None]:
np.random.seed(42)
torch.manual_seed(42)
os.makedirs("tox21_processed", exist_ok=True)

In [None]:
tasks, datasets, transformers = dc.molnet.load_tox21(featurizer="Raw", data_dir=".", save_dir=".")
train_dataset, valid_dataset, test_dataset = datasets

In [None]:
train_df = train_dataset.to_dataframe()
valid_df = valid_dataset.to_dataframe()
test_df = test_dataset.to_dataframe()

for i, task in enumerate(tasks):
    train_df.rename(columns={f'y{i+1}': task}, inplace=True)
    valid_df.rename(columns={f'y{i+1}': task}, inplace=True)
    test_df.rename(columns={f'y{i+1}': task}, inplace=True)

# save raw data to CSV files
train_df.to_csv("tox21_processed/train_raw.csv", index=False)
valid_df.to_csv("tox21_processed/valid_raw.csv", index=False)
test_df.to_csv("tox21_processed/test_raw.csv", index=False)

In [None]:

print(f"Number of tasks: {len(tasks)}")
print(f"Tasks: {tasks}")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(valid_df)}")
print(f"Test set size: {len(test_df)}")
print("Sample data:")

print("Data loading and initial processing complete.")
train_df.head()


In [None]:
print("Missing values in training set:")
print(train_df.isnull().sum())
print("Missing values in validation set:")
print(valid_df.isnull().sum())
print("Missing values in test set:")
print(test_df.isnull().sum())

# distribution of labels for each task
for task in tasks:
    print(f"Label distribution for task {task} in training set:")
    print(train_df[task].value_counts(dropna=False))
    print(f"Label distribution for task {task} in validation set:")
    print(valid_df[task].value_counts(dropna=False))
    print(f"Label distribution for task {task} in test set:")
    print(test_df[task].value_counts(dropna=False))

# visualize some molecules
print("Sample molecules from training set:")
for smi in train_df['ids'].head(5):
    mol = Chem.MolFromSmiles(smi)
    display(Chem.Draw.MolToImage(mol))

In [None]:

import matplotlib.pyplot as plt
import numpy as np
import os
import os.path as osp
import pandas as pd

def add_no_finding_label(df, tasks, label_name="No-finding"):
    task_vals = df[tasks].fillna(0).values
    no_finding = (np.sum(task_vals, axis=1) == 0).astype(int)
    df[label_name] = no_finding
    print(f" Added '{label_name}' column — {df[label_name].sum()} samples marked as 'no finding'.")
    return df


def show_split_distribution_tox21_stacked(train_df, val_df, test_df, tasks, original_df=None, save_name="tox21_label_distribution_stacked.png"):
    def compute_ratios(df):
        pos_ratio = []
        neg_ratio = []
        for t in tasks:
            valid = df[t].notna()
            if valid.sum() == 0:
                pos_ratio.append(0)
                neg_ratio.append(0)
                continue
            pos = df.loc[valid, t].sum()
            total = valid.sum()
            pos_ratio.append((pos / total) * 100)
            neg_ratio.append(100 - (pos / total) * 100)
        return pos_ratio, neg_ratio

    train_pos, train_neg = compute_ratios(train_df)
    val_pos, val_neg     = compute_ratios(val_df)
    test_pos, test_neg   = compute_ratios(test_df)
    if original_df is not None:
        orig_pos, orig_neg = compute_ratios(original_df)

    plt.figure(figsize=(16, 6))
    x = np.arange(len(tasks))
    width = 0.2

    plt.bar(x - width*1.5, train_neg, width, label='Train (neg)', color='skyblue', alpha=0.7)
    plt.bar(x - width*1.5, train_pos, width, bottom=train_neg, label='Train (pos)', color='blue', alpha=0.8)

    plt.bar(x - width/2, val_neg, width, label='Val (neg)', color='navajowhite', alpha=0.7)
    plt.bar(x - width/2, val_pos, width, bottom=val_neg, label='Val (pos)', color='orange', alpha=0.8)

    plt.bar(x + width/2, test_neg, width, label='Test (neg)', color='palegreen', alpha=0.7)
    plt.bar(x + width/2, test_pos, width, bottom=test_neg, label='Test (pos)', color='green', alpha=0.8)

    if original_df is not None:
        plt.bar(x + width*1.5, orig_neg, width, label='Orig (neg)', color='lightgrey', alpha=0.7)
        plt.bar(x + width*1.5, orig_pos, width, bottom=orig_neg, label='Orig (pos)', color='grey', alpha=0.8)

    plt.ylabel("Percentage (%)")
    plt.title("Tox21 Label Distribution (Positive vs Negative Ratios per Task)")
    plt.xticks(x, tasks, rotation=45, ha='right')
    plt.legend(ncol=4, bbox_to_anchor=(0.5, -0.2), loc='upper center')
    plt.tight_layout()


    save_dir = "imgs"
    os.makedirs(save_dir, exist_ok=True)
    save_path = osp.join(save_dir, save_name)
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

tasks = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
    "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]

train_df['split'] = 'train'
valid_df['split'] = 'valid'
test_df['split']  = 'test'

train_df = add_no_finding_label(train_df, tasks)
valid_df = add_no_finding_label(valid_df, tasks)
test_df  = add_no_finding_label(test_df,  tasks)
tasks_with_no = tasks + ["No-finding"]

# Stack them together
tox21_df = pd.concat([train_df, valid_df, test_df], ignore_index=True)
cols = ['split'] + [c for c in tox21_df.columns if c != 'split']
tox21_df = tox21_df[cols]

show_split_distribution_tox21_stacked(train_df, valid_df, test_df, tasks_with_no, original_df=tox21_df)

### Stratified, Round-Robin, Rare-First Sampling

In [None]:
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.preprocessing import MultiLabelBinarizer
import random
import pandas as pd
from collections import defaultdict, Counter

def multilabel_split_tox21(
    df,
    tasks,
    train_size=20000,
    val_size=2500,
    test_size=2500,
    seed=42,
    holdout_test_df=None
):
    random.seed(seed)
    df = df.reset_index(drop=True).copy()

    # Exclude held-out test set
    if holdout_test_df is not None:
        test_df = holdout_test_df.copy()
        heldout_idx = set(test_df.index)
        df = df[~df.index.isin(heldout_idx)].reset_index(drop=True)
        print(f"Held out existing test set with {len(test_df)} samples.")
    else:
        test_df = None

    # Convert to numpy multilabel matrix
    Y = df[tasks].fillna(0).astype(int).values
    total = len(df)
    subsz = train_size + val_size + (0 if test_df is not None else test_size)

    msss = MultilabelStratifiedShuffleSplit(
        n_splits=1,
        train_size=train_size / subsz,
        test_size=val_size / subsz,
        random_state=seed
    )
    tr_i, val_i = next(msss.split(np.zeros(len(df)), Y))
    train_df = df.iloc[tr_i].reset_index(drop=True)
    val_df = df.iloc[val_i].reset_index(drop=True)

    if test_df is None:
        remain = df.drop(train_df.index.union(val_df.index))
        test_df = remain.sample(n=test_size, random_state=seed).reset_index(drop=True)

    print(f"Stratified Split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")
    return train_df, val_df, test_df

def multilabel_balanced_split_tox21(
    df,
    tasks,
    train_size=20000,
    val_size=2500,
    test_size=2500,
    seed=42,
    holdout_test_df=None
):
    random.seed(seed)
    df = df.reset_index(drop=True).copy()

    if holdout_test_df is not None:
        test_df = holdout_test_df.copy()
        heldout_idx = set(test_df.index)
        df = df[~df.index.isin(heldout_idx)].reset_index(drop=True)
        print(f"Held out fixed test set ({len(test_df)} samples).")
    else:
        test_df = None

    total_size = train_size + val_size + (0 if test_df is not None else test_size)
    label2idxs = defaultdict(list)
    
    for task in tasks:
        pos_idx = df.index[df[task] == 1].tolist()
        random.shuffle(pos_idx)
        label2idxs[task] = pos_idx

    labels = list(label2idxs.keys())
    label_ptr = {t: 0 for t in labels}
    sampled_idx = set()
    total_idxs = set(df.index)

    # Round-robin sampling
    while len(sampled_idx) < total_size:
        for label in labels:
            ptr = label_ptr[label]
            total_label_idx = label2idxs[label]
            if ptr >= len(total_label_idx):
                continue
            while ptr < len(total_label_idx):
                xi = total_label_idx[ptr]
                ptr += 1
                if xi not in sampled_idx:
                    sampled_idx.add(xi)
                    break
            label_ptr[label] = ptr
        # Stop if all labels exhausted
        if all(label_ptr[t] >= len(label2idxs[t]) for t in labels):
            break

    # Fill missing with negatives if needed
    remaining_needed = total_size - len(sampled_idx)
    if remaining_needed > 0:
        remain_pool = list(total_idxs - sampled_idx)
        random.shuffle(remain_pool)
        sampled_idx.update(remain_pool[:remaining_needed])

    sampled_idx = list(sampled_idx)
    random.shuffle(sampled_idx)

    train_ids = sampled_idx[:train_size]
    val_ids = sampled_idx[train_size:train_size + val_size]

    train_df = df.loc[train_ids].reset_index(drop=True)
    val_df = df.loc[val_ids].reset_index(drop=True)

    if test_df is None:
        remain = df.drop(train_df.index.union(val_df.index))
        test_df = remain.sample(n=test_size, random_state=seed).reset_index(drop=True)

    print(f" Balanced Split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")
    return train_df, val_df, test_df

def multilabel_rare_first_split_tox21(
    df,
    tasks,
    train_size=20000,
    val_size=2500,
    test_size=2500,
    seed=42,
    holdout_test_df=None
):
    random.seed(seed)
    df = df.reset_index(drop=True).copy()

    if holdout_test_df is not None:
        test_df = holdout_test_df.copy()
        heldout_idx = set(test_df.index)
        df = df[~df.index.isin(heldout_idx)].reset_index(drop=True)
        print(f" Held out {len(test_df)} samples for test.")
    else:
        test_df = None

    # Count positives per label
    label_counts = {t: int(df[t].fillna(0).sum()) for t in tasks}
    sorted_labels = sorted(label_counts.items(), key=lambda x: x[1])  # rare  common

    # Build mapping of task  indices with label == 1
    label2idxs = {t: df.index[df[t] == 1].tolist() for t in tasks}
    for idx_list in label2idxs.values():
        random.shuffle(idx_list)

    train_idx, val_idx, used_idx = set(), set(), set()
    total_needed = train_size + val_size + (0 if test_df is not None else test_size)

    for task, _ in sorted_labels:
        available = [i for i in label2idxs[task] if i not in used_idx]
        random.shuffle(available)
        for i in available:
            if len(train_idx) < train_size:
                train_idx.add(i)
            elif len(val_idx) < val_size:
                val_idx.add(i)
            used_idx.add(i)
            if len(train_idx) + len(val_idx) >= total_needed:
                break
        if len(train_idx) + len(val_idx) >= total_needed:
            break

    # Fill remaining from unassigned samples
    remaining = list(set(df.index) - used_idx)
    random.shuffle(remaining)
    for i in remaining:
        if len(train_idx) < train_size:
            train_idx.add(i)
        elif len(val_idx) < val_size:
            val_idx.add(i)

    train_df = df.loc[list(train_idx)].reset_index(drop=True)
    val_df = df.loc[list(val_idx)].reset_index(drop=True)

    if test_df is None:
        remain = list(set(df.index) - train_idx - val_idx)
        random.shuffle(remain)
        test_df = df.loc[remain[:test_size]].reset_index(drop=True)

    print(f" Rare-first Split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")
    return train_df, val_df, test_df



strat_train_df, strat_valid_df, strat_test_df=  multilabel_split_tox21(tox21_df, tasks_with_no, len(train_df), len(valid_df), len(test_df), holdout_test_df=test_df)
bal_train_df, bal_valid_df, bal_test_df=  multilabel_balanced_split_tox21(tox21_df,tasks_with_no, len(train_df), len(valid_df), len(test_df), holdout_test_df=test_df)
rare_train_df, rare_valid_df, rare_test_df=  multilabel_rare_first_split_tox21(tox21_df,tasks_with_no, len(train_df), len(valid_df), len(test_df), holdout_test_df=test_df)

show_split_distribution_tox21_stacked(strat_train_df, strat_valid_df, strat_test_df, tasks_with_no, original_df=tox21_df)
show_split_distribution_tox21_stacked(bal_train_df, bal_valid_df, bal_test_df, tasks_with_no, original_df=tox21_df)
show_split_distribution_tox21_stacked(rare_train_df, rare_valid_df, rare_test_df, tasks_with_no, original_df=tox21_df)

### Extract molecular features for training basic machine learning models.

In [None]:
# extract molecular features using RDKit and save to CSV files
def featurize_molecule(smi):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    features = {}
    # Basic descriptors
    features['MolWt'] = Descriptors.MolWt(mol)
    features['NumHDonors'] = Descriptors.NumHDonors(mol)
    features['NumHAcceptors'] = Descriptors.NumHAcceptors(mol)
    features['TPSA'] = Descriptors.TPSA(mol)
    features['LogP'] = Descriptors.MolLogP(mol)
    # MACCS keys
    maccs = MACCSkeys.GenMACCSKeys(mol)
    for i in range(167):
        features[f'MACCS_{i}'] = int(maccs.GetBit(i))
    # Morgan fingerprint
    morgan_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
    for i in range(2048):
        features[f'Morgan_{i}'] = int(morgan_fp.GetBit(i))
    return features

def featurize_dataset(df):
    feature_list = []
    for smi in tqdm(df['ids'], desc="Featurizing molecules"):
        feats = featurize_molecule(smi)
        if feats is not None:
            feature_list.append(feats)
        else:
            feature_list.append({})
    features_df = pd.DataFrame(feature_list)
    return pd.concat([df.reset_index(drop=True), features_df.reset_index(drop=True)], axis=1)

train_features_df = featurize_dataset(train_df)
valid_features_df = featurize_dataset(valid_df)
test_features_df = featurize_dataset(test_df)

strat_feat_train_df, strat_feat_valid_df, strat_feat_test_df=  featurize_dataset(strat_train_df), featurize_dataset(strat_valid_df), featurize_dataset(strat_test_df)
bal_feat_train_df, bal_feat_valid_df, bal_feat_test_df=  featurize_dataset(bal_train_df), featurize_dataset(bal_valid_df), featurize_dataset(bal_test_df)
rare_feat_train_df, rare_feat_valid_df, rare_feat_test_df=   featurize_dataset(rare_train_df), featurize_dataset(rare_valid_df), featurize_dataset(rare_test_df)


train_features_df.to_csv("tox21_processed/train_featurized.csv", index=False)
valid_features_df.to_csv("tox21_processed/valid_featurized.csv", index=False)   
test_features_df.to_csv("tox21_processed/test_featurized.csv", index=False)

strat_feat_train_df.to_csv("tox21_processed/strat_feat_train_df.csv", index=False)
strat_feat_valid_df.to_csv("tox21_processed/strat_feat_valid_df.csv", index=False)   
strat_feat_test_df.to_csv("tox21_processed/strat_feat_test_df.csv", index=False)

bal_feat_train_df.to_csv("tox21_processed/bal_feat_train_df.csv", index=False)
bal_feat_valid_df.to_csv("tox21_processed/bal_feat_valid_df.csv", index=False)   
bal_feat_test_df.to_csv("tox21_processed/bal_feat_test_df.csv", index=False)

rare_feat_train_df.to_csv("tox21_processed/rare_feat_train_df.csv", index=False)
rare_feat_valid_df.to_csv("tox21_processed/rare_feat_valid_df.csv", index=False)   
rare_feat_test_df.to_csv("tox21_processed/rare_feat_test_df.csv", index=False)

In [None]:
# check the new dataframes
print("Featurized training set sample:")
train_features_df.head()

In [None]:
print("Featurized training set info:")
print(train_features_df.info())
print("Featurized validation set info:")
print(valid_features_df.info())
print("Featurized test set info:")
print(test_features_df.info())
print(" Featurization complete.")

In [None]:
# print the object type columns to verify
print(train_features_df.select_dtypes(include=['object']).columns)

In [None]:
#check for any remaining missing values
print("Missing values in featurized training set:")
print(train_features_df.isnull().sum().sum())
print("Missing values in featurized validation set:")
print(valid_features_df.isnull().sum().sum())
print("Missing values in featurized test set:")
print(test_features_df.isnull().sum().sum())

In [None]:
from sklearn.preprocessing import StandardScaler
continuous_features = ['MolWt', 'NumHDonors', 'NumHAcceptors', 'TPSA', 'LogP']
scaler = StandardScaler()
train_features_df[continuous_features] = scaler.fit_transform(train_features_df[continuous_features])
valid_features_df[continuous_features] = scaler.transform(valid_features_df[continuous_features])
test_features_df[continuous_features] = scaler.transform(test_features_df[continuous_features])


scaler = StandardScaler()
strat_feat_train_df[continuous_features] , strat_feat_valid_df[continuous_features] , strat_feat_test_df[continuous_features] =  scaler.fit_transform(strat_feat_train_df[continuous_features]), scaler.transform(strat_feat_valid_df[continuous_features]),   scaler.transform(strat_feat_test_df[continuous_features])

scaler = StandardScaler()
bal_feat_train_df[continuous_features] , bal_feat_valid_df[continuous_features] , bal_feat_test_df[continuous_features] = scaler.fit_transform(bal_feat_train_df[continuous_features]), scaler.transform(bal_feat_valid_df[continuous_features]),   scaler.transform(bal_feat_test_df[continuous_features])

scaler = StandardScaler()
rare_feat_train_df[continuous_features] , rare_feat_valid_df[continuous_features] , rare_feat_test_df[continuous_features] =  scaler.fit_transform(rare_feat_train_df[continuous_features]), scaler.transform(rare_feat_valid_df[continuous_features]),   scaler.transform(rare_feat_test_df[continuous_features])


print("Normalization complete.")

In [None]:
# save normalized data to CSV files
train_features_df.to_csv("tox21_processed/train_normalized.csv", index=False)
valid_features_df.to_csv("tox21_processed/valid_normalized.csv", index=False)
test_features_df.to_csv("tox21_processed/test_normalized.csv", index=False)


strat_feat_train_df.to_csv("tox21_processed/strat_normalized_train_df.csv", index=False)
strat_feat_valid_df.to_csv("tox21_processed/strat_normalized_valid_df.csv", index=False)   
strat_feat_test_df.to_csv("tox21_processed/strat_normalized_test_df.csv", index=False)

bal_feat_train_df.to_csv("tox21_processed/bal_normalized_train_df.csv", index=False)
bal_feat_valid_df.to_csv("tox21_processed/bal_normalized_valid_df.csv", index=False)   
bal_feat_test_df.to_csv("tox21_processed/bal_normalized_test_df.csv", index=False)

rare_feat_train_df.to_csv("tox21_processed/rare_normalized_train_df.csv", index=False)
rare_feat_valid_df.to_csv("tox21_processed/rare_normalized_valid_df.csv", index=False)   
rare_feat_test_df.to_csv("tox21_processed/rare_normalized_test_df.csv", index=False)


### Extract graph features for training graph machine learning models.

In [None]:
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
import torch

# convert data to graph format for graph neural networks
def mol_to_graph_data_obj(mol, df, tasks=tasks):
    if mol is None:
        return None
    # node features
    atom_features_list = []

    # ensure molecule has 3D conformer
    if mol.GetNumConformers() == 0:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, AllChem.ETKDG())
        AllChem.UFFOptimizeMolecule(mol)
    conf = mol.GetConformer()
    coords = [] 
    for atom in mol.GetAtoms():
        pos = conf.GetAtomPosition(atom.GetIdx())
        coords.append([pos.x, pos.y, pos.z])
    pos = torch.tensor(coords, dtype=torch.float)


    for atom in mol.GetAtoms():
        atom_features = []
        atom_features.append(atom.GetAtomicNum())
        atom_features.append(atom.GetDegree())
        atom_features.append(atom.GetFormalCharge())
        atom_features.append(atom.GetHybridization())
        atom_features.append(int(atom.GetIsAromatic()))
        atom_features_list.append(atom_features)

    # standardize node features to tensor
    x = torch.tensor(atom_features_list, dtype=torch.float)

    # edge index and edge features
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append((i, j))
        edge_index.append((j, i))
        bond_type = bond.GetBondType()
        if bond_type == Chem.rdchem.BondType.SINGLE:
            edge_attr.append([1, 0, 0])
            edge_attr.append([1, 0, 0])
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            edge_attr.append([0, 1, 0])
            edge_attr.append([0, 1, 0])
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            edge_attr.append([0, 0, 1])
            edge_attr.append([0, 0, 1])
        else:
            edge_attr.append([0, 0, 0])
            edge_attr.append([0, 0, 0])
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 3), dtype=torch.float)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,pos=pos)
    
    smi = Chem.MolToSmiles(mol)
    labels = df[df['ids'] == smi][tasks].values
    if len(labels) > 0:
        data.y = torch.tensor(labels[0], dtype=torch.float)
    else:
        data.y = torch.tensor([float('nan')] * len(tasks), dtype=torch.float)

    return data

def featurize_graph_dataset(df, tasks=tasks):
    graph_list = []
    for smi in tqdm(df['ids'], desc="Featurizing molecules to graphs"):
        mol = Chem.MolFromSmiles(smi)
        graph = mol_to_graph_data_obj(mol, df, tasks=tasks)

        # get label row for this molecule
        row = df[df['ids'] == smi][tasks].values
        if len(row) == 0:
            continue
        row = row[0]

        graph_list.append(graph)
    return graph_list

train_graphs = featurize_graph_dataset(train_df)
valid_graphs = featurize_graph_dataset(valid_df)
test_graphs = featurize_graph_dataset(test_df)

strat_train_graphs = featurize_graph_dataset(strat_feat_train_df, tasks = tasks_with_no)
strat_valid_graphs = featurize_graph_dataset(strat_feat_valid_df, tasks = tasks_with_no)
strat_test_graphs = featurize_graph_dataset(strat_feat_test_df, tasks = tasks_with_no)

bal_train_graphs = featurize_graph_dataset(bal_train_df, tasks = tasks_with_no)
bal_valid_graphs = featurize_graph_dataset(bal_valid_df, tasks = tasks_with_no)
bal_test_graphs = featurize_graph_dataset(bal_test_df, tasks = tasks_with_no)

rare_train_graphs = featurize_graph_dataset(rare_train_df, tasks = tasks_with_no)
rare_valid_graphs = featurize_graph_dataset(rare_valid_df, tasks = tasks_with_no)
rare_test_graphs = featurize_graph_dataset(rare_test_df, tasks = tasks_with_no)

print(f"Number of training graphs: {len(train_graphs)}")
print(f"Number of validation graphs: {len(valid_graphs)}")
print(f"Number of test graphs: {len(test_graphs)}")
print("Sample graph data object:")
print(train_graphs[0])

In [None]:
# print(train_graphs[138].y)

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw
import torch

# pick one molecule and its graph
idx = 138
mol = Chem.MolFromSmiles(train_df['ids'].iloc[idx])
graph = train_graphs[idx]

# extract label
label_value = graph.y.item() if graph.y.numel() == 1 else graph.y.tolist()
print("Label(s):", label_value)

# draw with RDKit
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Label: {label_value}")
img.show()

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

def visualize_graph_by_label(data, title_prefix="Graph"):
    G = to_networkx(data, to_undirected=True)

    # get label
    if data.y.numel() == 1:
        label = data.y.item()
    else:
        label = data.y.tolist()

    # choose color based on label
    if isinstance(label, (int, float)):
        color = "red" if label > 0.5 else "blue"  # binary coloring
    else:
        color = "gray"

    plt.figure(figsize=(5, 5))
    nx.draw(
        G,
        with_labels=True,
        node_color=color,
        edge_color="gray",
        node_size=600,
        font_weight="bold",
    )
    plt.title(f"{title_prefix} | Label: {label}")
    plt.show()

# Example
visualize_graph_by_label(train_graphs[138], title_prefix="Training Graph 0")

In [None]:
train_graphs[138]

In [None]:
# save graph data objects using torch.save
torch.save(train_graphs, "tox21_processed/train_graphs.pt")
torch.save(valid_graphs, "tox21_processed/valid_graphs.pt")
torch.save(test_graphs, "tox21_processed/test_graphs.pt")

torch.save(strat_train_graphs, "tox21_processed/strat_train_graphs.pt")
torch.save(strat_valid_graphs, "tox21_processed/strat_valid_graphs.pt")
torch.save(strat_test_graphs, "tox21_processed/strat_test_graphs.pt")

torch.save(bal_train_graphs, "tox21_processed/bal_train_graphs.pt")
torch.save(bal_valid_graphs, "tox21_processed/bal_valid_graphs.pt")
torch.save(bal_test_graphs, "tox21_processed/bal_test_graphs.pt")

torch.save(rare_train_graphs, "tox21_processed/rare_train_graphs.pt")
torch.save(rare_valid_graphs, "tox21_processed/rare_valid_graphs.pt")
torch.save(rare_test_graphs, "tox21_processed/rare_test_graphs.pt")

In [None]:
import deepchem as dc
from rdkit import Chem
from collections import Counter
import matplotlib.pyplot as plt

tasks, datasets, transformers = dc.molnet.load_tox21(
    featurizer="Raw", data_dir=".", save_dir="."
)
train_dataset, valid_dataset, test_dataset = datasets

# Convert to pandas DataFrames
train_df = train_dataset.to_dataframe()
valid_df = valid_dataset.to_dataframe()
test_df = test_dataset.to_dataframe()

# Combine all Mol objects
all_mols = list(train_df["X"]) + list(valid_df["X"]) + list(test_df["X"])

# Count atom occurrences
atom_counter = Counter()
for mol in all_mols:
    if mol is None:
        continue
    for atom in mol.GetAtoms():
        symbol = atom.GetSymbol()
        if symbol != "H": 
            atom_counter[symbol] += 1

# Sort by frequency
atoms, counts = zip(*sorted(atom_counter.items(), key=lambda x: x[1], reverse=True))

# Plot
plt.figure(figsize=(10, 6))
plt.bar(atoms, counts, color='skyblue')
plt.xlabel("Atom Type", fontsize=12)
plt.ylabel("Total Count (Train + Valid + Test)", fontsize=12)
plt.title("Tox21 Atom Type Frequency", fontsize=14)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [None]:
from rdkit import Chem

rare_atoms = {
    "Br", "I", "Si", "B", "Se", "Sn", "Te", "Zr", "Cr", "V", "Ti",
    "Ge", "As", "Cu", "Zn", "Pd", "Ru", "Pt", "Mn", "Ni", "Co",
    "Na", "K", "Ca", "Mg", "Li", "Al", "Bi", "Sb", "Hg", "Pb"
}

def has_rare_atoms(mol):
    if mol is None:
        return False
    return any(atom.GetSymbol() in rare_atoms for atom in mol.GetAtoms())

# Combine all Mol objects again
all_mols = list(train_df["X"]) + list(valid_df["X"]) + list(test_df["X"])

# Count how many contain any rare atom
rare_count = sum(has_rare_atoms(mol) for mol in all_mols)
total_count = len(all_mols)

print(f"Total molecules: {total_count}")
print(f"Molecules with Br or rarer atoms: {rare_count}")
print(f"Fraction: {rare_count/total_count:.4f}")


In [None]:
import pandas as pd
import numpy as np
from rdkit import Chem
rare_atoms = {
    "Br", "I", "Si", "B", "Se", "Sn", "Te", "Zr", "Cr", "V", "Ti",
    "Ge", "As", "Cu", "Zn", "Pd", "Ru", "Pt", "Mn", "Ni", "Co",
    "Na", "K", "Ca", "Mg", "Li", "Al", "Bi", "Sb", "Hg", "Pb"
}

def has_rare_atoms(mol):
    if mol is None:
        return False
    return any(atom.GetSymbol() in rare_atoms for atom in mol.GetAtoms())

# Combine train/valid/test
all_df = pd.concat([
    train_dataset.to_dataframe(),
    valid_dataset.to_dataframe(),
    test_dataset.to_dataframe()
], ignore_index=True)

# Extract relevant columns
y_cols = [c for c in all_df.columns if c.startswith("y")]
mols = list(all_df["X"])
y = all_df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)

# Filter out molecules with rare atoms
rare_mask = np.array([has_rare_atoms(m) for m in mols])
filtered_y = y[~rare_mask]

# Identify "no finding" (all zeros or NaNs)
is_no_finding = np.all((np.isnan(filtered_y)) | (filtered_y == 0), axis=1)

# Print summary
total = len(filtered_y)
no_findings = np.sum(is_no_finding)
print(f"Total molecules (without rare atoms): {total}")
print(f"'No-finding' molecules (all y* = 0 or NaN): {no_findings}")
print(f"Fraction: {no_findings/total:.3f}")


In [None]:

rare_atoms = {
    "Br", "I", "Si", "B", "Se", "Sn", "Te", "Zr", "Cr", "V", "Ti",
    "Ge", "As", "Cu", "Zn", "Pd", "Ru", "Pt", "Mn", "Ni", "Co",
    "Na", "K", "Ca", "Mg", "Li", "Al", "Bi", "Sb", "Hg", "Pb"
}

def has_rare_atoms(mol):
    if mol is None:
        return False
    return any(atom.GetSymbol() in rare_atoms for atom in mol.GetAtoms())

def compute_split_stats(dataset, split_name):
    """Return stats for one split."""
    df = dataset.to_dataframe()
    mols = list(df["X"])
    y_cols = [c for c in df.columns if c.startswith("y")]
    y = df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)

    # rare-atom mask
    rare_mask = np.array([has_rare_atoms(m) for m in mols])

    # filter out rare atoms
    filtered_y = y[~rare_mask]

    # molecules with no findings (all y* = 0 or NaN)
    is_no_finding = np.all((np.isnan(filtered_y)) | (filtered_y == 0), axis=1)

    total = len(filtered_y)
    no_findings = np.sum(is_no_finding)
    print(f"=== {split_name.upper()} SPLIT ===")
    print(f"Total molecules (no rare atoms): {total}")
    print(f"'No-finding' molecules: {no_findings}")
    print(f"Fraction: {no_findings/total:.3f}\n")

# ---- Run for each split ----
compute_split_stats(train_dataset, "train")
compute_split_stats(valid_dataset, "valid")
compute_split_stats(test_dataset, "test")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem

rare_atoms = {
    "Br", "I", "Si", "B", "Se", "Sn", "Te", "Zr", "Cr", "V", "Ti",
    "Ge", "As", "Cu", "Zn", "Pd", "Ru", "Pt", "Mn", "Ni", "Co",
    "Na", "K", "Ca", "Mg", "Li", "Al", "Bi", "Sb", "Hg", "Pb"
}

def has_rare_atoms(mol):
    if mol is None:
        return False
    return any(atom.GetSymbol() in rare_atoms for atom in mol.GetAtoms())

def label_distribution(dataset, split_name):
    df = dataset.to_dataframe()
    mols = list(df["X"])
    y_cols = [c for c in df.columns if c.startswith("y")]
    y = df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)

    # remove rare-atom molecules
    rare_mask = np.array([has_rare_atoms(m) for m in mols])
    filtered_y = y[~rare_mask]

    # keep only molecules with at least one positive finding
    has_finding = np.any(filtered_y == 1, axis=1)
    y_with_findings = filtered_y[has_finding]

    # compute per-label positive counts
    pos_counts = np.nansum(y_with_findings == 1, axis=0)
    total = len(y_with_findings)
    fractions = pos_counts / total

    print(f"=== {split_name.upper()} SPLIT ===")
    print(f"Molecules with findings (no rare atoms): {total}")
    print(pd.DataFrame({'Assay': y_cols, 'Positives': pos_counts, 'Fraction': fractions}))

    # plot
    plt.figure(figsize=(8,4))
    plt.bar(y_cols, fractions, color="mediumseagreen")
    plt.xticks(rotation=45)
    plt.ylabel("Fraction positive")
    plt.title(f"{split_name.title()} Split – Label Distribution (with findings, no rare atoms)")
    plt.tight_layout()
    plt.show()

# Run for each split
label_distribution(train_dataset, "train")
label_distribution(valid_dataset, "valid")
label_distribution(test_dataset, "test")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from rdkit import Chem

def molecule_length(mol):
    if mol is None:
        return np.nan
    return sum(atom.GetSymbol() != "H" for atom in mol.GetAtoms())

def plot_molecule_length_distribution(dataset, split_name):
    df = dataset.to_dataframe()
    mols = list(df["X"])

    lengths = [molecule_length(m) for m in mols if m is not None]
    lengths = np.array(lengths, dtype=float)
    mean_len, median_len = np.nanmean(lengths), np.nanmedian(lengths)

    plt.figure(figsize=(7,4))
    plt.hist(lengths, bins=25, color="cornflowerblue", edgecolor="black", alpha=0.75)
    plt.title(f"{split_name.title()} Split Molecule Length Distribution", fontsize=13)
    plt.xlabel("Number of heavy atoms (no H)", fontsize=11)
    plt.ylabel("Count", fontsize=11)
    plt.axvline(mean_len, color="red", linestyle="--", label=f"Mean = {mean_len:.1f}")
    plt.axvline(median_len, color="orange", linestyle=":", label=f"Median = {median_len:.1f}")
    plt.legend()
    plt.tight_layout()
    plt.show()

    print(f"{split_name.title()} split  mean length: {mean_len:.2f}, median: {median_len:.2f}, "
          f"min: {np.nanmin(lengths)}, max: {np.nanmax(lengths)}")

# Run for each split
plot_molecule_length_distribution(train_dataset, "train")
plot_molecule_length_distribution(valid_dataset, "valid")
plot_molecule_length_distribution(test_dataset, "test")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from rdkit import Chem

def molecule_edges(mol):
    if mol is None:
        return np.nan
    return mol.GetNumBonds()

def plot_edge_distribution(dataset, split_name):
    df = dataset.to_dataframe()
    mols = list(df["X"])

    edges = [molecule_edges(m) for m in mols if m is not None]
    edges = np.array(edges, dtype=float)

    mean_edges, median_edges = np.nanmean(edges), np.nanmedian(edges)

    plt.figure(figsize=(7,4))
    plt.hist(edges, bins=25, color="lightcoral", edgecolor="black", alpha=0.75)
    plt.title(f"{split_name.title()} Split Edge (Bond) Count Distribution", fontsize=13)
    plt.xlabel("Number of bonds (edges)", fontsize=11)
    plt.ylabel("Count", fontsize=11)
    plt.axvline(mean_edges, color="red", linestyle="--", label=f"Mean = {mean_edges:.1f}")
    plt.axvline(median_edges, color="orange", linestyle=":", label=f"Median = {median_edges:.1f}")
    plt.legend()
    plt.tight_layout()
    plt.show()

    print(f"{split_name.title()} split mean edges: {mean_edges:.2f}, "
          f"median: {median_edges:.2f}, min: {np.nanmin(edges)}, max: {np.nanmax(edges)}")


plot_edge_distribution(train_dataset, "train")
plot_edge_distribution(valid_dataset, "valid")
plot_edge_distribution(test_dataset, "test")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem


rare_atoms = {
    "Br", "I", "Si", "B", "Se", "Sn", "Te", "Zr", "Cr", "V", "Ti",
    "Ge", "As", "Cu", "Zn", "Pd", "Ru", "Pt", "Mn", "Ni", "Co",
    "Na", "K", "Ca", "Mg", "Li", "Al", "Bi", "Sb", "Hg", "Pb"
}

def has_rare_atoms(mol):
    if mol is None:
        return False
    return any(atom.GetSymbol() in rare_atoms for atom in mol.GetAtoms())

def molecule_length(mol):
    if mol is None:
        return np.nan
    return sum(atom.GetSymbol() != "H" for atom in mol.GetAtoms())

def molecule_edges(mol):
    if mol is None:
        return np.nan
    return mol.GetNumBonds()

def filter_and_plot(dataset, split_name):
    df = dataset.to_dataframe()
    mols = list(df["X"])
    y_cols = [c for c in df.columns if c.startswith("y")]
    y = df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)

    rare_mask = np.array([has_rare_atoms(m) for m in mols])
    mols_filtered = [m for m, rare in zip(mols, rare_mask) if not rare]
    y_filtered = y[~rare_mask]


    has_finding = np.any(y_filtered == 1, axis=1)
    mols_final = [m for m, keep in zip(mols_filtered, has_finding) if keep]
    y_final = y_filtered[has_finding]

    # Compute molecule length and edge counts
    lengths = np.array([molecule_length(m) for m in mols_final if m is not None], dtype=float)
    edges   = np.array([molecule_edges(m) for m in mols_final if m is not None], dtype=float)

    print(f"=== {split_name.upper()} SPLIT ===")
    print(f"Molecules after filtering: {len(mols_final)}")
    print(f"Mean length: {np.nanmean(lengths):.2f}, edges: {np.nanmean(edges):.2f}")
    print(f"Median length: {np.nanmedian(lengths):.2f}, edges: {np.nanmedian(edges):.2f}\n")

    # Plot molecule length
    plt.figure(figsize=(7,4))
    plt.hist(lengths, bins=25, color="steelblue", edgecolor="black", alpha=0.75)
    plt.title(f"{split_name.title()} – Molecule Length (no rare atoms, with findings)", fontsize=13)
    plt.xlabel("Number of heavy atoms", fontsize=11)
    plt.ylabel("Count", fontsize=11)
    plt.axvline(np.nanmean(lengths), color="red", linestyle="--", label=f"Mean = {np.nanmean(lengths):.1f}")
    plt.axvline(np.nanmedian(lengths), color="orange", linestyle=":", label=f"Median = {np.nanmedian(lengths):.1f}")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Plot edge count 
    plt.figure(figsize=(7,4))
    plt.hist(edges, bins=25, color="salmon", edgecolor="black", alpha=0.75)
    plt.title(f"{split_name.title()} – Edge Count (no rare atoms, with findings)", fontsize=13)
    plt.xlabel("Number of bonds", fontsize=11)
    plt.ylabel("Count", fontsize=11)
    plt.axvline(np.nanmean(edges), color="red", linestyle="--", label=f"Mean = {np.nanmean(edges):.1f}")
    plt.axvline(np.nanmedian(edges), color="orange", linestyle=":", label=f"Median = {np.nanmedian(edges):.1f}")
    plt.legend()
    plt.tight_layout()
    plt.show()


filter_and_plot(train_dataset, "train")
filter_and_plot(valid_dataset, "valid")
filter_and_plot(test_dataset, "test")


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_12xagon_scatter(df, exclude_no_finding=True, n_samples=None):
    task_names = [
        "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
        "NR-ER", "NR-ER-LBD", "NR-PPAR-γ",
        "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
    ]


    y_cols = [c for c in df.columns if c.startswith("y")]
    y = df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)

    if exclude_no_finding:
        mask = ~np.all((np.isnan(y)) | (y == 0), axis=1)
        y = y[mask]


        idx = np.random.choice(len(y), size=min(n_samples, len(y)), replace=False)
        y = y[idx]

    y = np.nan_to_num(y, nan=0)

    # Compute 12-gon vertex directions
    N = y.shape[1]
    angles = np.linspace(0, 2*np.pi, N, endpoint=False)


    x_coords, y_coords = [], []
    for yi in y:
        if yi.sum() == 0:
            continue
        weights = yi / yi.sum()
        x = np.sum(weights * np.cos(angles))
        y_ = np.sum(weights * np.sin(angles))
        x_coords.append(x)
        y_coords.append(y_)

    plt.figure(figsize=(6,6))
    plt.scatter(
        x_coords, y_coords, s=20, alpha=0.6,
        color='mediumseagreen', edgecolors='black', linewidths=0.4
    )
    plt.title("Tox21 12-Label Dodecagon Scatter (molecules with findings only)", fontsize=13)
    plt.axis('equal')
    plt.xlabel("x")
    plt.ylabel("y")

    # Draw 12-gon boundary
    boundary_x = np.cos(angles).tolist() + [np.cos(angles[0])]
    boundary_y = np.sin(angles).tolist() + [np.sin(angles[0])]
    plt.plot(boundary_x, boundary_y, color='gray', linestyle='--', linewidth=1)

    # Label vertices using actual task names
    for i, (bx, by) in enumerate(zip(boundary_x[:-1], boundary_y[:-1])):
        plt.text(1.15*bx, 1.15*by, task_names[i], ha='center', va='center', fontsize=9)

    plt.tight_layout()
    plt.show()

train_df = train_dataset.to_dataframe()
plot_12xagon_scatter(train_df, exclude_no_finding=True, n_samples=None)



In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 

def plot_12xagon_3d_frequency(df, exclude_no_finding=True, bins=40):
    task_names = [
        "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
        "NR-ER", "NR-ER-LBD", "NR-PPAR-γ",
        "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
    ]


    y_cols = [c for c in df.columns if c.startswith("y")]
    y = df[y_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=np.float32)
    y = np.nan_to_num(y, nan=0)


    if exclude_no_finding:
        mask = ~np.all(y == 0, axis=1)
        y = y[mask]

    N = y.shape[1]
    angles = np.linspace(0, 2*np.pi, N, endpoint=False)
    x_coords, y_coords = [], []
    for yi in y:
        weights = yi / yi.sum() if yi.sum() > 0 else np.zeros_like(yi)
        x_coords.append(np.sum(weights * np.cos(angles)))
        y_coords.append(np.sum(weights * np.sin(angles)))
    x_coords, y_coords = np.array(x_coords), np.array(y_coords)


    hist, xedges, yedges = np.histogram2d(x_coords, y_coords, bins=bins)
    xcenters = 0.5 * (xedges[:-1] + xedges[1:])
    ycenters = 0.5 * (yedges[:-1] + yedges[1:])
    X, Y = np.meshgrid(xcenters, ycenters)
    Z = hist.T  


    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection="3d")
    surf = ax.plot_surface(X, Y, Z, cmap="YlGnBu", edgecolor="none", alpha=0.95)
    fig.colorbar(surf, shrink=0.6, aspect=10, label="Molecule Frequency")

    boundary_x = np.cos(angles).tolist() + [np.cos(angles[0])]
    boundary_y = np.sin(angles).tolist() + [np.sin(angles[0])]
    ax.plot(boundary_x, boundary_y, zs=0, color='gray', linestyle='--', linewidth=1)
    for i, (bx, by) in enumerate(zip(boundary_x[:-1], boundary_y[:-1])):
        ax.text(1.2*bx, 1.2*by, 0, task_names[i], fontsize=9, ha='center', va='center')

    ax.set_zlabel("Frequency")
    ax.set_title("Tox21 12-Label Dodecagon 3D Frequency Plot", fontsize=13)
    plt.tight_layout()
    plt.show()


train_df = train_dataset.to_dataframe()
plot_12xagon_3d_frequency(train_df, exclude_no_finding=True, bins=50)
plot_12xagon_3d_frequency(valid_df, exclude_no_finding=True, bins=50)
plot_12xagon_3d_frequency(test_df, exclude_no_finding=True, bins=50)
