# Minimal Chemprop Example in Jupyter

This notebook shows how to:
1. **Load** RDKit molecules, features, and labels from a pickle file.
2. **Create** `MoleculeDatapoint` objects via `create_data_points`.
3. **Build** a Chemprop MPNN model.
4. **Train** the model (no validation set, purely training).

We'll also **print** the contents of `mols`, `features`, and `labels` to demonstrate the data being passed in.


In [63]:
import pickle

import numpy as np
from chemprop import data, models, nn
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

################################
# 1) Minimal Helper Function
################################
def create_data_points(mols, labels, features, how='features'):
    """
    Create a list of chemprop.data.MoleculeDatapoint objects, each holding:
      - an RDKit Mol
      - a label y
      - optional features (V_f or V_d) shaped as (num_atoms, -1).
    """
    data_points = []
    for mol, y, feat in zip(mols, labels, features):
        num_atoms = mol.GetNumAtoms()
        if feat.shape[0] != num_atoms:
            raise ValueError(
                f"FATAL ERROR: Features array has length {feat.shape[0]}, "
                f"but the molecule has {num_atoms} atoms."
            )

        feat_reshaped = feat.reshape(num_atoms, -1)

        if how == 'no extra':
            dp = data.MoleculeDatapoint(mol=mol, y=np.array([y]))
        elif how == 'features':
            dp = data.MoleculeDatapoint(mol=mol, y=np.array([y]), V_f=feat_reshaped)
        elif how == 'descriptors':
            dp = data.MoleculeDatapoint(mol=mol, y=np.array([y]), V_d=feat_reshaped)
        else:
            raise ValueError(f"Unknown how={how}. Choose from ['no extra','features','descriptors'].")

        data_points.append(dp)

    return data_points

################################
# 2) Build dataset and dataloader
################################
def build_dataset_and_dataloader(moldata, batch_size=1):
    """
    Builds a MoleculeDataset from the data points and returns a dataloader.
    Dynamically detects extra_atom_fdim from the first data point's V_f (if any).
    """

    featurizer = SimpleMoleculeMolGraphFeaturizer(extra_atom_fdim=moldata[0].V_f.shape[1] if moldata[0].V_f is not None else 0)
    dataset = data.MoleculeDataset(moldata, featurizer=featurizer)
    loader = data.build_dataloader(dataset, batch_size=batch_size, num_workers=0)
    return dataset, loader

################################
# 3) Build a simple Chemprop model
################################
def build_model(dataset):
    mp = nn.BondMessagePassing(d_v=dataset.featurizer.atom_fdim, d_vd=dataset.d_vd)
    agg = nn.NormAggregation()
    ffn = nn.BinaryClassificationFFN(input_dim=mp.output_dim)
    model = models.MPNN(
        message_passing=mp,
        agg=agg,
        predictor=ffn,
        metrics=[nn.metrics.BinaryF1Score()]
    )
    return model

################################
# 4) Minimal training loop
################################
def run_training(model, loader, work_dir="/tmp", max_epochs=3):
    checkpointing = ModelCheckpoint(
        dirpath=f"{work_dir}/Chemprop_example",
        filename="best",
        monitor="val_loss",  # no val_loss if no validation data
        mode="min",
        save_last=True
    )
    trainer = Trainer(
        logger=False,
        enable_checkpointing=True,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1,
        max_epochs=max_epochs,
        callbacks=[checkpointing]
    )
    # Fit with no validation
    trainer.fit(model, loader)
    return trainer

################################
# 5) Minimal Main Logic
################################
def run_minimal_chemprop_example(pickle_file, how):
    # 1) Load the data from a pickle
    # Expecting each entry to be (mol, feature, label)
    with open(pickle_file, "rb") as f:
        all_data = pickle.load(f)

    # 2) Extract separate lists
    mols = [entry[0] for entry in all_data]
    features = [entry[1] for entry in all_data]
    labels = [entry[2] for entry in all_data]

    # Print them
    print("Mols:", mols)
    print("Number of atoms per mol:", [mol.GetNumAtoms() for mol in mols])
    print("\nFeatures:", features)
    print("\nLabels:", labels)

    # 3) Create data points
    datapoints = create_data_points(mols, labels, features, how=how)

    # 4) Build dataset + loader
    dataset, loader = build_dataset_and_dataloader(datapoints, batch_size=2)

    # 5) Build model
    model = build_model(dataset)

    # 6) Train
    trainer = run_training(model, loader, work_dir="/tmp", max_epochs=3)
    print("Finished minimal Chemprop training example.")

In [64]:
# 7) Actually run the code
# Provide the path to your minimal sample input pickle file.
pickle_file = "sample_Chemprop_input.pkl"
run_minimal_chemprop_example(pickle_file, how="no extra")

Mols: [<rdkit.Chem.rdchem.Mol object at 0xb7811bf6610>, <rdkit.Chem.rdchem.Mol object at 0xb7811bf72e0>]
Number of atoms per mol: [11, 12]

Features: [array([0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]

Labels: [1, 0]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params
------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K 
1 | agg             | NormAggregation         | 0     
2 | bn              | Identity                | 0     
3 | predictor       | BinaryClassificationFFN | 90.6 K
4 | X_d_transform   | Identity                | 0     
5 | metrics         | ModuleList              | 0     
------------------------------------------------------------
318 K     Trainable params
0         Non-trainable params
318 K     Total params
1.273     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


Finished minimal Chemprop training example.
