In [9]:
from torch_geometric.data import Data as PyGData, Dataset as PyGDataset
from torch_geometric.loader import DataLoader
from typing import Literal, List, Union
import pandas as pd
from abc import ABC, abstractmethod
from app.features.model.schema.configs import ModelConfig
import networkx as nx
import yaml



## Loading config for the Model

In [10]:
config = """name: GNNExample
dataset:
  name: Small Zinc dataset
  target_column: tpsa
  feature_columns:
    - smiles
    - mwt

featurizers:
  - name: MolToGraphFeaturizer
    type: app.features.model.featurizers.MoleculeFeaturizer
    input:
      - smiles
    args:
      allow_unknown: false
      sym_bond_list: true
      per_atom_fragmentation: false

layers:

  # Start fst branch (from featurizer)
  - name: GCN1
    type: torch_geometric.nn.GCNConv
    args:
      in_channels: 26
      out_channels: 64
    input: MolToGraphFeaturizer

  - name: GCN1_Activation
    type: torch.nn.ReLU
    input: GCN1

  - name: GCN2 
    type: torch_geometric.nn.GCNConv
    input: GCN1_Activation
    args:
      in_channels: 64
      out_channels: 64

  - name: GCN2_Activation
    type: torch.nn.ReLU
    input: GCN2

  - name: GCN3
    type: torch_geometric.nn.GCNConv
    input: GCN2_Activation
    args:
      in_channels: 64
      out_channels: 64

  - name: GCN3_Activation
    type: torch.nn.ReLU
    input: GCN3

  - name: AddPool
    type: app.features.model.layers.GlobalPooling
    input: GCN3_Activation
    args:
      aggr: 'sum'
  # End of fst branch

  # Second branch would simply be linear layers in mwt
  - name: Linear1
    type: torch.nn.Linear
    args:
      in_features: 1
      out_features: 10
    input: mwt

  - name: Combiner
    type: app.features.model.layers.Concat
    input: ['AddPool', 'Linear1']

  - name: LinearJoined
    type: torch.nn.Linear
    input: Combiner
    args:
      in_features: 74
      out_features: 1

  - name: OutputSigmoid
    type: torch.nn.Sigmoid
    input: LinearJoined
"""

model = ModelConfig.from_yaml(config)

In [11]:
model.featurizers

[AppmoleculefeaturizerLayerConfig(name='MolToGraphFeaturizer', input=['smiles'], type='app.features.model.featurizers.MoleculeFeaturizer', args=AppmoleculefeaturizerArgs(allow_unknown=False, sym_bond_list=True, per_atom_fragmentation=False))]

In [12]:
model.dataset

DatasetConfig(name='Small Zinc dataset', target_column='tpsa', feature_columns=['smiles', 'mwt'])

## Exploring the dataset manually

In [13]:
zinc_example = pd.read_csv("./zinc.csv")
zinc_example.head()

