### Example script for training MPNN-POM model

In [20]:
import deepchem as dc
from tqdm import tqdm
# from openpom.feat.graph_featurizer import GraphFeaturizer, GraphConvConstants
# from openpom.utils.data_utils import get_class_imbalance_ratio
# from openpom.models.mpnn_pom import MPNNPOMModel
from datetime import datetime
from sklearn.metrics import roc_auc_score
import pandas as pd

### Utils

#### Optimizer

In [21]:
from deepchem.models.optimizers import Optimizer
from deepchem.models.optimizers import Adam
from deepchem.models.optimizers import AdaGrad
from deepchem.models.optimizers import AdamW
from deepchem.models.optimizers import SparseAdam
from deepchem.models.optimizers import RMSProp
from deepchem.models.optimizers import GradientDescent
from deepchem.models.optimizers import KFAC


def get_optimizer(optimizer_name: str = 'adam') -> Optimizer:
    """
    Get deepchem optimizer object

    Parameters
    ---------
    optimizer_name: str
      optimizer name
      choices: [adam, adagrad, adamw, sparseadam, rmsprop, sgd, kfac]
      default: 'adam'

    Returns
    -------
    Optimizer
      Deepchem optimizer object
    """
    if optimizer_name == 'adam':
        return Adam()
    elif optimizer_name == 'adagrad':
        return AdaGrad()
    elif optimizer_name == 'adamw':
        return AdamW()
    elif optimizer_name == 'sparseadam':
        return SparseAdam()
    elif optimizer_name == 'rmsprop':
        return RMSProp()
    elif optimizer_name == 'sgd':
        return GradientDescent()
    elif optimizer_name == 'kfac':
        return KFAC()
    else:
        print("INVALID OPTIMISER NAME!, using ADAM optimizer by default")
        return Adam()


#### Molecule_Feature_Utils

In [22]:
from typing import List
from deepchem.utils.typing import RDKitAtom
from deepchem.utils.molecule_feature_utils import one_hot_encode


def get_atomic_num_one_hot(atom: RDKitAtom,
                           allowable_set: List[int],
                           include_unknown_set: bool = True) -> List[float]:
    """
    Get a one-hot feature about atomic number of the given atom.

    Parameters
    ---------
    atom: RDKitAtom
        RDKit atom object
    allowable_set: List[int]
        The range of atomic numbers to consider.
    include_unknown_set: bool, default False
        If true, the index of all types not in
        `allowable_set` is `len(allowable_set)`.

    Returns
    -------
    List[float]
        A one-hot vector of atomic number of the given atom.
        If `include_unknown_set` is False, the length is
        `len(allowable_set)`.
        If `include_unknown_set` is True, the length is
        `len(allowable_set) + 1`.

    """
    return one_hot_encode(atom.GetAtomicNum() - 1, allowable_set,
                          include_unknown_set)


def get_atom_total_valence_one_hot(
        atom: RDKitAtom,
        allowable_set: List[int],
        include_unknown_set: bool = True) -> List[float]:
    """Get a one-hot feature for total valence of an atom.

    Parameters
    ---------
    atom: rdkit.Chem.rdchem.Atom
        RDKit atom object
    allowable_set: List[int]
        Atom total valence to consider.
    include_unknown_set: bool, default True
        If true, the index of all types not in
        `allowable_set` is `len(allowable_set)`.

    Returns
    -------
    List[float]
        A one-hot vector for total valence an atom has.
        If `include_unknown_set` is False, the length is
        `len(allowable_set)`.
        If `include_unknown_set` is True, the length is
        `len(allowable_set) + 1`.

    """
    return one_hot_encode(atom.GetTotalValence(), allowable_set,
                          include_unknown_set)


#### Loss

In [23]:
import torch
from typing import Optional, Callable, List
from deepchem.models.losses import Loss

