In [None]:
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
from sklearn.model_selection import KFold
from chemprop import data, featurizers, models, nn
import torch
import chemprop.nn.metrics as chem_metrics
# Custom function for scaffold splitting (use your own function)
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import MolFromSmiles
import concurrent.futures
import logging
from chemprop.data import datapoints, dataloader
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
from chemprop.data import MoleculeDataset
from chemprop.data.datapoints import MoleculeDatapoint
num_workers = 12
# Define the function to create MoleculeDatapoints
def create_molecule_datapoints(smiles, targets):
    return [
        MoleculeDatapoint.from_smi(smi, y=target.tolist())
        for smi, target in zip(smiles, targets)
    ]

df_train = pd.read_csv('./NIHDataset/train_df_class.csv')
df_test = pd.read_csv('./NIHDataset/test_df_class.csv')
df_val = pd.read_csv('./NIHDataset/val_df_class.csv')

# Column names
smiles_column = 'smiles'
target_columns = ['LD50_class']
smis_test = df_test[smiles_column].values  # SMILES from test (external)
ys_test = df_test[target_columns].values  # Targets from test (external)

# Split data

train_smis = df_train[smiles_column].values

train_targets = df_train[target_columns].values
val_smis = df_val[smiles_column].values
val_targets   = df_val[target_columns].values
ys_test       = df_test[target_columns].values

train_dp = create_molecule_datapoints(train_smis, train_targets)
val_dp = create_molecule_datapoints(val_smis, val_targets)
test_dp = create_molecule_datapoints(smis_test, ys_test)

train_dset = MoleculeDataset(train_dp)
val_dset = MoleculeDataset(val_dp)
test_dset = MoleculeDataset(test_dp)
pl.seed_everything(42, workers=True)

# Dataloaders

train_loader = data.build_dataloader(train_dset,num_workers=num_workers,shuffle=True)

val_loader = data.build_dataloader(val_dset,num_workers=num_workers,shuffle=False)

test_loader = data.build_dataloader(test_dset,num_workers=num_workers,shuffle=False)

ffn = nn.MulticlassClassificationFFN(output_transform=None, n_classes=4)
batch_norm = False
mp = nn.BondMessagePassing(depth=4,dropout=0.1)

agg = nn.NormAggregation()

mpnn = models.MPNN(mp,agg,ffn,batch_norm,
                   metrics=[chem_metrics.MulticlassMCCMetric()])
ckpt_dir = Path(f"./checkpointsclassificationChemprop/")
ckpt_dir.mkdir(parents=True, exist_ok=True)



checkpoint_cb = ModelCheckpoint(
    dirpath=ckpt_dir,
    filename="best-{epoch}-{val/multiclass-mcc:.4f}",
    monitor="val/multiclass-mcc",
    mode="max",
    save_last=True
)
# -------------------------
# Trainer
# -------------------------
trainer = pl.Trainer(accelerator="auto",devices=1,
                     max_epochs=100,logger=False,
                     callbacks=[checkpoint_cb],
                     enable_progress_bar=True)


# Train

trainer.fit(mpnn, train_loader, val_loader)

external_test_metrics = []

preds = trainer.predict(
    mpnn,
    dataloaders=test_loader,
    ckpt_path="best",
    weights_only=False
)

logits = torch.cat(preds, dim=0)   # (N,1,4)
logits = logits.squeeze(1)          # (N,4)

probs = torch.softmax(logits, dim=1).cpu().numpy()
y_pred = torch.argmax(logits, dim=1).cpu().numpy()
y_true = ys_test.flatten()

print(logits.shape, y_pred.shape, y_true.shape)

test_results = pd.DataFrame({
    "actual": y_true,
    "predicted": y_pred
})

for i in range(probs.shape[1]):
    test_results[f"prob_class_{i}"] = probs[:, i]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, digits=4))