Unnamed: 0,zinc_id,smiles,mwt,logp,heavy_atoms,n_rings,heteroatoms,tpsa,hacceptors,hdonors,rotatable_bonds
0,ZINC000000000007,C=CCc1ccc(OCC(=O)N(CC)CC)c(OC)c1,277.364,2.6709,20,1,4,38.77,3,0,8
1,ZINC000000000010,C[C@@]1(c2ccccc2)OC(C(=O)O)=CC1=O,218.208,1.4696,16,2,4,63.6,3,1,2
2,ZINC000000000011,COc1cc(Cc2cnc(N)nc2N)cc(OC)c1N(C)C,303.366,1.315,22,2,7,99.52,7,2,5
3,ZINC000000000012,O=C(C[S@@](=O)C(c1ccccc1)c1ccccc1)NO,289.356,2.0301,20,2,5,66.4,3,2,5
4,ZINC000000000014,CC[C@H]1[C@H](O)N2[C@H]3C[C@@]45c6ccccc6N(C)[C...,326.44,1.5545,24,12,4,46.94,4,2,1


## Building the CustomDataset

In [4]:
from collections.abc import MutableMapping, Mapping, Sequence
from typing import Any, Callable, Dict, Iterable, List, Optional
import weakref
import torch
from torch_sparse import SparseTensor
import numpy as np
import copy


def recursive_apply_(data: Any, function: Callable) -> None:
    if isinstance(data, torch.Tensor):
        function(data)
        return

    if isinstance(data, tuple) and hasattr(data, "_fields"):  # Named Tuple
        for value in data:
            recursive_apply_(data, function)
        return

    if isinstance(data, Sequence) and not isinstance(data, str):
        for value in data:
            recursive_apply_(value, function)
        return

    if isinstance(data, Mapping):
        for value in data.values:
            recursive_apply_(value, function)
        return

    try:
        function(data)
    except:
        pass


def recursive_apply(data: Any, function: Callable) -> Any:
    if isinstance(data, torch.Tensor):
        return function(data)

    if isinstance(data, torch.nn.utils.rnn.PackedSequence):
        return function(data)

    if isinstance(data, tuple) and hasattr(data, "_fields"):
        return type(data)(*(recursive_apply(d, function) for d in data))

    if isinstance(data, Sequence) and not isinstance(data, str):
        return [recursive_apply(data, function) for d in data]

    if isinstance(data, Mapping):
        return {key: recursive_apply(data[key], function) for key in data}

    try:
        return function(data)
    except:
        return data


def size_repr(key: Any, value: Any, indent: int = 0) -> str:
    pad = " " * indent

    if isinstance(value, torch.Tensor) and value.dim() == 0:
        out = value.item()
    elif isinstance(value, torch.Tensor):
        out = str(list(value.size()))
    elif isinstance(value, np.ndarray):
        out = str(list(value.shape))
    elif isinstance(value, SparseTensor):
        out = str(value.sizes())[:-1] + f", nnz={value.nnz()}]"
    elif isinstance(value, str):
        out = f"'{value}'"
    elif isinstance(value, Sequence):
        out = str([len(value)])
    elif isinstance(value, Mapping) and len(value) == 0:
        out = "{}"
    elif (
        isinstance(value, Mapping)
        and len(value) == 1
        and not isinstance(list(value.values())[0], Mapping)
    ):
        lines = [size_repr(k, v, 0) for k, v in value.items()]
        out = "{ " + ", ".join(lines) + " }"
    elif isinstance(value, Mapping):
        lines = [size_repr(k, v, indent + 2) for k, v in value.items()]
        out = "{\n" + ",\n".join(lines) + "\n" + pad + "}"
    else:
        out = str(value)

    key = str(key).replace("'", "")
    if isinstance(value, BaseStorage):
        return f"{pad}\033[1m{key}\033[0m={out}"
    else:
        return f"{pad}{key}={out}"


class MappingView(object):
    def __init__(self, mapping: Mapping, *args: List[str]) -> None:
        self._mapping = mapping
        self._args = args

    def _keys(self) -> Iterable:
        if len(self._args) == 0:
            return self._mapping.keys()
        return [arg for arg in self._args if arg in self._mapping]

    def __len__(self) -> int:
        return len(self._keys())

    def __repr__(self) -> str:
        mapping = {key: self._mapping[key] for key in self._keys()}
        return f"{self.__class__.__name__}({mapping})"

    __class_getitem__ = classmethod(type([]))


class KeysView(MappingView):
    def __iter__(self) -> Iterable:
        yield from self._keys()


class ValuesView(MappingView):
    def __iter__(self) -> Iterable:
        for key in self._keys():
            yield self._mapping[key]


class ItemsView(MappingView):
    def __iter__(self) -> Iterable:
        for key in self._keys():
            yield (key, self._mapping[key])


class BaseStorage(MutableMapping):
    def __init__(
        self, _mapping: Optional[Dict[str, Any]] = None, **kwargs
    ) -> None:
        super().__init__()
        self._mapping = {}

        # Setup all attributes that comes from _mapping
        for key, value in (_mapping or {}).items():
            setattr(self, key, value)

        # Transform all arguments passed by kwargs
        # in new atttributes for the base storage instance
        for key, value in kwargs.items():
            setattr(self, key, value)

    @property
    def _key(self) -> Any:
        return None

    def __len__(self) -> int:
        return len(self._mapping)

    def __getattr__(self, key: str) -> Any:
        if key == "_mapping":
            self._mapping = {}
            return self._mapping

        try:
            return self[key]
        except KeyError:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{key}'"
            )

    def __setattr__(self, key: str, value: Any) -> None:
        if key == "_parent":
            self.__dict__[key] = weakref.ref(value)
        elif key[:1] == "_":
            self.__dict__[key] = value
        else:
            self[key] = value

    def __delattr__(self, key: str) -> None:
        if key[:1] == "_":
            del self.__dict__[key]
        else:
            del self[key]

    def __getitem__(self, key: str) -> Any:
        return self._mapping[key]

    def __setitem__(self, key: str, value: Any) -> None:
        if value is None and key in self._mapping:
            del self._mapping[key]
        elif value is not None:
            self._mapping[key] = value

    def __delitem__(self, key: str) -> None:
        if key in self._mapping:
            del self._mapping[key]

    def __iter__(self) -> Iterable:
        return iter(self._mapping)

    def __copy__(self):
        out = self.__class__.__new__(self.__class__)

        for key, value in self.__dict__.items():
            out.__dict__[key] = value

        out._mapping = copy.copy(out._mapping)
        return out

    def __deepcopy__(self, memo):
        out = self.__class__.__new__(self.__class__)

        for key, value in self.__dict__.items():
            out.__dict__[key] = value

        out._mapping = copy.deepcopy(out._mapping, memo)
        return out

    def __getstate__(self) -> Dict[str, Any]:
        out = self.__dict__.copy()

        _parent = out.get("_parent", None)
        if _parent is not None:
            out["_parent"] = _parent()

        return out

    def __setstate__(self, mapping: Dict[str, Any]) -> None:
        for key, value in mapping.items():
            self.__dict__[key] = value

        _parent = self.__dict__.get("_parent", None)
        if _parent is not None:
            self.__dict__["_parent"] = weakref.ref(_parent)

    def __repr__(self) -> str:
        return repr(self._mapping)

    def keys(self, *args: List[str]) -> KeysView:
        return KeysView(self._mapping, *args)

    def values(self, *args: List[str]) -> ValuesView:
        return ValuesView(self._mapping, *args)

    def items(self, *args: List[str]) -> ItemsView:
        return ItemsView(self._mapping, *args)

    def apply_(self, function: Callable, *args: List[str]):
        for key, value in self.items(*args):
            self[key] = recursive_apply(value, function)
        return self

    def apply(self, function: Callable, *args: List[str]):
        for key, value in self.items(*args):
            self[key] = recursive_apply(value, function)
        return self


