# Minimal Chemprop Example in Jupyter

This notebook shows how to:
1. **Load** RDKit molecules, features, and labels from a pickle file.
2. **Create** `MoleculeDatapoint` objects via `create_data_points`.
3. **Build** a Chemprop MPNN model.
4. **Train** the model (no validation set, purely training).

We'll also **print** the contents of `mols`, `features`, and `labels` to demonstrate the data being passed in.


In [None]:
import pickle

import numpy as np
from chemprop import data, models, nn
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint


################################
# 1) Minimal Helper Function
################################
def create_data_points(mols, labels, features, how='features'):
    """
    Create a list of chemprop.data.MoleculeDatapoint objects, each holding:
      - an RDKit Mol
      - a label y
      - optional extra features (V_f or V_d)
    """
    if how == 'no extra':
        data_points = [
            data.MoleculeDatapoint(mol=mol, y=np.array([y]))
            for mol, y, feature in zip(mols, labels, features)
        ]
    elif how == 'features':
        data_points = [
            data.MoleculeDatapoint(mol=mol, y=np.array([y]), V_f=feature)
            for mol, y, feature in zip(mols, labels, features)
        ]
    elif how == 'descriptors':
        data_points = [
            data.MoleculeDatapoint(mol=mol, y=np.array([y]), V_d=feature)
            for mol, y, feature in zip(mols, labels, features)
        ]
    else:
        raise ValueError(f"Unknown how={how}. Choose from ['no extra','features','descriptors'].")

    return data_points

################################
# 2) Build dataset and dataloader
################################
def build_dataset_and_dataloader(moldata, batch_size=1):
    """
    Builds a MoleculeDataset from the data points and returns a dataloader.
    """
    featurizer = SimpleMoleculeMolGraphFeaturizer()
    dataset = data.MoleculeDataset(moldata, featurizer)
    loader = data.build_dataloader(dataset, batch_size=batch_size, num_workers=0)
    return dataset, loader

################################
# 3) Build a simple Chemprop model
################################
def build_model():
    mp = nn.AtomMessagePassing()
    agg = nn.MeanAggregation()
    ffn = nn.BinaryClassificationFFN()
    model = models.MPNN(mp, agg, ffn, nn.metrics.BinaryF1Score())
    return model

################################
# 4) Minimal training loop
################################
def run_training(model, loader, work_dir="/tmp", max_epochs=3):
    checkpointing = ModelCheckpoint(
        dirpath=f"{work_dir}/Chemprop_example",
        filename="best",
        monitor="val_loss",  # no val_loss if no validation data
        mode="min",
        save_last=True
    )
    trainer = Trainer(
        logger=False,
        enable_checkpointing=True,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1,
        max_epochs=max_epochs,
        callbacks=[checkpointing]
    )
    # Fit with no validation
    trainer.fit(model, loader)
    return trainer

################################
# 5) Minimal Main Logic
################################
def run_minimal_chemprop_example(pickle_file, how):
    # 1) Load the data from a pickle
    # Expecting each entry to be (mol, feature, label)
    with open(pickle_file, "rb") as f:
        all_data = pickle.load(f)

    # 2) Extract separate lists
    mols = [entry[0] for entry in all_data]
    features = [entry[1] for entry in all_data]
    labels = [entry[2] for entry in all_data]

    # Print them
    print("Mols:", mols)
    print("\nFeatures:", features)
    print("\nLabels:", labels)

    # 3) Create data points
    datapoints = create_data_points(mols, labels, features, how)

    # 4) Build dataset + loader
    dataset, loader = build_dataset_and_dataloader(datapoints, batch_size=2)

    # 5) Build model
    model = build_model()

    # 6) Train
    trainer = run_training(model, loader, work_dir="/tmp", max_epochs=3)
    print("Finished minimal Chemprop training example.")


In [None]:
# 7) Actually run the code
# Provide the path to your minimal sample input pickle file.
pickle_file = "sample_Chemprop_input.pkl"
run_minimal_chemprop_example(pickle_file, how="features")