# Playing with the data structure

In [1]:
# import malt
# import torch
# import seaborn as sns

In [5]:
from typing import Union
import functools
import dgl
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer

class Molecule(object):
    """ Models information associated with a molecule.

    Parameters
    ----------
    smiles : str
        SMILES of the molecule.

    g : dgl.DGLGraph or None, default=None
        The DGL graph of the molecule.

    metadata : dict, default={}
        Metadata associated with the compound

    featurizer : callable, default=CanonicalAtomFeaturizer(
        atom_data_field='feat')
        The function which maps the SMILES string to a DGL graph.

    Methods
    -------
    featurize()
        Convert the SMILES string to a graph if there isn't one.

    """

    def __init__(
        self,
        smiles: str,
        g: Union[dgl.DGLGraph, None] = None,
        metadata: dict = {},
        featurizer: callable = functools.partial(
            smiles_to_bigraph,
            node_featurizer=CanonicalAtomFeaturizer(atom_data_field="h"),
        ),
    ) -> None:
        self.smiles = smiles
        self.g = g
        self.metadata = metadata
        self.featurizer = featurizer

    def __repr__(self):
        return self.smiles

    def __eq__(self, other):
        return (
            self.g == other.g
            and self.y == other.y
            and self.extra == other.extra
        )
    
    def __getattr__(self, name):
        if name not in self.metadata:
            raise RuntimeError(
                f'`{name}` is not associated with this Molecule.'
            )
        return self.metadata[name]

    def featurize(self):
        """Featurize the SMILES string to get the graph.

        Returns
        -------
        dgl.DGLGraph : The resulting graph.

        """
        # if there is already a graph, raise an error
        if self.is_featurized():
            raise RuntimeError("Point is already featurized.")

        # featurize
        self.g = self.featurizer(self.smiles)

        return self

    def is_featurized(self):
        return self.g is not None

    def erase_annotation(self):
        if 'y' in self.metadata:
            self.y = None
        return self

In [6]:
m = Molecule(
    smiles='NCCc1ccc(O)c(O)c1',
    metadata = {'y': 0.0}
)

In [11]:
m.g