## instance = DataInstance
## instance.batch
## instance.MolFeaturizer
## instance.target
## instance.any = ...
class DataInstance(BaseStorage):
    def __init__(self, y=None, **kwargs):
        self.__dict__["_store"] = BaseStorage(_parent=self)

        if y is not None:
            self.y = y

        for key, value in kwargs.items():
            setattr(self, key, value)

    def __getitem__(self, key: str) -> Any:
        return self._store[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self._store[key] = value

    def __delitem__(self, key: str) -> None:
        if key in self._store:
            del self._store[key]

    def __getattr__(self, key: str) -> Any:
        if "_store" not in self.__dict__:
            raise RuntimeError
        return getattr(self._store, key)

    def __setattr__(self, key: str, value: Any) -> None:
        setattr(self._store, key, value)

    def __delattr__(self, key: str, value: Any) -> None:
        delattr(self._store, key)

    def __repr__(self) -> str:
        cls = self.__class__.__name__
        info = [size_repr(k, v, indent=2) for k, v in self._store.items()]
        info = ",\n".join(info)
        return f"{cls}(\n{info}\n)"

In [34]:
from torch.utils.data import Dataset as TorchDataset


class CustomDataset(TorchDataset):
    def __init__(
        self,
        data: pd.DataFrame,
        feature_columns,
        featurizers_config,
        target: str,
    ) -> None:
        self.data = data
        self.target = target
        self.columns = feature_columns
        self._featurizers_config = featurizers_config

        self.setup()

    def _determine_task_type(self) -> str:
        target_type = self.data.dtypes[self.target]

        if "float" in target_type.name:
            # If it is a float target, we treat the task as a regression
            return "regression"

        return AttributeError("Unsupported target type for prediction.")

    def setup(self):
        # First we need to determine the type of the task
        self._task_type = self._determine_task_type()
        # After that we can instanciate all of the featurizers used to transform
        # the columns
        self._featurizers = {}
        for featurizer_config in self._featurizers_config:
            self._featurizers[
                featurizer_config.name
            ] = featurizer_config.create()

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index) -> DataInstance:
        d = DataInstance()
        sample = dict(self.data.iloc[index, :])

        columns_to_include = self.columns.copy()
        # We need to featurize all of the columns that pass into a
        # featurizer before include in the data instance
        for featurizer in self._featurizers_config:
            d[featurizer.name] = self._featurizers[featurizer.name](
                sample[featurizer.input[0]]
            )
            columns_to_include.remove(featurizer.input[0])

        # After that we can include all of the columns that remains from the featurizers
        for column in columns_to_include:
            d[column] = sample[column]

        return d

