# Training

# Import packages

In [None]:
import pandas as pd
from pathlib import Path
import numpy as np
import ast

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

from chemprop import data, featurizers, models, nn
from chemprop.utils import make_mol

# Change data inputs here

In [None]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "atomic_regression_atom_mapped_input_copy.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['charges', 'charges2'] # list of names of the columns containing targets

## Load data

In [None]:
df_input = pd.read_csv(input_path)
df_input

## Get SMILES and targets

In [None]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns]

In [None]:
smis[:2] # show first 2 SMILES strings

In [None]:
ys[:5] # show first 5 molecule targets

In [None]:
Y = []
for molecule in range(len(ys)):
    list_props = []
    for prop in target_columns:
        np_prop = np.array(ast.literal_eval(ys.iloc[molecule][prop]))
        np_prop = np.expand_dims(np_prop, axis=1)
        list_props.append(np_prop)
    Y.append(np.hstack(list_props))

## Get molecule datapoints

In [None]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y, keep_h = True) for smi, y in zip(smis, Y)]

## Perform data splitting for training, validation, and testing

In [None]:
# available split types
list(data.SplitType.keys())

In [None]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

## Get AtomDataset

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.AtomDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.AtomDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)

test_dset = data.AtomDataset(test_data, featurizer)

all_dset = data.AtomDataset(all_data, featurizer)
slices = all_dset._slices

## Get DataLoader

In [None]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
all_loader = data.build_dataloader(all_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [None]:
mp = nn.BondMessagePassing() #include why aggregation isn't used

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [None]:
print(nn.PredictorRegistry)

In [None]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [None]:
ffn = nn.RegressionFFN(output_transform=output_transform, n_tasks=2)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [None]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [None]:
print(nn.metrics.MetricRegistry)

In [None]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [None]:
agg = nn.NoAggregation()
atom_mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

atom_mpnn

# Set up trainer

In [None]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50, # number of epochs to train for
    callbacks=[checkpointing],
)

# Start training

In [None]:
trainer.fit(atom_mpnn, train_loader, val_loader)

# Test results

In [None]:
results = trainer.test(atom_mpnn, test_loader)

# Predictions

In [None]:
from chemprop.models import load_model
import torch

individual_preds = []
model = load_model(checkpointing.best_model_path, False, False)
trainer = pl.Trainer(
    logger=False,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

predss = trainer.predict(model, all_loader)
individual_preds.append(torch.concat(predss, 0))

average_preds = torch.mean(torch.stack(individual_preds).float(), dim=0)
test_path = chemprop_dir / "tests" / "data" / "atomic_regression_atom_mapped_input_copy.csv"
df_test = pd.read_csv(test_path, header="infer", index_col=False)

## Loaded Model

In [None]:
model

In [None]:
for i in range(len(df_test)):
    first_atom = slices.index(i)
    last_atom = first_atom + slices.count(i)
    mol_avg_preds = average_preds[first_atom:last_atom]
    df_test.loc[i, target_columns] = [str(mol_avg_preds[:,j].tolist()) for j in range(len(target_columns))]

#torch.split(average_preds, split_size_or_sections=torch.bincount(torch.tensor(slices)).tolist(), dim=0)

output_path = chemprop_dir / "tests" / "data" / "atomic_regression_atom_mapped_output.csv"
if output_path.suffix == ".pkl":
    df_test = df_test.reset_index(drop=True)
    df_test.to_pickle(output_path)
else:
    df_test.to_csv(output_path, index=False)

df_test