class CustomMultiLabelLoss(Loss):
    """
    Custom Multi-Label Loss function for multi-label classification.

    The objective function is a summed cross-entropy loss over all tasks,
    with each task's contribution to the loss being weighted by a factor
    of log(1+ class_imbalance_ratio), such that rarer tasks were given
    a higher weighting.

    This loss function is based on:
    `A Principal Odor Map Unifies Diverse Tasks in Human Olfactory Perception
    preprint <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.

    The labels should have shape (batch_size) or (batch_size, tasks), and be
    integer class labels. The outputs have shape (batch_size, classes) or
    (batch_size, tasks, classes) and be logits that are converted to
    probabilities using a softmax function.
    """

    def __init__(self,
                 class_imbalance_ratio: Optional[List] = None,
                 loss_aggr_type: str = 'sum',
                 sample_weights: Optional[List[float]] = None,
                 device: Optional[str] = None):
        """
        Parameters
        ----------
        class_imbalance_ratio: Optional[List]
            List of class imbalance ratios.
        loss_aggr_type: str
            Loss aggregation type; 'sum' or 'mean'.
        sample_weights: Optional[List[float]]
            List of sample weights for each data point.
        device: Optional[str]
            The device on which to run computations. If None, a device is
            chosen automatically.
        """
        super(CustomMultiLabelLoss, self).__init__()
        if class_imbalance_ratio is None:
            print(Warning("No class imbalance ratio provided!"))
            self.class_imbalance_ratio: Optional[torch.Tensor] = None
        else:
            self.class_imbalance_ratio = torch.Tensor(class_imbalance_ratio)

        self.sample_weights: Optional[torch.Tensor] = None
        if sample_weights is not None:
            self.sample_weights = torch.Tensor(sample_weights)

        if loss_aggr_type not in ['sum', 'mean']:
            raise ValueError(f"Invalid loss aggregate type: {loss_aggr_type}")
        self.loss_aggr_type: str = loss_aggr_type

        if device is not None:
            if self.class_imbalance_ratio is not None:
                self.class_imbalance_ratio = self.class_imbalance_ratio.to(device)
            if self.sample_weights is not None:
                self.sample_weights = self.sample_weights.to(device)

    def _create_pytorch_loss(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        """
        Returns loss function for pytorch backend
        """
        ce_loss_fn: torch.nn.CrossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='none')

        def loss(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
            """
            The objective function is a summed cross-entropy loss over all
            tasks, with each task's contribution to the loss being weighted
            by a factor of log(1+ class_imbalance_ratio), such that rarer
            tasks were given a higher weighting.

            Parameters
            ----------
            output: torch.Tensor
                Output logits from model's forward pass per batch.
            labels: torch.Tensor
                Target labels per batch.

            Returns
            -------
            loss: torch.Tensor
                Total or mean loss depending on loss aggregation type.
            """
            # Convert (batch_size, tasks, classes) to (batch_size, classes, tasks)
            if len(output.shape) == 3:
                output = output.permute(0, 2, 1)

            if len(labels.shape) == len(output.shape):
                labels = labels.squeeze(-1)

            # Handle multilabel
            probabilities: torch.Tensor = output[:, 0, :]
            complement_probabilities: torch.Tensor = 1 - probabilities
            binary_output: torch.Tensor = torch.stack([complement_probabilities, probabilities], dim=1)

            ce_loss: torch.Tensor = ce_loss_fn(binary_output, labels.long())

            if self.class_imbalance_ratio is None:
                if self.loss_aggr_type == 'sum':
                    loss: torch.Tensor = ce_loss.sum(dim=1)
                else:
                    loss = ce_loss.mean(dim=1)
            else:
                balancing_factors: torch.Tensor = torch.log(1 + self.class_imbalance_ratio)
                balanced_losses: torch.Tensor = torch.mul(ce_loss, balancing_factors)

                if self.loss_aggr_type == 'sum':
                    loss = balanced_losses.sum(dim=1)
                else:
                    loss = balanced_losses.mean(dim=1)

            if self.sample_weights is not None:
                batch_sample_weights = self.sample_weights[:loss.size(0)].to(loss.device)
                loss = loss * batch_sample_weights

            return loss.unsqueeze(-1).repeat(1, output.shape[-1])

        return loss


#### Data_Utils

In [24]:
import tempfile
import pandas as pd
import numpy as np
from typing import List, Optional, Tuple, Iterator
from deepchem.data.datasets import DiskDataset, NumpyDataset
from skmultilearn.model_selection import IterativeStratification
from deepchem.splits import Splitter


def get_class_imbalance_ratio(dataset: DiskDataset) -> List:
    """
    Get imbalance ratio per task from DiskDataset

    Imbalance ratio per label (IRLbl): Let M be an MLD with a set of
    labels L and Yi be the label-set of the ith instance. IRLbl is calcu-
    lated for the label λ as the ratio between the majority label and
    the label λ, where IRLbl is 1 for the most frequent label and a
    greater value for the rest. The larger the value of IRLbl, the higher
    the imbalance level for the concerned label.

    Parameters
    ---------
    dataset: DiskDataset
        Deepchem diskdataset object to get class imbalance ratio

    Returns
    -------
    class_imbalance_ratio: List
        List of imbalance ratios per task

    References
    ----------
    .. TarekegnA.N. et al.
       "A review of methods for imbalanced multi-label classification"
       Pattern Recognit. (2021)
    """
    if not isinstance(dataset, DiskDataset) and not isinstance(dataset, NumpyDataset):
        raise Exception("The dataset should be a deepchem DiskDataset or NumpyDataset")
    df: pd.DataFrame = pd.DataFrame(dataset.y)
    class_counts: np.ndarray = df.sum().to_numpy()
    max_count: int = max(class_counts)
    class_imbalance_ratio: List = (class_counts / max_count).tolist()
    return class_imbalance_ratio


class IterativeStratifiedSplitter(Splitter):
    """
    Iteratively stratify a multi-label data set into folds/splits.

    Construct an iterative stratifier that splits the dataset
    trying to maintain balanced representation with respect to
    order-th label combinations.

    Available splits:
        - train_valid_test_split()
        - train_test_split()

    Note:
        Requires `skmultilearn` library to be installed.
    """

    def __init__(self, order: int = 2) -> None:
        """
        Parameters
        ---------
        order: int
            order for iterative stratification (default: 2)
        """
        self.order: int = order

    def split(
        self,
        dataset: DiskDataset,
        frac_train: float = 0.8,
        frac_valid: float = 0.1,
        frac_test: float = 0.1,
        seed: Optional[int] = None,
        log_every_n: Optional[int] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Return indices for iterative stratified split

        Parameters
        ----------
        dataset: dc.data.Dataset
            Dataset to be split.
        seed: int, optional (default None)
            Random seed to use.
        frac_train: float, optional (default 0.8)
            The fraction of data to be used for the training split.
        frac_valid: float, optional (default 0.1)
            The fraction of data to be used for the validation split.
        frac_test: float, optional (default 0.1)
            The fraction of data to be used for the test split.
        log_every_n: int, optional (default None)
            Controls the logger by dictating how often logger outputs
            will be produced.

        Returns
        -------
        Tuple[np.ndarray, np.ndarray, np.ndarray]
            A tuple `(train_indices, valid_indices, test_indices)`
            for the various splits.
        """
        X1: pd.DataFrame
        y1: pd.DataFrame
        X1, y1 = pd.DataFrame(dataset.X), pd.DataFrame(dataset.y)
        stratifier1: IterativeStratification = IterativeStratification(
            n_splits=2,
            order=self.order,
            sample_distribution_per_fold=[frac_test + frac_valid, frac_train],
            # shuffle=True,
            random_state=seed,
        )

        train_indices: np.ndarray
        other_indices: np.ndarray
        train_indices, other_indices = next(stratifier1.split(X1, y1))

        temp_dir: str = tempfile.mkdtemp()
        other_dataset: DiskDataset = dataset.select(other_indices.tolist(),
                                                    temp_dir)

        X2: pd.DataFrame
        y2: pd.DataFrame
        X2, y2 = pd.DataFrame(other_dataset.X), pd.DataFrame(other_dataset.y)
        new_split_ratio: float = round(frac_test / (frac_test + frac_valid), 2)
        stratifier2: IterativeStratification = IterativeStratification(
            n_splits=2,
            order=self.order,
            sample_distribution_per_fold=[
                new_split_ratio, 1 - new_split_ratio
            ],
            random_state=seed,
        )

        valid_indices: np.ndarray
        test_indices: np.ndarray
        valid_indices, test_indices = next(stratifier2.split(X2, y2))
        return train_indices, valid_indices, test_indices

    def k_fold_split(
        self,
        dataset: DiskDataset,
        k: int,
        directories: Optional[List[str]] = None
    ) -> List[Tuple[DiskDataset, DiskDataset]]:
        """
        Parameters
        ----------
        dataset: DiskDataset
            DiskDataset to do a k-fold split
        k: int
            Number of folds to split `DiskDataset` into. (k>1)
        directories: List[str], optional (default None)
            List of length 2*k filepaths to save the result disk-datasets.

        Returns
        -------
        List[Tuple[DiskDataset, DiskDataset]]
            List of length k tuples of (train, cv)
            where `train` and `cv` are both `DiskDataset`.
        """
        assert k != 1
        if directories is None:
            directories = [tempfile.mkdtemp() for _ in range(2 * k)]
        else:
            assert len(directories) == 2 * k

        X: pd.DataFrame
        y: pd.DataFrame
        X, y = pd.DataFrame(dataset.X), pd.DataFrame(dataset.y)
        stratifier: IterativeStratification = IterativeStratification(
            n_splits=k,
            order=self.order,
        )

        train_datasets: List = []
        cv_datasets: List = []
        split_gen: Iterator = stratifier.split(X, y)
        for fold in range(k):
            train_dir, cv_dir = directories[2 * fold], directories[2 * fold +
                                                                   1]
            train_indices: np.ndarray
            cv_indices: np.ndarray
            train_indices, cv_indices = next(split_gen)
            train_dataset: DiskDataset = dataset.select(
                train_indices.tolist(), train_dir)
            cv_dataset: DiskDataset = dataset.select(cv_indices.tolist(),
                                                     cv_dir)
            train_datasets.append(train_dataset)
            cv_datasets.append(cv_dataset)
        return list(zip(train_datasets, cv_datasets))


In [25]:
def get_sample_weights(dataset: DiskDataset, intersection: set) -> List[float]:
    """
    Generate random sample weights for each sample in the dataset.

    Parameters
    ----------
    dataset: DiskDataset
        Deepchem diskdataset object to get sample weights.
    seed: Optional[int]
        Random seed for reproducibility.

    Returns
    -------
    sample_weights: List[float]
        List of random weights for each sample in the dataset.
    """
    sample_weights = []
    for elt in dataset.X[:,1]:
        # if the smiles is in the challenge intersection change sample weight to 2
        if elt in intersection:
            sample_weights.append(2)
        else:
            sample_weights.append(1)
    return sample_weights

### GraphFeaturizer

In [26]:
import numpy as np
import re
from rdkit import Chem
from typing import List, Union, Dict, Sequence
from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.feat.graph_data import GraphData
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils \
    import get_atom_formal_charge_one_hot
from deepchem.utils.molecule_feature_utils \
    import get_atom_total_num_Hs_one_hot
from deepchem.utils.molecule_feature_utils \
    import get_atom_hybridization_one_hot
# from openpom.utils.molecule_feature_utils \
#     import get_atomic_num_one_hot, get_atom_total_valence_one_hot
import logging

logger = logging.getLogger(__name__)

# get the smile in 
def smilesAtomReorder(atoms,smiles):
    atoms_reorder = []
    i = 0
    while i < len(smiles):
        # handle 2 letters atoms
        if i + 2 <= len(smiles):
            symbol2 = smiles[i:i+2]
            len2App = False
            if symbol2.isalpha():
                for atom in atoms:
                    if atom.GetSymbol() == symbol2:
                        atoms_reorder.append(atom)
                        atoms.remove(atom)
                        i+=2
                        len2App = True
                        break
            if len2App:
                continue
        # append single letter atoms
        symbol = smiles[i]
        if symbol.isalpha():
            for atom in atoms:
                if atom.GetSymbol() == symbol or atom.GetSymbol().lower()==symbol:
                    atoms.remove(atom)
                    atoms_reorder.append(atom)
                    break
        i+=1
                
    return atoms_reorder
                    

class GraphConvConstants(object):
    """
    A class for holding featurization parameters.
    """

    MAX_ATOMIC_NUM = 100
    ATOM_FEATURES: Dict[str, List[int]] = {
        'valence': [0, 1, 2, 3, 4, 5, 6],
        'degree': [0, 1, 2, 3, 4, 5],
        'num_Hs': [0, 1, 2, 3, 4],
        'formal_charge': [-1, -2, 1, 2, 0],
        'atomic_num': list(range(MAX_ATOMIC_NUM)),
    }
    ATOM_FEATURES_HYBRIDIZATION: List[str] = [
        "SP", "SP2", "SP3", "SP3D", "SP3D2"
    ]
    # 3D Coordinate Information

    # Dimension of atom feature vector + 3D coordinate info
    ATOM_FDIM = sum(len(choices) + 1
                    for choices in ATOM_FEATURES.values()) + len(
                        ATOM_FEATURES_HYBRIDIZATION) + 1
    # len(choices) +1 and len(ATOM_FEATURES_HYBRIDIZATION)
    # + 1 to include room for unknown set
    BOND_FDIM = 6


def atom_features(atom: RDKitAtom) -> Sequence[Union[bool, int, float]]:
    """
    Helper method used to compute atom feature vector.

    Parameters
    ----------
    atom: RDKitAtom
        Atom to compute features on.

    Returns
    -------
    features: Sequence[Union[bool, int, float]]
        A list of atom features.
    """
    if atom is None:
        features: Sequence[Union[bool, int,
                                 float]] = [0] * GraphConvConstants.ATOM_FDIM

    else:
        features = []
        features += get_atom_total_valence_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES['valence'])
        features += get_atom_total_degree_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES['degree'])
        features += get_atom_total_num_Hs_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES['num_Hs'])
        features += get_atom_formal_charge_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES['formal_charge'])
        features += get_atomic_num_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES['atomic_num'])
        features += get_atom_hybridization_one_hot(
            atom, GraphConvConstants.ATOM_FEATURES_HYBRIDIZATION, True)
        features = [int(feature) for feature in features]
        # Add 3D information at the end
        # features += getAtom3dInfo(smiles,idx)
        # features += [0,0,0]
    return features


def bond_features(bond: RDKitBond) -> Sequence[Union[bool, int, float]]:
    """
    Helper method used to compute bond feature vector.

    Parameters
    ----------
    bond: RDKitBond
        Bond to compute features on.

    Returns
    -------
    features: Sequence[Union[bool, int, float]]
        A list of bond features.
    """
    if bond is None:
        b_features: Sequence[Union[
            bool, int, float]] = [1] + [0] * (GraphConvConstants.BOND_FDIM - 1)

    else:
        bt = bond.GetBondType()
        b_features = [
            0, bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            bond.IsInRing()
        ]

    return b_features


class GraphFeaturizer(MolecularFeaturizer):
    """
    This class is a featurizer for GNN (MESSAGE PASSING) implementation for
    Principal Odor Map.

    The default node(atom) and edge(bond) representations are based on
    `A Principal Odor Map Unifies Diverse Tasks in Human Olfactory Perception
    preprint <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.

    The default node representation are constructed by concatenating
    the following values, and the feature length is 134.

    - Valence: A one-hot vector for total valence (0-6) of an atom.
    - Degree: A one-hot vector of the degree (0-5) of this atom.
    - Number of Hydrogens: A one-hot vector of the number of hydrogens
      (0-4) that this atom connected.
    - Formal charge: Integer electronic charge, -1, -2, 1, 2, 0.
    - Atomic num: A one-hot vector of this atom, in a range of first 100 atoms.
    - Hybridization: A one-hot vector of "SP", "SP2", "SP3", "SP3D", "SP3D2".

    The default edge representation are constructed by concatenating
    the following values, and the feature length is 6.

    - Bond type: A one-hot vector of the bond type,
      "single", "double", "triple", or "aromatic".
    - Is in ring: Boolean value to specify whether
      the bond is in a ring or not.

    If you want to know more details about features,
    please check the paper [1]_ and utilities in
    deepchem.utils.molecule_feature_utils.py.

    References
    ----------
    .. [1] Kearnes, Steven, et al.
       "Molecular graph convolutions: moving beyond fingerprints."
        Journal of computer-aided molecular design 30.8 (2016):595-608.

    Note
    ----
    This class requires RDKit to be installed.

    """

    def __init__(self, is_adding_hs=False):
        """
        Parameters
        ----------
        is_adding_hs: bool, default False
            Whether to add Hs or not.
        """
        self.is_adding_hs = is_adding_hs
        super(GraphFeaturizer).__init__()

    def _construct_bond_index(self, datapoint: RDKitMol) -> np.ndarray:
        """
        Construct edge (bond) index

        Parameters
        ----------
        datapoint: RDKitMol
            RDKit mol object.

        Returns
        -------
        edge_index: np.ndarray
            Edge (Bond) index

        """
        src: List[int] = []
        dest: List[int] = []
        for bond in datapoint.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            src += [start, end]
            dest += [end, start]
        return np.asarray([src, dest], dtype=int)

    def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData:
        """Calculate molecule graph features from RDKit mol object.

        Parameters
        ----------
        datapoint: RDKitMol
            RDKit mol object.

        Returns
        -------
        graph: GraphData
            A molecule graph object with features:
            - node_features: Node feature matrix with shape
              [num_nodes, num_node_features]
            - edge_index: Graph connectivity in COO format with shape
              [2, num_edges]
            - edge_features: Edge feature matrix with shape
              [num_edges, num_edge_features]
        """
        if isinstance(datapoint, Chem.rdchem.Mol):
            if self.is_adding_hs:
                datapoint = Chem.AddHs(datapoint)
        else:
            raise ValueError(
                "Feature field should contain smiles for featurizer!")

        # get atom features
        smiles = Chem.MolToSmiles(datapoint)
        # atoms = list(datapoint.GetAtoms())

        # atoms_reorder = smilesAtomReorder(atoms,smiles)
        
        # atoms_reorder = smilesAtomReorder(atoms,Chem.MolToSmiles(datapoint))
        f_atoms: np.ndarray = np.asarray(
            [atom_features(atom) for atom in datapoint.GetAtoms()],
            dtype=float)
        # print(smiles)
        # for atom in atoms_reorder:
        #     print(atom.GetSymbol())
        
        # get edge(bond) features
        if len(datapoint.GetBonds()) == 0:
            f_bonds: np.ndarray = np.empty((0, GraphConvConstants.BOND_FDIM))
        else:
            f_bonds_list = []
            for bond in datapoint.GetBonds():
                b_feat = 2 * [bond_features(bond)]
                f_bonds_list.extend(b_feat)
            f_bonds = np.asarray(f_bonds_list, dtype=float)

        # get edge index
        edge_index: np.ndarray = self._construct_bond_index(datapoint)

        return (GraphData(node_features=f_atoms,
                         edge_index=edge_index,
                         edge_features=f_bonds),Chem.MolToSmiles(datapoint))

### pom_ffn

In [27]:
import torch
import torch.nn as nn
from typing import List, Optional, Callable, Any


class CustomPositionwiseFeedForward(nn.Module):
    """
    Customised PositionwiseFeedForward layer from deepchem
    for:
        - hidden layers of variable sizes
        - batch normalization before every activation function
        - additional output of embedding layer (penultimate layer)
          for POM embeddings.
    """

    def __init__(
        self,
        d_input: int = 1024,
        d_hidden_list: List = [1024],
        d_output: int = 1024,
        activation: str = 'leakyrelu',
        dropout_p: float = 0.0,
        dropout_at_input_no_act: bool = False,
        batch_norm: bool = True,
    ):
        """Initialize a PositionwiseFeedForward layer.

        Parameters
        ----------
        d_input: int
            Size of input layer.
        d_hidden_list: List
            List of hidden sizes.
        d_output: int (same as d_input if d_output = 0)
            Size of output layer.
        activation: str
            Activation function to be used. Can choose between 'relu' for ReLU,
            'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
            'tanh' for TanH, 'selu' for SELU, 'elu' for ELU
            and 'linear' for linear activation.
        dropout_p: float
            Dropout probability.
        dropout_at_input_no_act: bool
            If true, dropout is applied on the input tensor.
            For single layer, it is not passed to an activation function.
        batch_norm: bool
            If true, applies batch normalization
            'before' every activation function
        """
        super(CustomPositionwiseFeedForward, self).__init__()

        self.dropout_at_input_no_act: bool = dropout_at_input_no_act
        self.batch_norm: bool = batch_norm

        self.activation: Callable[[Any], Any]
        if activation == 'relu':
            self.activation = nn.ReLU()

        elif activation == 'leakyrelu':
            self.activation = nn.LeakyReLU(0.1)

        elif activation == 'prelu':
            self.activation = nn.PReLU()

        elif activation == 'tanh':
            self.activation = nn.Tanh()

        elif activation == 'selu':
            self.activation = nn.SELU()

        elif activation == 'elu':
            self.activation = nn.ELU()

        elif activation == 'linear':
            self.activation = lambda x: x

        d_output = d_output if d_output != 0 else d_input

        # Set n_layers
        self.n_layers: int = len(d_hidden_list) + 1

        # Set linear layers
        if self.n_layers == 1:
            linears: List = [nn.Linear(d_input, d_output)]

        else:
            linears = [nn.Linear(d_input, d_hidden_list[0])]
            for idx in range(1, len(d_hidden_list)):
                linears.append(
                    nn.Linear(d_hidden_list[idx - 1], d_hidden_list[idx]))
            linears.append(nn.Linear(d_hidden_list[-1], d_output))

        self.linears: nn.ModuleList = nn.ModuleList(linears)
        dropout_layer: nn.Dropout = nn.Dropout(dropout_p)
        self.dropout_p: nn.ModuleList = nn.ModuleList(
            [dropout_layer for _ in range(self.n_layers)])

        if batch_norm:
            batchnorms: List = [
                nn.BatchNorm1d(d_hidden_list[idx])
                for idx in range(len(d_hidden_list))
            ]
            self.batchnorms: nn.ModuleList = nn.ModuleList(batchnorms)

    def forward(self, x: torch.Tensor) -> List[Optional[torch.Tensor]]:
        """
        Output Computation for the Customised
        PositionwiseFeedForward layer

        Parameters
        ----------
        x: torch.Tensor
            Input tensor

        Returns
        -------
        List[Optional[torch.Tensor]]
            List containing embeddings and output
        """

        if self.n_layers == 1:
            if self.dropout_at_input_no_act:
                return [None, self.linears[0](self.dropout_p[0](x))]
            else:
                return [
                    None,
                    self.dropout_p[0](self.activation(self.linears[0](x)))
                ]

        else:
            if self.dropout_at_input_no_act:
                x = self.dropout_p[-1](x)

            if self.batch_norm:
                for i in range(self.n_layers - 2):
                    x = self.dropout_p[i](self.activation(self.batchnorms[i](
                        self.linears[i](x))))

                embeddings: torch.Tensor = self.linears[self.n_layers - 2](x)
                x = self.dropout_p[self.n_layers - 2](self.activation(
                    self.batchnorms[self.n_layers - 2](embeddings)))
            else:
                for i in range(self.n_layers - 2):
                    x = self.dropout_p[i](self.activation(self.linears[i](x)))

                embeddings = self.linears[self.n_layers - 2](x)
                x = self.dropout_p[self.n_layers - 2](
                    self.activation(embeddings))

            output: torch.Tensor = self.linears[-1](x)
            return [embeddings, output]


### pom_mpnn_gnn

In [28]:
import torch.nn as nn
from dgl.nn.pytorch import NNConv
from dgllife.model.gnn import MPNNGNN


class CustomMPNNGNN(MPNNGNN):
    """
    Customized MPNNGNN layer based MPNNGNN layer in dgllife library.

    Additional options:
    -> toggle for residual in gnn layer
    -> choice for message aggregator type

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    This class performs message passing in MPNN
    and returns the updated node representations.
    """

    def __init__(self,
                 node_in_feats: int = 50,
                 edge_in_feats: int = 50,
                 node_out_feats: int = 64,
                 edge_hidden_feats: int = 128,
                 num_step_message_passing: int = 6,
                 residual: bool = True,
                 message_aggregator_type: str = 'sum'):
        """
        Parameters
        ----------
        node_in_feats: int
            Size for the input node features.
        node_out_feats: int
            Size for the output node representations. Default to 64.
        edge_in_feats: int
            Size for the input edge features. Default to 128.
        edge_hidden_feats: int
            Size for the hidden edge representations.
        num_step_message_passing: int
            Number of message passing steps. Default to 6.
        residual: bool
            If true, adds residual layer to gnn layer
        message_aggregator_type: str
            message aggregator type, 'sum', 'mean' or 'max'
        """
        super(CustomMPNNGNN,
              self).__init__(node_in_feats=node_in_feats,
                             edge_in_feats=edge_in_feats,
                             node_out_feats=node_out_feats,
                             edge_hidden_feats=edge_hidden_feats,
                             num_step_message_passing=num_step_message_passing)

        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats), nn.ReLU(),
            nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats))
        self.gnn_layer = NNConv(in_feats=node_out_feats,
                                out_feats=node_out_feats,
                                edge_func=edge_network,
                                aggregator_type=message_aggregator_type,
                                residual=residual)


### mpnnpom model

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union, Optional, Callable, Dict

from deepchem.models.losses import Loss, L2Loss
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.optimizers import Optimizer, LearningRateSchedule

# from openpom.layers.pom_ffn import CustomPositionwiseFeedForward
# from openpom.utils.loss import CustomMultiLabelLoss
# from openpom.utils.optimizer import get_optimizer

try:
    import dgl
    from dgl import DGLGraph
    from dgl.nn.pytorch import Set2Set
    from openpom.layers.pom_mpnn_gnn import CustomMPNNGNN
except (ImportError, ModuleNotFoundError):
    raise ImportError('This module requires dgl and dgllife')


class MPNNPOM(nn.Module):
    """
    MPNN model computes a principal odor map
    using multilabel-classification based on the pre-print:
    "A Principal Odor Map Unifies DiverseTasks in Human
        Olfactory Perception" [1]

    This model proceeds as follows:

    * Combine latest node representations and edge features in
        updating node representations, which involves multiple
        rounds of message passing.
    * For each graph, compute its representation by radius 0 combination
        to fold atom and bond embeddings together, followed by
        'set2set' or 'global_sum_pooling' readout.
    * Perform the final prediction using a feed-forward layer.

    References
    ----------
    .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling,
        Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres,
        Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin,
        Joel D. Mainland, Alexander B. Wiltschko
        `A Principal Odor Map Unifies Diverse Tasks
        in Human Olfactory Perception preprint
        <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.

    .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee,
        Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko
        `Machine Learning for Scent:
        Learning Generalizable Perceptual Representations
        of Small Molecules <https://arxiv.org/abs/1910.10685>`_.

    .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley,
        Oriol Vinyals, George E. Dahl.
        "Neural Message Passing for Quantum Chemistry." ICML 2017.

    Notes
    -----
    This class requires DGL (https://github.com/dmlc/dgl)
    and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci)
    to be installed.
    """

    def __init__(self,
                 n_tasks: int,
                 node_out_feats: int = 64,
                 edge_hidden_feats: int = 128,
                 edge_out_feats: int = 64,
                 num_step_message_passing: int = 3,
                 mpnn_residual: bool = True,
                 message_aggregator_type: str = 'sum',
                 mode: str = 'classification',
                 number_atom_features: int = 134,
                 number_bond_features: int = 6,
                 n_classes: int = 1,
                 nfeat_name: str = 'x',
                 efeat_name: str = 'edge_attr',
                 readout_type: str = 'set2set',
                 num_step_set2set: int = 6,
                 num_layer_set2set: int = 3,
                 ffn_hidden_list: List = [300],
                 ffn_embeddings: int = 256,
                 ffn_activation: str = 'relu',
                 ffn_dropout_p: float = 0.0,
                 ffn_dropout_at_input_no_act: bool = True):
        """
        Parameters
        ----------
        n_tasks: int
            Number of tasks.
        node_out_feats: int
            The length of the final node representation vectors
            before readout. Default to 64.
        edge_hidden_feats: int
            The length of the hidden edge representation vectors
            for mpnn edge network. Default to 128.
        edge_out_feats: int
            The length of the final edge representation vectors
            before readout. Default to 64.
        num_step_message_passing: int
            The number of rounds of message passing. Default to 3.
        mpnn_residual: bool
            If true, adds residual layer to mpnn layer. Default to True.
        message_aggregator_type: str
            MPNN message aggregator type, 'sum', 'mean' or 'max'.
            Default to 'sum'.
        mode: str
            The model type, 'classification' or 'regression'.
            Default to 'classification'.
        number_atom_features: int
            The length of the initial atom feature vectors. Default to 134.
        number_bond_features: int
            The length of the initial bond feature vectors. Default to 6.
        n_classes: int
            The number of classes to predict per task
            (only used when ``mode`` is 'classification'). Default to 1.
        nfeat_name: str
            For an input graph ``g``, the model assumes that it stores
            node features in ``g.ndata[nfeat_name]`` and will retrieve
            input node features from that. Default to 'x'.
        efeat_name: str
            For an input graph ``g``, the model assumes that it stores
            edge features in ``g.edata[efeat_name]`` and will retrieve
            input edge features from that. Default to 'edge_attr'.
        readout_type: str
            The Readout type, 'set2set' or 'global_sum_pooling'.
            Default to 'set2set'.
        num_step_set2set: int
            Number of steps in set2set readout.
            Used if, readout_type == 'set2set'.
            Default to 6.
        num_layer_set2set: int
            Number of layers in set2set readout.
            Used if, readout_type == 'set2set'.
            Default to 3.
        ffn_hidden_list: List
            List of sizes of hidden layer in the feed-forward network layer.
            Default to [300].
        ffn_embeddings: int
            Size of penultimate layer in the feed-forward network layer.
            This determines the Principal Odor Map dimension.
            Default to 256.
        ffn_activation: str
            Activation function to be used in feed-forward network layer.
            Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU,
            'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU,
            and 'elu' for ELU.
        ffn_dropout_p: float
            Dropout probability for the feed-forward network layer.
            Default to 0.0
        ffn_dropout_at_input_no_act: bool
            If true, dropout is applied on the input tensor.
            For single layer, it is not passed to an activation function.
        """
        if mode not in ['classification', 'regression']:
            raise ValueError(
                "mode must be either 'classification' or 'regression'")

        super(MPNNPOM, self).__init__()

        self.n_tasks: int = n_tasks
        self.mode: str = mode
        self.n_classes: int = n_classes
        self.nfeat_name: str = nfeat_name
        self.efeat_name: str = efeat_name
        self.readout_type: str = readout_type
        self.ffn_embeddings: int = ffn_embeddings
        self.ffn_activation: str = ffn_activation
        self.ffn_dropout_p: float = ffn_dropout_p

        if mode == 'classification':
            self.ffn_output: int = n_tasks * n_classes
        else:
            self.ffn_output = n_tasks

        self.mpnn: nn.Module = CustomMPNNGNN(
            node_in_feats=number_atom_features,
            node_out_feats=node_out_feats,
            edge_in_feats=number_bond_features,
            edge_hidden_feats=edge_hidden_feats,
            num_step_message_passing=num_step_message_passing,
            residual=mpnn_residual,
            message_aggregator_type=message_aggregator_type)

        self.project_edge_feats: nn.Module = nn.Sequential(
            nn.Linear(number_bond_features, edge_out_feats), nn.ReLU())

        if self.readout_type == 'set2set':
            self.readout_set2set: nn.Module = Set2Set(
                input_dim=node_out_feats + edge_out_feats,
                n_iters=num_step_set2set,
                n_layers=num_layer_set2set)
            ffn_input: int = 2 * (node_out_feats + edge_out_feats)
        elif self.readout_type == 'global_sum_pooling':
            ffn_input = node_out_feats + edge_out_feats
        else:
            raise Exception("readout_type invalid")

        if ffn_embeddings is not None:
            d_hidden_list: List = ffn_hidden_list + [ffn_embeddings]

        self.ffn: nn.Module = CustomPositionwiseFeedForward(
            d_input=ffn_input,
            d_hidden_list=d_hidden_list,
            d_output=self.ffn_output,
            activation=ffn_activation,
            dropout_p=ffn_dropout_p,
            dropout_at_input_no_act=ffn_dropout_at_input_no_act)

    def _readout(self, g: DGLGraph, node_encodings: torch.Tensor,
                 edge_feats: torch.Tensor) -> torch.Tensor:
        """
        Method to execute the readout phase.
        (compute molecules encodings from atom hidden states)

        Readout phase consists of radius 0 combination to fold atom
        and bond embeddings together,
        followed by:
            - a reduce-sum across atoms
                if `self.readout_type == 'global_sum_pooling'`
            - set2set pooling
                if `self.readout_type == 'set2set'`

        Parameters
        ----------
        g: DGLGraph
            A DGLGraph for a batch of graphs.
            It stores the node features in
            ``dgl_graph.ndata[self.nfeat_name]`` and edge features in
            ``dgl_graph.edata[self.efeat_name]``.

        node_encodings: torch.Tensor
            Tensor containing node hidden states.

        edge_feats: torch.Tensor
            Tensor containing edge features.

        Returns
        -------
        batch_mol_hidden_states: torch.Tensor
            Tensor containing batchwise molecule encodings.
        """

        g.ndata['node_emb'] = node_encodings
        g.edata['edge_emb'] = self.project_edge_feats(edge_feats)

        def message_func(edges) -> Dict:
            """
            The message function to generate messages
            along the edges for DGLGraph.send_and_recv()
            """
            src_msg: torch.Tensor = torch.cat(
                (edges.src['node_emb'], edges.data['edge_emb']), dim=1)
            return {'src_msg': src_msg}

        def reduce_func(nodes) -> Dict:
            """
            The reduce function to aggregate the messages
            for DGLGraph.send_and_recv()
            """
            src_msg_sum: torch.Tensor = torch.sum(nodes.mailbox['src_msg'],
                                                  dim=1)
            return {'src_msg_sum': src_msg_sum}

        # radius 0 combination to fold atom and bond embeddings together
        g.send_and_recv(g.edges(),
                        message_func=message_func,
                        reduce_func=reduce_func)

        if self.readout_type == 'set2set':
            batch_mol_hidden_states: torch.Tensor = self.readout_set2set(
                g, g.ndata['src_msg_sum'])
        elif self.readout_type == 'global_sum_pooling':
            batch_mol_hidden_states = dgl.sum_nodes(g, 'src_msg_sum')

        # batch_size x (node_out_feats + edge_out_feats)
        return batch_mol_hidden_states

    def forward(
        self, g: tuple[DGLGraph,str] 
    ) -> Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
        """
        Foward pass for MPNNPOM class. It also returns embeddings for POM.

        Parameters
        ----------
        g: DGLGraph
            A DGLGraph for a batch of graphs. It stores the node features in
            ``dgl_graph.ndata[self.nfeat_name]`` and edge features in
            ``dgl_graph.edata[self.efeat_name]``.

        Returns
        -------
        Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
            The model output.

        * When self.mode = 'regression',
            its shape will be ``(dgl_graph.batch_size, self.n_tasks)``.
        * When self.mode = 'classification',
            the output consists of probabilities for classes.
            Its shape will be
            ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)``
            if self.n_tasks > 1;
            its shape will be ``(dgl_graph.batch_size, self.n_classes)``
            if self.n_tasks is 1.
        """
        node_feats: torch.Tensor = g[0].ndata[self.nfeat_name]
        edge_feats: torch.Tensor = g[0].edata[self.efeat_name]

        node_encodings: torch.Tensor = self.mpnn(g[0], node_feats, edge_feats)

        molecular_encodings: torch.Tensor = self._readout(
            g[0], node_encodings, edge_feats)
        if self.readout_type == 'global_sum_pooling':
            molecular_encodings = F.softmax(molecular_encodings, dim=1)
        # g[1] contains the smiles list for the big DGLGraph data, use the smiles to train 3D rep layer and get the molecular_encodings
        smiles_list = g[1]

        embeddings: torch.Tensor
        out: torch.Tensor
        embeddings, out = self.ffn(molecular_encodings)

        if self.mode == 'classification':
            if self.n_tasks == 1:
                logits: torch.Tensor = out.view(-1, self.n_classes)
            else:
                logits = out.view(-1, self.n_tasks, self.n_classes)
            proba: torch.Tensor = F.sigmoid(
                logits)  # (batch, n_tasks, classes)
            if self.n_classes == 1:
                proba = proba.squeeze(-1)  # (batch, n_tasks)
            # return proba, logits, embeddings
            # get output before sigmoid activation
            return proba, logits,embeddings
        else:
            return out


class MPNNPOMModel(TorchModel):
    """
    MPNNPOMModel for obtaining a principal odor map
    using multilabel-classification based on the pre-print:
    "A Principal Odor Map Unifies DiverseTasks in Human
        Olfactory Perception" [1]

    * Combine latest node representations and edge features in
        updating node representations, which involves multiple
        rounds of message passing.
    * For each graph, compute its representation by radius 0 combination
        to fold atom and bond embeddings together, followed by
        'set2set' or 'global_sum_pooling' readout.
    * Perform the final prediction using a feed-forward layer.

    References
    ----------
    .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling,
        Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres,
        Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin,
        Joel D. Mainland, Alexander B. Wiltschko
        `A Principal Odor Map Unifies Diverse Tasks
        in Human Olfactory Perception preprint
        <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.

    .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee,
        Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko
        `Machine Learning for Scent:
        Learning Generalizable Perceptual Representations
        of Small Molecules <https://arxiv.org/abs/1910.10685>`_.

    .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley,
        Oriol Vinyals, George E. Dahl.
        "Neural Message Passing for Quantum Chemistry." ICML 2017.

    Notes
    -----
    This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
    (https://github.com/awslabs/dgl-lifesci) to be installed.

    The featurizer used with MPNNPOMModel must produce a Deepchem GraphData
    object which should have both 'edge' and 'node' features.
    """

    def __init__(self,
                 n_tasks: int,
                 class_imbalance_ratio: Optional[List] = None,
                 sample_weights: Optional[List] = None,
                 loss_aggr_type: str = 'sum',
                 learning_rate: Union[float, LearningRateSchedule] = 0.001,
                 batch_size: int = 100,
                 node_out_feats: int = 64,
                 edge_hidden_feats: int = 128,
                 edge_out_feats: int = 64,
                 num_step_message_passing: int = 3,
                 mpnn_residual: bool = True,
                 message_aggregator_type: str = 'sum',
                 mode: str = 'regression',
                 number_atom_features: int = 134,
                 number_bond_features: int = 6,
                 n_classes: int = 1,
                 readout_type: str = 'set2set',
                 num_step_set2set: int = 6,
                 num_layer_set2set: int = 3,
                 ffn_hidden_list: List = [300],
                 ffn_embeddings: int = 256,
                 ffn_activation: str = 'relu',
                 ffn_dropout_p: float = 0.0,
                 ffn_dropout_at_input_no_act: bool = True,
                 weight_decay: float = 1e-5,
                 self_loop: bool = False,
                 optimizer_name: str = 'adam',
                 device_name: Optional[str] = None,
                 **kwargs):
        """
        Parameters
        ----------
        n_tasks: int
            Number of tasks.
        class_imbalance_ratio: Optional[List]
            List of imbalance ratios per task.
        loss_aggr_type: str
            loss aggregation type; 'sum' or 'mean'. Default to 'sum'.
            Only applies to CustomMultiLabelLoss for classification
        learning_rate: Union[float, LearningRateSchedule]
            Learning rate value or scheduler object. Default to 0.001.
        batch_size: int
            Batch size for training. Default to 100.
        node_out_feats: int
            The length of the final node representation vectors
            before readout. Default to 64.
        edge_hidden_feats: int
            The length of the hidden edge representation vectors
            for mpnn edge network. Default to 128.
        edge_out_feats: int
            The length of the final edge representation vectors
            before readout. Default to 64.
        num_step_message_passing: int
            The number of rounds of message passing. Default to 3.
        mpnn_residual: bool
            If true, adds residual layer to mpnn layer. Default to True.
        message_aggregator_type: str
            MPNN message aggregator type, 'sum', 'mean' or 'max'.
            Default to 'sum'.
        mode: str
            The model type, 'classification' or 'regression'.
            Default to 'classification'.
        number_atom_features: int
            The length of the initial atom feature vectors. Default to 134.
        number_bond_features: int
            The length of the initial bond feature vectors. Default to 6.
        n_classes: int
            The number of classes to predict per task
            (only used when ``mode`` is 'classification'). Default to 1.
        readout_type: str
            The Readout type, 'set2set' or 'global_sum_pooling'.
            Default to 'set2set'.
        num_step_set2set: int
            Number of steps in set2set readout.
            Used if, readout_type == 'set2set'.
            Default to 6.
        num_layer_set2set: int
            Number of layers in set2set readout.
            Used if, readout_type == 'set2set'.
            Default to 3.
        ffn_hidden_list: List
            List of sizes of hidden layer in the feed-forward network layer.
            Default to [300].
        ffn_embeddings: int
            Size of penultimate layer in the feed-forward network layer.
            This determines the Principal Odor Map dimension.
            Default to 256.
        ffn_activation: str
            Activation function to be used in feed-forward network layer.
            Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU,
            'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU,
            and 'elu' for ELU.
        ffn_dropout_p: float
            Dropout probability for the feed-forward network layer.
            Default to 0.0
        ffn_dropout_at_input_no_act: bool
            If true, dropout is applied on the input tensor.
            For single layer, it is not passed to an activation function.
        weight_decay: float
            weight decay value for L1 and L2 regularization. Default to 1e-5.
        self_loop: bool
            Whether to add self loops for the nodes, i.e. edges
            from nodes to themselves. Generally, an MPNNPOMModel
            does not require self loops. Default to False.
        optimizer_name: str
            Name of optimizer to be used from
            [adam, adagrad, adamw, sparseadam, rmsprop, sgd, kfac]
            Default to 'adam'.
        device_name: Optional[str]
            The device on which to run computations. If None, a device is
            chosen automatically.
        kwargs
            This can include any keyword argument of TorchModel.
        """
        model: nn.Module = MPNNPOM(
            n_tasks=n_tasks,
            node_out_feats=node_out_feats,
            edge_hidden_feats=edge_hidden_feats,
            edge_out_feats=edge_out_feats,
            num_step_message_passing=num_step_message_passing,
            mpnn_residual=mpnn_residual,
            message_aggregator_type=message_aggregator_type,
            mode=mode,
            number_atom_features=number_atom_features,
            number_bond_features=number_bond_features,
            n_classes=n_classes,
            readout_type=readout_type,
            num_step_set2set=num_step_set2set,
            num_layer_set2set=num_layer_set2set,
            ffn_hidden_list=ffn_hidden_list,
            ffn_embeddings=ffn_embeddings,
            ffn_activation=ffn_activation,
            ffn_dropout_p=ffn_dropout_p,
            ffn_dropout_at_input_no_act=ffn_dropout_at_input_no_act)

        if class_imbalance_ratio and (len(class_imbalance_ratio) != n_tasks):
            raise Exception("size of class_imbalance_ratio \
                            should be equal to n_tasks")

        if mode == 'regression':
            loss: Loss = L2Loss()
            output_types: List = ['prediction']
        else:
            loss = CustomMultiLabelLoss(
                class_imbalance_ratio=class_imbalance_ratio,
                loss_aggr_type=loss_aggr_type,
                sample_weights=sample_weights,
                device=device_name)
            output_types = ['prediction', 'loss', 'embedding']

        optimizer: Optimizer = get_optimizer(optimizer_name)
        optimizer.learning_rate = learning_rate
        if device_name is not None:
            device: Optional[torch.device] = torch.device(device_name)
        else:
            device = None
        super(MPNNPOMModel, self).__init__(model,
                                           loss=loss,
                                           output_types=output_types,
                                           optimizer=optimizer,
                                           learning_rate=learning_rate,
                                           batch_size=batch_size,
                                           device=device,
                                           **kwargs)

        self.weight_decay: float = weight_decay
        self._self_loop: bool = self_loop
        self.regularization_loss: Callable = self._regularization_loss

    def _regularization_loss(self) -> torch.Tensor:
        """
        L1 and L2-norm losses for regularization

        Returns
        -------
        torch.Tensor
            sum of l1_norm and l2_norm
        """
        l1_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
        l2_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
        for name, param in self.model.named_parameters():
            if 'bias' not in name:
                l1_regularization = l1_regularization + torch.norm(param, p=1)
                l2_regularization = l2_regularization + torch.norm(param, p=2)
        l1_norm: torch.Tensor = self.weight_decay * l1_regularization
        l2_norm: torch.Tensor = self.weight_decay * l2_regularization
        return l1_norm + l2_norm

    def _prepare_batch(
        self, batch: Tuple[List, List, List]
    ) -> Tuple[DGLGraph, List[torch.Tensor], List[torch.Tensor]]:
        """Create batch data for MPNN.

        Parameters
        ----------
        batch: Tuple[List, List, List]
            The tuple is ``(inputs, labels, weights)``.

        Returns
        -------
        g: DGLGraph
            DGLGraph for a batch of graphs.
        labels: list of torch.Tensor or None
            The graph labels.
        weights: list of torch.Tensor or None
            The weights for each sample or
            sample/task pair converted to torch.Tensor.
        """
        inputs: List
        labels: List
        weights: List

        inputs, labels, weights = batch
        smiles = list(x[1] for x in inputs[0])
        # print(inputs[0])
        # now input is an array of [graphdata, smiles(str)], where smiles is used to point to another dataset
        dgl_graphs: List[DGLGraph] = [
            graph[0].to_dgl_graph(self_loop=self._self_loop)
            for graph in inputs[0]
        ]
        g: DGLGraph = dgl.batch(dgl_graphs).to(self.device)
        _, labels, weights = super(MPNNPOMModel, self)._prepare_batch(
            ([], labels, weights))
        return (g, smiles), labels, weights


### Training

In [30]:
TASKS = [
'alcoholic', 'aldehydic', 'alliaceous', 'almond', 'amber', 'animal',
'anisic', 'apple', 'apricot', 'aromatic', 'balsamic', 'banana', 'beefy',
'bergamot', 'berry', 'bitter', 'black currant', 'brandy', 'burnt',
'buttery', 'cabbage', 'camphoreous', 'caramellic', 'cedar', 'celery',
'chamomile', 'cheesy', 'cherry', 'chocolate', 'cinnamon', 'citrus', 'clean',
'clove', 'cocoa', 'coconut', 'coffee', 'cognac', 'cooked', 'cooling',
'cortex', 'coumarinic', 'creamy', 'cucumber', 'dairy', 'dry', 'earthy',
'ethereal', 'fatty', 'fermented', 'fishy', 'floral', 'fresh', 'fruit skin',
'fruity', 'garlic', 'gassy', 'geranium', 'grape', 'grapefruit', 'grassy',
'green', 'hawthorn', 'hay', 'hazelnut', 'herbal', 'honey', 'hyacinth',
'jasmin', 'juicy', 'ketonic', 'lactonic', 'lavender', 'leafy', 'leathery',
'lemon', 'lily', 'malty', 'meaty', 'medicinal', 'melon', 'metallic',
'milky', 'mint', 'muguet', 'mushroom', 'musk', 'musty', 'natural', 'nutty',
'odorless', 'oily', 'onion', 'orange', 'orangeflower', 'orris', 'ozone',
'peach', 'pear', 'phenolic', 'pine', 'pineapple', 'plum', 'popcorn',
'potato', 'powdery', 'pungent', 'radish', 'raspberry', 'ripe', 'roasted',
'rose', 'rummy', 'sandalwood', 'savory', 'sharp', 'smoky', 'soapy',
'solvent', 'sour', 'spicy', 'strawberry', 'sulfurous', 'sweaty', 'sweet',
'tea', 'terpenic', 'tobacco', 'tomato', 'tropical', 'vanilla', 'vegetable',
'vetiver', 'violet', 'warm', 'waxy', 'weedy', 'winey', 'woody'
]
print("No of tasks: ", len(TASKS))

No of tasks:  138


In [31]:
import csv
# download curated dataset
# !wget https://raw.githubusercontent.com/ARY2260/openpom/main/openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv

# The curated dataset can also found at `openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv` in the repo.

# input_file = '/home/stephen/openpom/openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv' # or new downloaded file path
input_file = '/home/stephen/openpom/openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv' # or new downloaded file path
# dataset3d ={}
# with open(input_file, mode='r', newline='') as file:
#     csv_reader = csv.DictReader(file)
#     for row in csv_reader:
#         smiles = row['nonStereoSMILES']
#         # Parse the coordinates string into a list of lists of floats
#         coordinates = [] 
#         # [list(map(float, coord.strip().split(','))) for coord in row['Coordinates'].split('\n')]
#         for coord in row['Coordinates'].split('\n'):
#             try:
#                 coordinates.append(list(map(float,coord.strip().split(','))))
#             except:
#                 for atom in Chem.MolFromSmiles(smiles).GetAtoms():
#                     coordinates.append([0,0,0])
            
#         # Add the SMILES string as the key and coordinates as the value
#         dataset3d[smiles] = coordinates
# def getAtom3dInfo(smiles,atomIdx):
#     return dataset3d[smiles][atomIdx]
# print(len(dataset3d['CC(O)CN']))

# print(len(dataset3d['CCCCCCCCC(=O)NCc1ccc(O)c(OC)c1']))


In [32]:
# get dataset
featurizer = GraphFeaturizer()
smiles_field = 'nonStereoSMILES'
loader = dc.data.CSVLoader(tasks=TASKS,
                   feature_field=smiles_field,
                   featurizer=featurizer)
dataset = loader.create_dataset(inputs=[input_file])
n_tasks = len(dataset.tasks)

In [33]:
print(n_tasks)
dataset.X[0][1]

138


'CC(O)CN'

In [34]:
# get train valid test splits

randomstratifiedsplitter = dc.splits.RandomStratifiedSplitter()
train_dataset, valid_dataset, test_dataset = randomstratifiedsplitter.train_valid_test_split(dataset, frac_train = 0.984, frac_valid = 0.016, frac_test = 0, seed = 3)

In [35]:
print("train_dataset: ", len(train_dataset))
print("valid_dataset: ", len(valid_dataset))
print("test_dataset: ", len(test_dataset))

train_dataset:  4895
valid_dataset:  88
test_dataset:  0


In [36]:
import requests
# def get_smiles(cid):
#     url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
#     response = requests.get(url)
#     if response.status_code == 200:
#         data = response.json()
#         try:
#             smiles = data['PropertyTable']['Properties'][0]['CanonicalSMILES']
#             print(smiles)
#             return smiles
#         except (IndexError, KeyError):
#             return None
#     else:
#         return None
# intersection_cid_list = [62465, 5634, 7685, 1031, 1032, 31244, 527, 31249, 31252, 15380, 7710, 31265, 7714, 31266, 1060, 4133, 6184, 31272, 31276, 1068, 14896, 7731, 12348, 7749, 12367, 6736, 7762, 61016, 7770, 8797, 1140, 7284, 7799, 1146, 10882, 6276, 10890, 8842, 8857, 1183, 7848, 176, 7344, 177, 180, 22201, 8892, 10430, 3776, 7363, 2758, 17100, 7888, 10448, 61138, 8918, 7895, 26331, 10976, 240, 7921, 3314, 7410, 12020, 11509, 11002, 261, 263, 264, 61199, 798, 7966, 11552, 7460, 22311, 6448, 307, 7991, 2879, 10560, 323, 326, 6989, 24915, 8025, 22873, 8030, 7519, 6501, 91497, 19310, 8048, 22386, 8051, 11124, 20859, 8063, 62336, 9609, 18827, 6544, 12178, 7059, 12180, 6549, 2969, 8091, 7583, 6561, 11173, 6054, 8103, 7590, 6569, 7600, 8118, 6584, 8635, 957, 6590, 8129, 8130, 12741, 454, 12232, 460, 14286, 7119, 7632, 62433, 10722, 997, 999, 62444, 7150, 8186, 7165, 6654]
# data = {'cid': intersection_cid_list, 'smiles': [get_smiles(cid) for cid in intersection_cid_list]}
# df = pd.DataFrame(data)

# # Save the DataFrame to a new CSV file
# output_file_path = 'cids_with_smiles.csv'  # Replace with your desired output file path
# df.to_csv(output_file_path, index=False)
file_path = '162_with_smiles.csv'  # Replace with your actual file path

# Read the CSV file into a pandas DataFrame
df = pd.read_csv(file_path)

# List of CIDs you are interested in
intersection_cid_list = [62465, 5634, 7685, 1031, 1032, 31244, 527, 31249, 31252, 15380, 7710, 31265, 7714, 31266, 1060, 4133, 6184, 31272, 31276, 1068, 7731, 12348, 7749, 12367, 6736, 7762, 61016, 7770, 8797, 612, 1127, 1140, 7284, 7799, 1146, 10882, 6276, 10890, 8842, 8857, 8858, 1183, 7848, 176, 7344, 177, 179, 180, 22201, 8892, 10430, 3776, 7363, 2758, 17100, 7888, 10448, 61138, 8918, 7895, 26331, 10976, 240, 7921, 3314, 7410, 12020, 11509, 11002, 261, 263, 264, 61199, 798, 7966, 11552, 7460, 7463, 307, 7991, 2879, 10560, 323, 326, 6989, 24915, 8025, 22873, 8030, 7519, 6501, 91497, 19310, 8048, 22386, 8051, 11124, 20859, 8063, 62336, 9609, 18827, 6544, 12178, 7059, 12180, 2969, 8091, 7583, 6561, 11173, 6054, 8103, 6569, 7600, 8118, 6584, 8635, 957, 6590, 8129, 8130, 12741, 454, 12232, 460, 14286, 7119, 7632, 62433, 10722, 7654, 999, 62444, 7150, 8186, 7165, 6654, 650]

# Filter the DataFrame to only include rows with CIDs in the list
filtered_df = df[df['CID'].isin(intersection_cid_list)]

# Create a set of SMILES strings from the filtered DataFrame
intersection_smiles_set = set(filtered_df['nonStereoSMILES'])
for elt in valid_dataset.X[:,1]:
    if elt in intersection_cid_list:
        print(true)

In [37]:
train_ratios = get_class_imbalance_ratio(train_dataset)

sample_weights = get_sample_weights(train_dataset, intersection_smiles_set)
assert len(train_ratios) == n_tasks

In [38]:
# learning_rate = ExponentialDecay(initial_rate=0.001, decay_rate=0.5, decay_steps=32*15, staircase=True)
learning_rate = 0.001

In [44]:
# initialize model

model = MPNNPOMModel(n_tasks = n_tasks,
                            batch_size=128,
                            learning_rate=learning_rate,
                            class_imbalance_ratio = train_ratios,
                            sample_weights=sample_weights,
                            loss_aggr_type = 'sum',
                            node_out_feats = 100,
                            edge_hidden_feats = 75,
                            edge_out_feats = 100,
                            num_step_message_passing = 5,
                            mpnn_residual = True,
                            message_aggregator_type = 'sum',
                            mode = 'classification',
                            number_atom_features = GraphConvConstants.ATOM_FDIM,
                            number_bond_features = GraphConvConstants.BOND_FDIM,
                            n_classes = 1,
                            readout_type = 'set2set',
                            num_step_set2set = 3,
                            num_layer_set2set = 2,
                            ffn_hidden_list= [392, 392],
                            ffn_embeddings = 256,
                            ffn_activation = 'relu',
                            ffn_dropout_p = 0.12,
                            ffn_dropout_at_input_no_act = False,
                            weight_decay = 1e-5,
                            self_loop = False,
                            optimizer_name = 'adam',
                            log_frequency = 32,
                            model_dir = './examples/experiments',
                            device_name='cuda')

In [45]:
nb_epoch = 62

In [46]:
metric = dc.metrics.Metric(dc.metrics.roc_auc_score)
# 

In [47]:
start_time = datetime.now()
for epoch in range(1, nb_epoch+1):
        loss = model.fit(
              train_dataset,
              nb_epoch=1,
              max_checkpoints_to_keep=1,
              deterministic=False,
              restore=epoch>1)
        train_scores = model.evaluate(train_dataset, [metric])['roc_auc_score']
        valid_scores = model.evaluate(valid_dataset, [metric])['roc_auc_score']
        print(f"epoch {epoch}/{nb_epoch} ; loss = {loss}; train_scores = {train_scores}; valid_scores = {valid_scores}")
model.save_checkpoint()
end_time = datetime.now()

epoch 1/62 ; loss = 3.5446649278913225; train_scores = 0.6615641030279492; valid_scores = 0.6750207066533
epoch 2/62 ; loss = 3.403613771711077; train_scores = 0.7614383626863442; valid_scores = 0.749537854255211
epoch 3/62 ; loss = 3.355686369396391; train_scores = 0.796350771055595; valid_scores = 0.8005994142277463
epoch 4/62 ; loss = 3.286780221121652; train_scores = 0.8165318979458897; valid_scores = 0.798987610950277
epoch 5/62 ; loss = 2.654901663462321; train_scores = 0.8212569157806578; valid_scores = 0.813355065731547
epoch 6/62 ; loss = 2.9323192596435548; train_scores = 0.8530324564436446; valid_scores = 0.848380044966851
epoch 7/62 ; loss = 3.049944484935087; train_scores = 0.8612312968466114; valid_scores = 0.8525586500703933
epoch 8/62 ; loss = 3.037635167439779; train_scores = 0.8584524864998806; valid_scores = 0.8443205980565839
epoch 9/62 ; loss = 3.0049480315177672; train_scores = 0.8719402762248631; valid_scores = 0.858350429841117
epoch 10/62 ; loss = 2.78899828592

In [48]:

# test_scores = model.evaluate(test_dataset, [metric])['roc_auc_score']
# pred = model.predict(test_dataset)
# print(pred.shape)

# challenge_dataset_input = "/home/stephen/openpom/pred_percept_single160.csv"


# # Load the CSV file into a pandas DataFrame
# file_path = challenge_dataset_input  # Replace with your actual file path
# df = pd.read_csv(file_path)

# # Function to get SMILES from CID using PubChem API
# def get_smiles_from_cid(cid):
#     url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
#     response = requests.get(url)
#     if response.status_code == 200:
#         data = response.json()
#         try:
#             smiles = data['PropertyTable']['Properties'][0]['CanonicalSMILES']
#             return smiles
#         except (IndexError, KeyError):
#             return None
#     else:
#         return None

# # Apply the function to each CID in the DataFrame and create a new column for SMILES
# # print(df.columns)
# df['nonStereoSMILES'] = df['Unnamed: 0'].apply(get_smiles_from_cid)

# # Save the DataFrame to a new CSV file
output_file_path = '/home/stephen/openpom/162_with_smiles.csv'  # Replace with your desired output file path
# df.to_csv(output_file_path, index=False)
# print(model.get_model().get_ffn().last2Layer_embeddings()[0].shape)
# print("time_taken: ", str(end_time-start_time))
# print("test_score: ", test_scores)

In [50]:

challenge_TASKS = ['INTENSITY','PLEASANTNESS','BAKERY','SWEET','FRUIT','FISH','GARLIC','SPICES','COLD','SOUR','BURNT','ACID',
                   'WARM','MUSKY','SWEATY','AMMONIA','DECAYED','WOOD','GRASS','FLOWER','CHEMICAL']
print('No. of TASKS: ' , len(challenge_TASKS))

featurizer = GraphFeaturizer()
smiles_field = 'nonStereoSMILES'
cloader = dc.data.CSVLoader(tasks=challenge_TASKS,
                   feature_field=smiles_field,
                   featurizer=featurizer)
challenge_dataset = cloader.create_dataset(inputs=[output_file_path])


pred_160 =  model.predict(challenge_dataset)
# print(pred_160[0])
# print(challenge_dataset.X)


No. of TASKS:  21


In [51]:
import json
df =  pd.read_csv(output_file_path)

df['prediction'] = [json.dumps(row.tolist()) for row in pred_160]
output_path = 'output_with_predictions_weighted.csv'  # Replace with your desired output file path
df.to_csv(output_path, index=False)

In [58]:
# import deepchem as dc
# from openpom.feat.graph_featurizer import GraphFeaturizer, GraphConvConstants
# from openpom.utils.data_utils import get_class_imbalance_ratio, IterativeStratifiedSplitter
# from openpom.models.mpnn_pom import MPNNPOMModel
# from datetime import datetime
# from tqdm import tqdm
# import torch
# import numpy as np
# from sklearn.metrics import roc_auc_score

In [59]:
TASKS = [
'alcoholic', 'aldehydic', 'alliaceous', 'almond', 'amber', 'animal',
'anisic', 'apple', 'apricot', 'aromatic', 'balsamic', 'banana', 'beefy',
'bergamot', 'berry', 'bitter', 'black currant', 'brandy', 'burnt',
'buttery', 'cabbage', 'camphoreous', 'caramellic', 'cedar', 'celery',
'chamomile', 'cheesy', 'cherry', 'chocolate', 'cinnamon', 'citrus', 'clean',
'clove', 'cocoa', 'coconut', 'coffee', 'cognac', 'cooked', 'cooling',
'cortex', 'coumarinic', 'creamy', 'cucumber', 'dairy', 'dry', 'earthy',
'ethereal', 'fatty', 'fermented', 'fishy', 'floral', 'fresh', 'fruit skin',
'fruity', 'garlic', 'gassy', 'geranium', 'grape', 'grapefruit', 'grassy',
'green', 'hawthorn', 'hay', 'hazelnut', 'herbal', 'honey', 'hyacinth',
'jasmin', 'juicy', 'ketonic', 'lactonic', 'lavender', 'leafy', 'leathery',
'lemon', 'lily', 'malty', 'meaty', 'medicinal', 'melon', 'metallic',
'milky', 'mint', 'muguet', 'mushroom', 'musk', 'musty', 'natural', 'nutty',
'odorless', 'oily', 'onion', 'orange', 'orangeflower', 'orris', 'ozone',
'peach', 'pear', 'phenolic', 'pine', 'pineapple', 'plum', 'popcorn',
'potato', 'powdery', 'pungent', 'radish', 'raspberry', 'ripe', 'roasted',
'rose', 'rummy', 'sandalwood', 'savory', 'sharp', 'smoky', 'soapy',
'solvent', 'sour', 'spicy', 'strawberry', 'sulfurous', 'sweaty', 'sweet',
'tea', 'terpenic', 'tobacco', 'tomato', 'tropical', 'vanilla', 'vegetable',
'vetiver', 'violet', 'warm', 'waxy', 'weedy', 'winey', 'woody'
]

print("No of tasks: ", len(TASKS))
n_tasks = len(TASKS)

No of tasks:  138


In [60]:
# uncomment and run if no splits saved yet

# download curated dataset
# !wget https://raw.githubusercontent.com/ARY2260/openpom/main/openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv

# The curated dataset can also found at `openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv` in the repo.

input_file = '/home/stephen/openpom/openpom/data/curated_datasets/curated_GS_LF_merged_4983.csv' # or new downloaded file path

# get dataset

featurizer = GraphFeaturizer()
smiles_field = 'nonStereoSMILES'
loader = dc.data.CSVLoader(tasks=TASKS,
                   feature_field=smiles_field,
                   featurizer=featurizer)
dataset = loader.create_dataset(inputs=[input_file])
n_tasks = len(dataset.tasks)

# get k folds list
k = 5
splitter = IterativeStratifiedSplitter(order=2)
directories = ['']*2*k
for fold in range(k):
    directories[2 * fold] = f'./ensemble_cv_exp/fold_{fold+1}/train_data'
    directories[2 * fold + 1] = f'./ensemble_cv_exp/fold_{fold+1}/cv_data'
folds_list = splitter.k_fold_split(dataset=dataset, k=k, directories=directories)

In [61]:
def benchmark_ensemble(fold, train_dataset, test_dataset, n_models, nb_epoch,challenge_dataset):
    train_ratios = get_class_imbalance_ratio(train_dataset)
    sample_weights = get_sample_weights(train_dataset, seed=42)
    assert len(train_ratios) == n_tasks

    # learning_rate = 0.001
    learning_rate = dc.models.optimizers.ExponentialDecay(initial_rate=0.001, decay_rate=0.5, decay_steps=32*20, staircase=True)
    metric = dc.metrics.Metric(dc.metrics.roc_auc_score)

    # fit models
    for i in tqdm(range(n_models)):
        model = MPNNPOMModel(n_tasks = n_tasks,
                                batch_size=128,
                                learning_rate=learning_rate,
                                class_imbalance_ratio = train_ratios,
                                loss_aggr_type = 'sum',
                                node_out_feats = 100,
                                edge_hidden_feats = 75,
                                edge_out_feats = 100,
                                num_step_message_passing = 5,
                                mpnn_residual = True,
                                message_aggregator_type = 'sum',
                                mode = 'classification',
                                number_atom_features = GraphConvConstants.ATOM_FDIM,
                                number_bond_features = GraphConvConstants.BOND_FDIM,
                                n_classes = 1,
                                readout_type = 'set2set',
                                num_step_set2set = 3,
                                num_layer_set2set = 2,
                                ffn_hidden_list= [392, 392],
                                ffn_embeddings = 256,
                                ffn_activation = 'relu',
                                ffn_dropout_p = 0.12,
                                ffn_dropout_at_input_no_act = False,
                                weight_decay = 1e-5,
                                self_loop = False,
                                optimizer_name = 'adam',
                                log_frequency = 32,
                                model_dir = f'./ensemble_cv_exp/ensemble_fold_{fold+1}/experiments_{i+1}',
                                device_name='cuda')

        start_time = datetime.now()
        
        # fit model
        loss = model.fit(
            train_dataset,
            nb_epoch=nb_epoch,
            max_checkpoints_to_keep=1,
            deterministic=False,
            restore=False)
        end_time = datetime.now()
        
        train_scores = model.evaluate(train_dataset, [metric])['roc_auc_score']
        test_scores = model.evaluate(test_dataset, [metric])['roc_auc_score']
        print(f"loss = {loss}; train_scores = {train_scores}; test_scores = {test_scores}; time_taken = {str(end_time-start_time)}")
        model.save_checkpoint() # saves final checkpoint => `checkpoint2.pt`
        del model
        torch.cuda.empty_cache()    

    # Get test score from the ensemble
    list_preds = []
    # challenge predicts
    list_cpreds = []
    for i in range(n_models):
        model = MPNNPOMModel(n_tasks = n_tasks,
                                batch_size=128,
                                learning_rate=learning_rate,
                                class_imbalance_ratio = train_ratios,
                                loss_aggr_type = 'sum',
                                node_out_feats = 100,
                                edge_hidden_feats = 75,
                                edge_out_feats = 100,
                                num_step_message_passing = 5,
                                mpnn_residual = True,
                                message_aggregator_type = 'sum',
                                mode = 'classification',
                                number_atom_features = GraphConvConstants.ATOM_FDIM,
                                number_bond_features = GraphConvConstants.BOND_FDIM,
                                n_classes = 1,
                                readout_type = 'set2set',
                                num_step_set2set = 3,
                                num_layer_set2set = 2,
                                ffn_hidden_list= [392, 392],
                                ffn_embeddings = 256,
                                ffn_activation = 'relu',
                                ffn_dropout_p = 0.12,
                                ffn_dropout_at_input_no_act = False,
                                weight_decay = 1e-5,
                                self_loop = False,
                                optimizer_name = 'adam',
                                log_frequency = 32,
                                model_dir = f'./ensemble_cv_exp/ensemble_fold_{fold+1}/experiments_{i+1}',
                                device_name='cuda')
        model.restore(f"./ensemble_cv_exp/ensemble_fold_{fold+1}/experiments_{i+1}/checkpoint2.pt")
        # test_scores = model.evaluate(test_dataset, [metric])['roc_auc_score']
        # print("test_score: ", test_scores)
        preds = model.predict(test_dataset)
        cpreds = model.predict(challenge_dataset)
        list_preds.append(preds)
        list_cpreds.append(cpreds)

    preds_arr = np.asarray(list_preds)
    cpreds_arr = np.asarray(list_cpreds)
    ensemble_preds = np.mean(preds_arr, axis=0)
    ensemble_cpreds =  np.mean(cpreds_arr, axis = 0)
    return (roc_auc_score(test_dataset.y, ensemble_preds, average="macro"), ensemble_cpreds)

In [62]:
n_models = 10
nb_epoch = 62
folds_results = []
folds_ensembles = []
for fold in tqdm(range(k)):
    print(f"Fold {fold+1} ensemble starting now.")
    train_dataset = dc.data.DiskDataset(directories[2 * fold])
    test_dataset = dc.data.DiskDataset(directories[2 * fold + 1])
    
    print("train_dataset: ", len(train_dataset))
    print("test_dataset: ", len(test_dataset))
    fold_result, fold_ensemble= benchmark_ensemble(fold=fold,
                                     train_dataset=train_dataset,
                                     test_dataset=test_dataset,
                                     n_models=n_models,
                                     nb_epoch=nb_epoch,challenge_dataset=challenge_dataset)
    
    print(f"Fold {fold+1} ensemble score: ", fold_result)
    folds_results.append(fold_result)
    folds_ensembles.append(fold_ensemble)

  0%|                                                                                             | 0/5 [00:00<?, ?it/s]

Fold 1 ensemble starting now.
train_dataset:  3963
test_dataset:  1020



  0%|                                                                                            | 0/10 [00:00<?, ?it/s][A
 10%|████████▍                                                                           | 1/10 [01:32<13:54, 92.73s/it][A

loss = 1.6435250043869019; train_scores = 0.9563002086224733; test_scores = 0.8716746449349252; time_taken = 0:01:31.122858



 20%|████████████████▊                                                                   | 2/10 [03:06<12:24, 93.10s/it][A

loss = 1.6794121265411377; train_scores = 0.9545676249378593; test_scores = 0.8735803697589037; time_taken = 0:01:31.666046



 30%|█████████████████████████▏                                                          | 3/10 [04:40<10:54, 93.54s/it][A

loss = 1.8955579996109009; train_scores = 0.9505791631345845; test_scores = 0.8733754826862802; time_taken = 0:01:32.353650



 40%|█████████████████████████████████▌                                                  | 4/10 [06:12<09:19, 93.23s/it][A

loss = 1.8951716423034668; train_scores = 0.9528689902341507; test_scores = 0.8701663686689994; time_taken = 0:01:31.057196



 50%|██████████████████████████████████████████                                          | 5/10 [07:46<07:46, 93.32s/it][A

loss = 1.5795682668685913; train_scores = 0.9575965692024742; test_scores = 0.8689327076087647; time_taken = 0:01:31.844231



 60%|██████████████████████████████████████████████████▍                                 | 6/10 [09:20<06:14, 93.50s/it][A

loss = 1.579249620437622; train_scores = 0.95666034663113; test_scores = 0.8687874606897277; time_taken = 0:01:32.218576



 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [10:54<04:41, 93.70s/it][A

loss = 1.7826263904571533; train_scores = 0.9528823901352466; test_scores = 0.8728401279928477; time_taken = 0:01:32.489951



 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [12:27<03:07, 93.56s/it][A

loss = 1.7993930578231812; train_scores = 0.9551544645652553; test_scores = 0.8698906797721485; time_taken = 0:01:31.649688



 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [14:01<01:33, 93.81s/it][A

loss = 1.7229435443878174; train_scores = 0.9517608772015606; test_scores = 0.8754341752248581; time_taken = 0:01:32.771783



100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [15:36<00:00, 93.67s/it][A

loss = 1.7918314933776855; train_scores = 0.9541813758083842; test_scores = 0.8718717819038466; time_taken = 0:01:32.994119



 20%|████████████████▍                                                                 | 1/5 [15:40<1:02:40, 940.20s/it]

Fold 1 ensemble score:  0.8873141219565731
Fold 2 ensemble starting now.
train_dataset:  3997
test_dataset:  986



  0%|                                                                                            | 0/10 [00:00<?, ?it/s][A
 10%|████████▍                                                                           | 1/10 [01:38<14:43, 98.21s/it][A

loss = 1.6341204643249512; train_scores = 0.9566379436122779; test_scores = 0.8696190985302328; time_taken = 0:01:36.544928



 20%|████████████████▊                                                                   | 2/10 [03:17<13:08, 98.55s/it][A

loss = 1.755035161972046; train_scores = 0.9522705567958488; test_scores = 0.8761822711749616; time_taken = 0:01:37.158389



 30%|█████████████████████████▏                                                          | 3/10 [04:57<11:34, 99.22s/it][A

loss = 1.6320730447769165; train_scores = 0.9560556418532109; test_scores = 0.87357930226913; time_taken = 0:01:38.409375



 40%|█████████████████████████████████▌                                                  | 4/10 [06:36<09:54, 99.15s/it][A

loss = 1.7861170768737793; train_scores = 0.9498253220627413; test_scores = 0.8760226845768484; time_taken = 0:01:37.271020



 50%|██████████████████████████████████████████                                          | 5/10 [08:15<08:16, 99.26s/it][A

loss = 1.7577598094940186; train_scores = 0.9498350240704226; test_scores = 0.8728842757054582; time_taken = 0:01:37.845969



 60%|██████████████████████████████████████████████████▍                                 | 6/10 [09:54<06:36, 99.24s/it][A

loss = 1.6499236822128296; train_scores = 0.955466705560324; test_scores = 0.874699185821393; time_taken = 0:01:37.394846



 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [11:32<04:56, 98.86s/it][A

loss = 1.7711378335952759; train_scores = 0.9499903862630855; test_scores = 0.8729214786366977; time_taken = 0:01:36.381684



 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [13:11<03:17, 98.85s/it][A

loss = 1.6363370418548584; train_scores = 0.9549666481294615; test_scores = 0.87032485397431; time_taken = 0:01:37.201251



 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [14:50<01:38, 98.71s/it][A

loss = 1.7258200645446777; train_scores = 0.953040548613671; test_scores = 0.8703533379223876; time_taken = 0:01:36.736451



100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [16:26<00:00, 98.69s/it][A

loss = 1.6715792417526245; train_scores = 0.956166311855072; test_scores = 0.8726977738560912; time_taken = 0:01:35.368046



 40%|█████████████████████████████████▌                                                  | 2/5 [32:10<48:28, 969.60s/it]

Fold 2 ensemble score:  0.8869184149371335
Fold 3 ensemble starting now.
train_dataset:  3984
test_dataset:  999



  0%|                                                                                            | 0/10 [00:00<?, ?it/s][A
 10%|████████▍                                                                           | 1/10 [01:36<14:30, 96.73s/it][A

loss = 1.6158984899520874; train_scores = 0.9542197791503823; test_scores = 0.8704401823902974; time_taken = 0:01:34.983657



 20%|████████████████▊                                                                   | 2/10 [03:11<12:46, 95.83s/it][A

loss = 1.7132742404937744; train_scores = 0.9527399348406904; test_scores = 0.8710960827088176; time_taken = 0:01:33.620875



 30%|█████████████████████████▏                                                          | 3/10 [04:50<11:18, 96.88s/it][A

loss = 1.5748958587646484; train_scores = 0.9564598694744729; test_scores = 0.8690067837653257; time_taken = 0:01:36.418503



 40%|█████████████████████████████████▌                                                  | 4/10 [06:32<09:55, 99.23s/it][A

loss = 1.6142781972885132; train_scores = 0.9552532496722691; test_scores = 0.8711998132023773; time_taken = 0:01:41.289263



 50%|██████████████████████████████████████████                                          | 5/10 [08:09<08:11, 98.34s/it][A

loss = 1.6607221364974976; train_scores = 0.9537125068464605; test_scores = 0.8700770463476449; time_taken = 0:01:35.085886



 60%|██████████████████████████████████████████████████▍                                 | 6/10 [09:45<06:30, 97.58s/it][A

loss = 1.6622167825698853; train_scores = 0.954279440466472; test_scores = 0.8718725066821373; time_taken = 0:01:34.562880



 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [11:22<04:51, 97.33s/it][A

loss = 1.5805692672729492; train_scores = 0.957719359784427; test_scores = 0.8708745750828832; time_taken = 0:01:35.279133



 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [12:59<03:14, 97.17s/it][A

loss = 1.5349888801574707; train_scores = 0.9589456347422424; test_scores = 0.8702914929496397; time_taken = 0:01:35.240934



 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [14:36<01:37, 97.12s/it][A

loss = 1.724890947341919; train_scores = 0.9516814180013313; test_scores = 0.875307130330076; time_taken = 0:01:35.437247



100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [16:15<00:00, 97.58s/it][A

loss = 1.6521098613739014; train_scores = 0.9539120511832755; test_scores = 0.8747763059209586; time_taken = 0:01:37.716822



 60%|██████████████████████████████████████████████████▍                                 | 3/5 [48:29<32:27, 973.94s/it]

Fold 3 ensemble score:  0.8870904952957991
Fold 4 ensemble starting now.
train_dataset:  4003
test_dataset:  980



  0%|                                                                                            | 0/10 [00:00<?, ?it/s][A
 10%|████████▍                                                                           | 1/10 [01:36<14:28, 96.52s/it][A

loss = 1.6287542581558228; train_scores = 0.9566525343184855; test_scores = 0.8678340895961376; time_taken = 0:01:34.915587



 20%|████████████████▊                                                                   | 2/10 [03:12<12:49, 96.13s/it][A

loss = 1.6418370008468628; train_scores = 0.956067884502413; test_scores = 0.8718582144604609; time_taken = 0:01:34.232663



 30%|█████████████████████████▏                                                          | 3/10 [04:49<11:17, 96.78s/it][A

loss = 1.6257741451263428; train_scores = 0.9563372157117667; test_scores = 0.8701788440898139; time_taken = 0:01:35.966281



 40%|█████████████████████████████████▌                                                  | 4/10 [06:26<09:40, 96.68s/it][A

loss = 1.7756750583648682; train_scores = 0.9514639545743597; test_scores = 0.872361831599786; time_taken = 0:01:34.963887



 50%|██████████████████████████████████████████                                          | 5/10 [08:02<08:01, 96.33s/it][A

loss = 1.6584330797195435; train_scores = 0.9546302636960031; test_scores = 0.8695750504085289; time_taken = 0:01:34.154954



 60%|██████████████████████████████████████████████████▍                                 | 6/10 [09:37<06:24, 96.11s/it][A

loss = 1.5894982814788818; train_scores = 0.958601357098166; test_scores = 0.8663982520847128; time_taken = 0:01:34.061915



 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [11:13<04:47, 95.89s/it][A

loss = 1.5524948835372925; train_scores = 0.9586191229373922; test_scores = 0.8679205070636705; time_taken = 0:01:33.871213



 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [12:51<03:13, 96.66s/it][A

loss = 1.6252131462097168; train_scores = 0.9572916312470844; test_scores = 0.8716109537817014; time_taken = 0:01:36.678240



 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [14:34<01:38, 98.48s/it][A

loss = 1.6472058296203613; train_scores = 0.9572436176331647; test_scores = 0.8697011205578145; time_taken = 0:01:40.878415



100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [16:11<00:00, 97.17s/it][A

loss = 1.715270757675171; train_scores = 0.9531591507244195; test_scores = 0.8706767801233244; time_taken = 0:01:35.998987



 80%|█████████████████████████████████████████████████████████████████▌                | 4/5 [1:04:44<16:14, 974.32s/it]

Fold 4 ensemble score:  0.8843187071298421
Fold 5 ensemble starting now.
train_dataset:  3985
test_dataset:  998



  0%|                                                                                            | 0/10 [00:00<?, ?it/s][A
 10%|████████▍                                                                           | 1/10 [01:34<14:12, 94.71s/it][A

loss = 1.702138900756836; train_scores = 0.9521387847690915; test_scores = 0.8751401994702805; time_taken = 0:01:32.909623



 20%|████████████████▊                                                                   | 2/10 [03:10<12:40, 95.08s/it][A

loss = 1.6775354146957397; train_scores = 0.9539742744000154; test_scores = 0.8763000567308926; time_taken = 0:01:33.659930



 30%|█████████████████████████▏                                                          | 3/10 [04:44<11:04, 94.98s/it][A

loss = 1.6171324253082275; train_scores = 0.9571025634130835; test_scores = 0.8746976067055074; time_taken = 0:01:33.221295



 40%|█████████████████████████████████▌                                                  | 4/10 [06:26<09:45, 97.61s/it][A

loss = 1.723649501800537; train_scores = 0.9513927504781439; test_scores = 0.8803128285549635; time_taken = 0:01:40.140731



 50%|██████████████████████████████████████████                                          | 5/10 [08:00<08:01, 96.39s/it][A

loss = 1.7287044525146484; train_scores = 0.9516555301012892; test_scores = 0.8748088326640047; time_taken = 0:01:32.664946



 60%|██████████████████████████████████████████████████▍                                 | 6/10 [09:35<06:23, 95.85s/it][A

loss = 1.730612874031067; train_scores = 0.9518967140907973; test_scores = 0.8795471764806473; time_taken = 0:01:33.198995



 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [11:10<04:46, 95.57s/it][A

loss = 1.713202953338623; train_scores = 0.9526710647569354; test_scores = 0.8756522977753823; time_taken = 0:01:33.499810



 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [12:45<03:10, 95.41s/it][A

loss = 1.7112946510314941; train_scores = 0.951269294538907; test_scores = 0.8759483398120069; time_taken = 0:01:33.426687



 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [14:28<01:37, 97.59s/it][A

loss = 1.5890095233917236; train_scores = 0.9567866520427777; test_scores = 0.8737843073789636; time_taken = 0:01:40.392035



100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [16:03<00:00, 96.35s/it][A

loss = 1.71857750415802; train_scores = 0.952078141450469; test_scores = 0.874420710444567; time_taken = 0:01:33.950473



100%|██████████████████████████████████████████████████████████████████████████████████| 5/5 [1:20:51<00:00, 970.23s/it]

Fold 5 ensemble score:  0.8908339706789512





In [66]:
cv_mean_result = np.mean(folds_results)
cv_mean_ensemble_embedding = np.mean(folds_ensembles, axis = 0)
print(cv_mean_ensemble_embedding[0])
print(round(cv_mean_result, 4))


[-2.5910535e+00 -4.2015090e+00 -3.5298297e+00 -2.2822831e+00
 -2.2648399e+00 -1.2477567e+00 -1.2646991e+00 -4.5026774e+00
 -4.0482874e+00 -1.1327337e+00 -2.6242426e-01 -4.2799354e+00
 -2.4552250e+00 -4.0020142e+00 -3.2093384e+00 -2.0431437e+00
 -3.5969315e+00 -3.9479260e+00  1.2841821e-03 -3.2952323e+00
 -3.5902710e+00 -1.8022617e+00 -1.1232203e+00 -3.1083984e+00
 -2.2587631e+00 -3.4375362e+00 -3.3785198e+00 -2.9139245e+00
 -1.5028133e+00 -1.7282383e+00 -3.5096123e+00 -2.9601483e+00
  2.6644370e-01 -2.3787553e+00 -2.0670533e+00 -1.5703940e+00
 -4.2702436e+00 -2.8365774e+00 -2.3062983e+00 -3.0165160e+00
 -2.2590470e+00 -9.4517040e-01 -4.1948891e+00 -1.8989738e+00
 -1.0459563e+00 -1.1468416e+00 -2.9295864e+00 -3.2843940e+00
 -3.3001697e+00 -3.0441067e+00 -2.1865180e+00 -2.5898354e+00
 -3.9308884e+00 -2.7995763e+00 -3.4049048e+00 -2.7780671e+00
 -3.5422177e+00 -4.4791026e+00 -4.3882079e+00 -3.4213996e+00
 -3.2012286e+00 -2.0092456e+00 -1.6903414e+00 -2.8125067e+00
 -1.5153580e+00 -2.30224

In [65]:
import json
df =  pd.read_csv(output_file_path)

df['prediction'] = [json.dumps(row.tolist()) for row in cv_mean_ensemble_embedding]
output_path = 'ensemble_embeddings.csv'  # Replace with your desired output file path
df.to_csv(output_path, index=False)