In [31]:
model.featurizers

[AppmoleculefeaturizerLayerConfig(name='MolToGraphFeaturizer', input=['smiles'], type='app.features.model.featurizers.MoleculeFeaturizer', args=AppmoleculefeaturizerArgs(allow_unknown=False, sym_bond_list=True, per_atom_fragmentation=False))]

In [20]:
model.dataset

DatasetConfig(name='Small Zinc dataset', target_column='tpsa', feature_columns=['smiles', 'mwt'])

In [35]:
dataset = CustomDataset(
    zinc_example,
    model.dataset.feature_columns,
    model.featurizers,
    model.dataset.target_column,
)
dataset

<__main__.CustomDataset at 0x7f1f4beae1c0>

In [36]:
dataset[0]

DataInstance(
  MolToGraphFeaturizer=Data(x=[20, 26], edge_index=[2, 40], edge_attr=[40, 9]),
  mwt=277.3639999999999
)

In [38]:
from torch_geometric.loader import DataLoader

b = next(iter(DataLoader(dataset, batch_size=4)))
b

{'MolToGraphFeaturizer': DataBatch(x=[78, 26], edge_index=[2, 162], edge_attr=[162, 9], batch=[78], ptr=[5]),
 'mwt': tensor([277.3640, 218.2080, 303.3660, 289.3560])}

In [42]:
b["MolToGraphFeaturizer"].batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3])

In [None]:
from app.features.dataset.model import Dataset


class DataModule:
    ## Data Module that will work similar to a lightning module
    ## but with support to our data types

    def __init__(self, dataset: Dataset, dataset_config, featurizers_config):
        self.dataset_config = dataset_config
        self.featurizers_config = featurizers_config

        self.dataset_metadata = dataset
        self.dataset_file = self.dataset_metadata.get_dataframe()

        self.dataset_metadata.split_type
        self.dataset_metadata.split_

    def setup(self):
        raise NotImplementedError

    def train_dataloader(self):
        pass

    def test_dataloader(self):
        pass

    def val_dataloader(self):
        pass

    def _split_data(self):
        pass