In [1]:
# baseline GNN model for node-level regression
# from copy import deepcopy
# import torch
# import torch.nn as nn
# from torch.optim import lr_scheduler

import pytorch_lightning as pl
from qtaim_embed.models.graph_level.base_gcn import GCNGraphPred
from qtaim_embed.utils.data import get_default_graph_level_config
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    EarlyStopping,
    ModelCheckpoint,
)

# import dgl.nn.pytorch as dglnn
# from torchmetrics.wrappers import MultioutputWrapper
# import torchmetrics

"""from qtaim_embed.models.utils import (
    _split_batched_output,
    get_layer_args,
    link_fmt_to_node_fmt,
)

from qtaim_embed.models.layers import (
    GraphConvDropoutBatch,
    ResidualBlock,
    UnifySize,
    Set2SetThenCat,
    SumPoolingThenCat,
    WeightAndSumThenCat,
    GlobalAttentionPoolingThenCat,
)
"""

'from qtaim_embed.models.utils import (\n    _split_batched_output,\n    get_layer_args,\n    link_fmt_to_node_fmt,\n)\n\nfrom qtaim_embed.models.layers import (\n    GraphConvDropoutBatch,\n    ResidualBlock,\n    UnifySize,\n    Set2SetThenCat,\n    SumPoolingThenCat,\n    WeightAndSumThenCat,\n    GlobalAttentionPoolingThenCat,\n)\n'

In [2]:
"""
def link_fmt_to_node_fmt(dict_feats):
    ret_dict = {}
    for k, v in dict_feats.items():
        assert k[-1] in ["g", "b", "a"], "key must end with g, b, or a"
        if k[-1] == "g":
            ret_dict["global"] = v
        elif k[-1] == "b":
            ret_dict["bond"] = v
        elif k[-1] == "a":
            ret_dict["atom"] = v

    return ret_dict
"""

'\ndef link_fmt_to_node_fmt(dict_feats):\n    ret_dict = {}\n    for k, v in dict_feats.items():\n        assert k[-1] in ["g", "b", "a"], "key must end with g, b, or a"\n        if k[-1] == "g":\n            ret_dict["global"] = v\n        elif k[-1] == "b":\n            ret_dict["bond"] = v\n        elif k[-1] == "a":\n            ret_dict["atom"] = v\n\n    return ret_dict\n'

class GCNGraphPred(pl.LightningModule):
"""
Basic GNN model for graph-level regression
Takes
atom_input_size: int, dimension of atom features
bond_input_size: int, dimension of bond features
global_input_size: int, dimension of global features
target_dict: dict, dictionary of targets
n_conv_layers: int, number of convolution layers
conv_fn: str "GraphConvDropoutBatch"
dropout: float, dropout rate
batch_norm: bool, whether to use batch norm
activation: str, activation function
bias: bool, whether to use bias
norm: str, normalization type
aggregate: str, aggregation type
lr: float, learning rate
scheduler_name: str, scheduler type
weight_decay: float, weight decay
lr_plateau_patience: int, patience for lr scheduler
lr_scale_factor: float, scale factor for lr scheduler
loss_fn: str, loss function
resid_n_graph_convs: int, number of graph convolutions per residual block
scalers: list, list of scalers applied to each node type
embedding_size: int, size of embedding layer
global_pooling: str, type of global pooling

    """

    def __init__(
        self,
        atom_input_size=12,
        bond_input_size=8,
        global_input_size=3,
        n_conv_layers=3,
        target_dict={"atom": "E"},
        conv_fn="GraphConvDropoutBatch",
        global_pooling="WeightAndSumThenCat",
        resid_n_graph_convs=None,
        dropout=0.2,
        batch_norm=True,
        activation="ReLU",
        bias=True,
        norm="both",
        aggregate="sum",
        lr=1e-3,
        scheduler_name="reduce_on_plateau",
        weight_decay=0.0,
        lr_plateau_patience=5,
        lr_scale_factor=0.5,
        loss_fn="mse",
        embedding_size=128,
        fc_layer_size=[128, 64],
        fc_dropout=0.0,
        fc_batch_norm=True,
        lstm_iters=3,
        lstm_layers=1,
        output_dims=1,
        pooling_ntypes=["atom", "bond"],
        pooling_ntypes_direct=["global"],
    ):
        super().__init__()
        self.learning_rate = lr

        # output_dims = 0
        # for k, v in target_dict.items():
        #    output_dims += len(v)

        assert conv_fn == "GraphConvDropoutBatch" or conv_fn == "ResidualBlock", (
            "conv_fn must be either GraphConvDropoutBatch or ResidualBlock"
            + f"but got {conv_fn}"
        )

        if conv_fn == "ResidualBlock":
            assert resid_n_graph_convs is not None, (
                "resid_n_graph_convs must be specified for ResidualBlock"
                + f"but got {resid_n_graph_convs}"
            )

        assert global_pooling in [
            "WeightAndSumThenCat",
            "SumPoolingThenCat",
            "GlobalAttentionPoolingThenCat",
            "Set2SetThenCat",
        ], (
            "global_pooling must be either WeightAndSumThenCat, SumPoolingThenCat, or GlobalAttentionPoolingThenCat"
            + f"but got {global_pooling}"
        )

        params = {
            "atom_input_size": atom_input_size,
            "bond_input_size": bond_input_size,
            "global_input_size": global_input_size,
            "conv_fn": conv_fn,
            "target_dict": target_dict,
            "output_dims": output_dims,
            "dropout": dropout,
            "batch_norm_tf": batch_norm,
            "activation": activation,
            "bias": bias,
            "norm": norm,
            "aggregate": aggregate,
            "n_conv_layers": n_conv_layers,
            "lr": lr,
            "weight_decay": weight_decay,
            "lr_plateau_patience": lr_plateau_patience,
            "lr_scale_factor": lr_scale_factor,
            "scheduler_name": scheduler_name,
            "loss_fn": loss_fn,
            "resid_n_graph_convs": resid_n_graph_convs,
            "embedding_size": embedding_size,
            "fc_layer_size": fc_layer_size,
            "fc_dropout": fc_dropout,
            "fc_batch_norm": fc_batch_norm,
            "n_fc_layers": len(fc_layer_size),
            "global_pooling": global_pooling,
            "ntypes_pool": pooling_ntypes,
            "ntypes_pool_direct_cat": pooling_ntypes_direct,
            "output_dims": output_dims,
            "lstm_iters": lstm_iters,
            "lstm_layers": lstm_layers,
        }

        self.hparams.update(params)
        self.save_hyperparameters()

        # convert string activation to function
        if self.hparams.activation is not None:
            self.hparams.activation = getattr(torch.nn, self.hparams.activation)()

        input_size = {
            "atom": self.hparams.atom_input_size,
            "bond": self.hparams.bond_input_size,
            "global": self.hparams.global_input_size,
        }
        # print("input size", input_size)
        self.embedding = UnifySize(
            input_dim=input_size,
            output_dim=self.hparams.embedding_size,
        )
        # self.embedding_output_size = self.hparams.embedding_size

        self.conv_layers = nn.ModuleList()

        if self.hparams.conv_fn == "GraphConvDropoutBatch":
            for i in range(self.hparams.n_conv_layers):
                embedding_in = False
                if i == 0:
                    embedding_in = True

                layer_args = get_layer_args(self.hparams, i, embedding_in=embedding_in)

                self.conv_layers.append(
                    dglnn.HeteroGraphConv(
                        {
                            "a2b": GraphConvDropoutBatch(**layer_args["a2b"]),
                            "b2a": GraphConvDropoutBatch(**layer_args["b2a"]),
                            "a2g": GraphConvDropoutBatch(**layer_args["a2g"]),
                            "g2a": GraphConvDropoutBatch(**layer_args["g2a"]),
                            "b2g": GraphConvDropoutBatch(**layer_args["b2g"]),
                            "g2b": GraphConvDropoutBatch(**layer_args["g2b"]),
                            "a2a": GraphConvDropoutBatch(**layer_args["a2a"]),
                            "b2b": GraphConvDropoutBatch(**layer_args["b2b"]),
                            "g2g": GraphConvDropoutBatch(**layer_args["g2g"]),
                        },
                        aggregate=self.hparams.aggregate,
                    )
                )

        elif self.hparams.conv_fn == "ResidualBlock":
            layer_tracker = 0
            embedding_in = False
            if layer_tracker == 0:
                embedding_in = True

            while layer_tracker < self.hparams.n_conv_layers:
                if (
                    layer_tracker + self.hparams.resid_n_graph_convs
                    > self.hparams.n_conv_layers - 1
                ):
                    # print("triggered output_layer args")
                    layer_ind = self.hparams.n_conv_layers - layer_tracker - 1
                else:
                    layer_ind = -1

                layer_args = get_layer_args(
                    self.hparams, layer_ind, embedding_in=embedding_in
                )

                output_block = False
                if layer_ind != -1:
                    output_block = True

                self.conv_layers.append(
                    ResidualBlock(
                        layer_args,
                        resid_n_graph_convs=self.hparams.resid_n_graph_convs,
                        aggregate=self.hparams.aggregate,
                        output_block=output_block,
                    )
                )

                layer_tracker += self.hparams.resid_n_graph_convs

        self.conv_layers = nn.ModuleList(self.conv_layers)
        # print("conv layer out modes", self.conv_layers[-1].mods)

        # print("conv layer out feats", self.conv_layers[-1].out_feats)
        # conv_out_size = self.conv_layers[-1].out_feats

        if self.hparams.conv_fn == "GraphConvDropoutBatch":
            conv_out_size = {}
            for k, v in self.conv_layers[-1].mods.items():
                conv_out_size[k] = v.out_feats
        elif self.hparams.conv_fn == "ResidualBlock":
            conv_out_size = self.conv_layers[-1].out_feats

        # print("conv out raw", conv_out_size)
        self.conv_out_size = link_fmt_to_node_fmt(conv_out_size)
        # print("conv out size: ", self.conv_out_size)

        if self.hparams.global_pooling == "WeightAndSumThenCat":
            readout_fn = WeightAndSumThenCat
        elif self.hparams.global_pooling == "SumPoolingThenCat":
            readout_fn = SumPoolingThenCat
        elif self.hparams.global_pooling == "GlobalAttentionPoolingThenCat":
            readout_fn = GlobalAttentionPoolingThenCat
        elif self.hparams.global_pooling == "Set2SetThenCat":
            readout_fn = Set2SetThenCat

        list_in_feats = []
        for type_feat in self.hparams.pooling_ntypes:
            list_in_feats.append(self.conv_out_size[type_feat])

        self.readout_out_size = 0
        # print("list in feats", list_in_feats)
        # print("conv out size", self.conv_out_size)
        # print("pooling ntypes", self.hparams.pooling_ntypes)
        # print("pooling ntypes direct cat", self.hparams.ntypes_pool_direct_cat)
        if self.hparams.global_pooling == "Set2SetThenCat":
            # print("using set2setthencat")

            self.readout = readout_fn(
                n_iters=self.hparams.lstm_iters,
                n_layers=self.hparams.lstm_layers,
                in_feats=list_in_feats,
                ntypes=self.hparams.pooling_ntypes,
                ntypes_direct_cat=self.hparams.ntypes_pool_direct_cat,
            )
            for i in self.hparams.pooling_ntypes:
                if i not in self.hparams.ntypes_pool_direct_cat:
                    self.readout_out_size += self.conv_out_size[i] * 2
                else:
                    self.readout_out_size += self.conv_out_size[i]

        else:
            # print("other readout used")
            self.readout = readout_fn(
                ntypes=self.hparams.pooling_ntypes,
                in_feats=list_in_feats,
                ntypes_direct_cat=self.hparams.ntypes_pool_direct_cat,
            )

            for i in self.hparams.pooling_ntypes:
                if i in self.hparams.ntypes_pool_direct_cat:
                    self.readout_out_size += self.conv_out_size[i]
                else:
                    self.readout_out_size += self.conv_out_size[i]

        # print("readout out size", self.readout_out_size)
        # self.readout_out_size = readout_out_size
        self.loss = self.loss_function()

        self.fc_layers = nn.ModuleList()

        input_size = self.readout_out_size
        for i in range(self.hparams.n_fc_layers):
            out_size = self.hparams.fc_layer_size[i]
            self.fc_layers.append(nn.Linear(input_size, out_size))
            if self.hparams.fc_batch_norm:
                self.fc_layers.append(nn.BatchNorm1d(out_size))
            if self.hparams.activation is not None:
                self.fc_layers.append(self.hparams.activation)
            if self.hparams.fc_dropout > 0:
                self.fc_layers.append(nn.Dropout(self.hparams.fc_dropout))
            input_size = out_size

        self.fc_layers.append(nn.Linear(input_size, self.hparams.output_dims))

        # print("number of output dims", output_dims)

        # create multioutput wrapper for metrics
        self.train_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.output_dims
        )
        self.train_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.output_dims
        )
        self.train_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.output_dims,
        )
        self.val_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.output_dims
        )
        self.val_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.output_dims
        )
        self.val_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.output_dims,
        )
        self.test_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.output_dims
        )
        self.test_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.output_dims
        )
        self.test_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.output_dims,
        )

    def forward(self, graph, inputs):
        """
        Forward pass
        """

        feats = self.embedding(inputs)
        for ind, conv in enumerate(self.conv_layers):
            feats = conv(graph, feats)

        readout_feats = self.readout(graph, feats)
        for ind, layer in enumerate(self.fc_layers):
            readout_feats = layer(readout_feats)

        # print("preds shape:", readout_feats.shape)
        return readout_feats

    def loss_function(self):
        """
        Initialize loss function
        """
        if self.hparams.loss_fn == "mse":
            # make multioutput wrapper for mse
            loss_multi = MultioutputWrapper(
                torchmetrics.MeanSquaredError(), num_outputs=self.hparams.output_dims
            )
        elif self.hparams.loss_fn == "smape":
            loss_multi = MultioutputWrapper(
                torchmetrics.SymmetricMeanAbsolutePercentageError(),
                num_outputs=self.hparams.output_dims,
            )
        elif self.hparams.loss_fn == "mae":
            loss_multi = MultioutputWrapper(
                torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.output_dims
            )
        else:
            loss_multi = MultioutputWrapper(
                torchmetrics.MeanSquaredError(), num_outputs=self.hparams.output_dims
            )

        loss_fn = loss_multi
        return loss_fn

    def compute_loss(self, target, pred):
        """
        Compute loss
        """
        return self.loss(target, pred)

    def feature_at_each_layer(model, graph, feats):
        """
        Get the features at each layer before the final fully-connected layer.

        This is used for feature visualization to see how the model learns.

        Returns:
            dict: (layer_idx, feats), each feats is a list of
        """

        layer_idx = 0
        atom_feats, bond_feats, global_feats = {}, {}, {}

        feats = model.embedding(feats)
        bond_feats[layer_idx] = _split_batched_output(graph, feats["bond"], "bond")
        atom_feats[layer_idx] = _split_batched_output(graph, feats["atom"], "atom")
        global_feats[layer_idx] = _split_batched_output(
            graph, feats["global"], "global"
        )

        layer_idx += 1

        # gated layer
        for layer in model.conv_layers[:-1]:
            feats = layer(graph, feats)
            # store bond feature of each molecule
            bond_feats[layer_idx] = _split_batched_output(graph, feats["bond"], "bond")

            atom_feats[layer_idx] = _split_batched_output(graph, feats["atom"], "atom")

            global_feats[layer_idx] = _split_batched_output(
                graph, feats["global"], "global"
            )
            layer_idx += 1

    def shared_step(self, batch, mode):
        batch_graph, batch_label = batch
        logits = self.forward(
            batch_graph, batch_graph.ndata["feat"]
        )  # returns a dict of node types
        # max_nodes = -1
        # for target_type, target_list in self.hparams.target_dict.items():
        #    if target_list is not None and len(target_list) > 0:
        #        labels = batch_label[target_type]
        #        logits_temp = logits[target_type]
        #        if max_nodes < logits_temp.shape[0]:
        #            max_nodes = logits_temp.shape[0]
        #        logits_list.append(logits_temp)
        #        labels_list.append(labels)
        labels = batch_label["global"]
        labels

        # logits_list = [F.pad(i, (0, 0, 0, max_nodes - i.shape[0])) for i in logits_list]
        # labels_list = [F.pad(i, (0, 0, 0, max_nodes - i.shape[0])) for i in labels_list]
        # logits = torch.cat(logits, dim=1)
        # labels = torch.cat(labels, dim=1)

        all_loss = self.compute_loss(logits, labels)

        # log loss
        self.log(
            f"{mode}_loss",
            all_loss.sum(),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            batch_size=len(labels),
            sync_dist=True,
        )
        self.update_metrics(logits, labels, mode)

    def training_step(self, batch, batch_idx):
        """
        Train step
        """
        return self.shared_step(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        """
        Val step
        """
        return self.shared_step(batch, mode="val")

    def test_step(self, batch, batch_idx):
        # Todo
        return self.shared_step(batch, mode="test")

    def on_train_epoch_end(self):
        """
        Training epoch end
        """
        r2, mae, mse = self.compute_metrics(mode="train")
        # get epoch number
        if self.trainer.current_epoch == 0:
            self.log("val_mae", 10**10, prog_bar=False)
        self.log("train_r2", r2.median(), prog_bar=False, sync_dist=True)
        self.log("train_mae", mae.mean(), prog_bar=False, sync_dist=True)
        self.log("train_mse", mse.mean(), prog_bar=True, sync_dist=True)

    def on_validation_epoch_end(self):
        """
        Validation epoch end
        """
        r2, mae, mse = self.compute_metrics(mode="val")
        r2_median = r2.median().type(torch.float32)
        self.log("val_r2", r2_median, prog_bar=True, sync_dist=True)
        self.log("val_mae", mae.mean(), prog_bar=False, sync_dist=True)
        self.log("val_mse", mse.mean(), prog_bar=True, sync_dist=True)

    def on_test_epoch_end(self):
        """
        Test epoch end
        """
        r2, mae, mse = self.compute_metrics(mode="test")
        self.log("test_r2", r2.median(), prog_bar=False, sync_dist=True)
        self.log("test_mae", mae.mean(), prog_bar=False, sync_dist=True)
        self.log("test_mse", mse.mean(), prog_bar=False, sync_dist=True)

    def update_metrics(self, pred, target, mode):
        """
        Update metrics using torchmetrics interfaces
        """

        if mode == "train":
            self.train_r2.update(pred, target)
            self.train_torch_l1.update(pred, target)
            self.train_torch_mse.update(pred, target)
        elif mode == "val":
            self.val_r2.update(pred, target)
            self.val_torch_l1.update(pred, target)
            self.val_torch_mse.update(pred, target)

        elif mode == "test":
            self.test_r2.update(pred, target)
            self.test_torch_l1.update(pred, target)
            self.test_torch_mse.update(pred, target)

    def compute_metrics(self, mode):
        """
        Compute metrics using torchmetrics interfaces
        """

        if mode == "train":
            r2 = self.train_r2.compute()
            torch_l1 = self.train_torch_l1.compute()
            torch_mse = self.train_torch_mse.compute()
            self.train_r2.reset()
            self.train_torch_l1.reset()
            self.train_torch_mse.reset()

        elif mode == "val":
            r2 = self.val_r2.compute()
            torch_l1 = self.val_torch_l1.compute()
            torch_mse = self.val_torch_mse.compute()
            self.val_r2.reset()
            self.val_torch_l1.reset()
            self.val_torch_mse.reset()

        elif mode == "test":
            r2 = self.test_r2.compute()
            torch_l1 = self.test_torch_l1.compute()
            torch_mse = self.test_torch_mse.compute()
            self.test_r2.reset()
            self.test_torch_l1.reset()
            self.test_torch_mse.reset()

        return r2, torch_l1, torch_mse

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )

        scheduler = self._config_lr_scheduler(optimizer)

        lr_scheduler = {"scheduler": scheduler, "monitor": "val_mae"}

        return [optimizer], [lr_scheduler]

    def _config_lr_scheduler(self, optimizer):
        scheduler_name = self.hparams["scheduler_name"].lower()

        if scheduler_name == "reduce_on_plateau":
            scheduler = lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode="max",
                factor=self.hparams.lr_scale_factor,
                patience=self.hparams.lr_plateau_patience,
                verbose=True,
            )

        elif scheduler_name == "none":
            scheduler = None
        else:
            raise ValueError(f"Not supported lr scheduler: {scheduler_name}")

        return scheduler

    def evaluate_manually(self, batch_graph, batched_label, scaler_list):
        """
        Evaluate a set of data manually
        Takes
            feats: dict, dictionary of batched features
            scaler_list: list, list of scalers
        """
        # batch_graph, batch_label = batch
        preds = self.forward(batch_graph, batched_label)
        preds_unscaled = deepcopy(preds)
        labels_unscaled = deepcopy(batched_label)
        for scaler in scaler_list:
            labels_unscaled = scaler.inverse_feats(labels_unscaled)
            preds_unscaled = scaler.inverse_feats(preds_unscaled)

        # manually compute metrics
        r2 = torchmetrics.R2Score()
        mae = torchmetrics.MeanAbsoluteError()
        mse = torchmetrics.MeanSquaredError()

        r2.update(preds_unscaled, labels_unscaled)
        mae.update(preds_unscaled, labels_unscaled)
        mse.update(preds_unscaled, labels_unscaled)

        r2 = r2.compute()
        mae = mae.compute()
        mse = mse.compute()

        return r2, mae, mse


In [3]:
# from qtaim_embed.utils.grapher import get_grapher
# from qtaim_embed.data.molwrapper import mol_wrappers_from_df

"""from qtaim_embed.core.dataset import (
    HeteroGraphNodeLabelDataset,
    Subset,
    HeteroGraphGraphLabelDataset,
)"""

'from qtaim_embed.core.dataset import (\n    HeteroGraphNodeLabelDataset,\n    Subset,\n    HeteroGraphGraphLabelDataset,\n)'

In [4]:
config = get_default_graph_level_config()
config["log_scale_features"] = True
config["log_scale_targets"] = False
config["standard_scale_features"] = True
config["standard_scale_targets"] = True
config["debug"] = False
config[
    "train_dataset_loc"
] = "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim_labelled.pkl"

In [5]:
"""train_dataset = HeteroGraphGraphLabelDataset(
    file=config["train_dataset_loc"],
    allowed_ring_size=config["allowed_ring_size"],
    allowed_charges=config["allowed_charges"],
    self_loop=True,
    extra_keys=config["extra_keys"],
    target_list=config["target_list"],
    extra_dataset_info=config["extra_dataset_info"],
    debug=config["debug"],
    standard_scale_features=config["standard_scale_features"],
    standard_scale_targets=config["standard_scale_targets"],
    log_scale_features=config["log_scale_features"],
    log_scale_targets=config["log_scale_targets"],
)"""

'train_dataset = HeteroGraphGraphLabelDataset(\n    file=config["train_dataset_loc"],\n    allowed_ring_size=config["allowed_ring_size"],\n    allowed_charges=config["allowed_charges"],\n    self_loop=True,\n    extra_keys=config["extra_keys"],\n    target_list=config["target_list"],\n    extra_dataset_info=config["extra_dataset_info"],\n    debug=config["debug"],\n    standard_scale_features=config["standard_scale_features"],\n    standard_scale_targets=config["standard_scale_targets"],\n    log_scale_features=config["log_scale_features"],\n    log_scale_targets=config["log_scale_targets"],\n)'

In [6]:
"""
from qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask

dataloader = DataLoaderMoleculeGraphTask(train_dataset, batch_size=200, shuffle=True)
"""

'\nfrom qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask\n\ndataloader = DataLoaderMoleculeGraphTask(train_dataset, batch_size=200, shuffle=True)\n'

In [7]:
"""
batch_graph, batch_label = next(iter(dataloader))
"""

'\nbatch_graph, batch_label = next(iter(dataloader))\n'

In [8]:
"""feat_size = train_dataset.feature_size()"""

'feat_size = train_dataset.feature_size()'

In [9]:
"""
# "GraphConvDropoutBatch",
for function in [
    "GlobalAttentionPoolingThenCat",
    "Set2SetThenCat",
    "WeightAndSumThenCat",
    "SumPoolingThenCat",
]:
    for conv_fn in ["ResidualBlock", "GraphConvDropoutBatch"]:
        print(f"function: {function}, conv_fn: {conv_fn}")
        model = GCNGraphPred(
            atom_input_size=feat_size["atom"],
            bond_input_size=feat_size["bond"],
            global_input_size=feat_size["global"],
            n_conv_layers=5,
            resid_n_graph_convs=2,
            target_dict={"global": "extra_feat_global_E1_CAM"},
            conv_fn=conv_fn,
            global_pooling=function,
            dropout=0.2,
            batch_norm=True,
            activation=None,
            bias=True,
            norm="both",
            aggregate="sum",
            lr=1e-2,
            scheduler_name="reduce_on_plateau",
            weight_decay=0.0,
            lr_plateau_patience=5,
            lr_scale_factor=0.8,
            loss_fn="mse",
            embedding_size=24,
            fc_layer_size=[256, 128, 128],
            fc_dropout=0.2,
            fc_batch_norm=True,
            lstm_iters=3,
            lstm_layers=1,
            output_dims=1,
            pooling_ntypes=["atom", "bond", "global"],
            pooling_ntypes_direct=["global"],
        )
    batch_graph, batch_label = next(iter(dataloader))
    # out = model.forward(batch_graph, batch_graph.ndata["feat"])
    out = model.shared_step((batch_graph, batch_label), "train")
"""

'\n# "GraphConvDropoutBatch",\nfor function in [\n    "GlobalAttentionPoolingThenCat",\n    "Set2SetThenCat",\n    "WeightAndSumThenCat",\n    "SumPoolingThenCat",\n]:\n    for conv_fn in ["ResidualBlock", "GraphConvDropoutBatch"]:\n        print(f"function: {function}, conv_fn: {conv_fn}")\n        model = GCNGraphPred(\n            atom_input_size=feat_size["atom"],\n            bond_input_size=feat_size["bond"],\n            global_input_size=feat_size["global"],\n            n_conv_layers=5,\n            resid_n_graph_convs=2,\n            target_dict={"global": "extra_feat_global_E1_CAM"},\n            conv_fn=conv_fn,\n            global_pooling=function,\n            dropout=0.2,\n            batch_norm=True,\n            activation=None,\n            bias=True,\n            norm="both",\n            aggregate="sum",\n            lr=1e-2,\n            scheduler_name="reduce_on_plateau",\n            weight_decay=0.0,\n            lr_plateau_patience=5,\n            lr_scale_fact

In [10]:
from qtaim_embed.core.datamodule import QTAIMGraphTaskDataModule

dm = QTAIMGraphTaskDataModule(
    config=config,
)
feat_size = dm.setup(stage="fit")

... > creating MoleculeWrapper objects


100%|██████████| 21786/21786 [00:02<00:00, 8861.69it/s]


... > bond_feats_error_count:  1
... > atom_feats_error_count:  1
element set {'F', 'O', 'C', 'N', 'H'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta', 'bond_length']
selected global keys ['extra_feat_global_E1_CAM']
... > Building graphs and featurizing


100%|██████████| 21785/21785 [01:05<00:00, 333.39it/s]


included in labels
{'global': ['extra_feat_global_E1_CAM']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_F', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_N', 'chemical_symbol_H', 'extra_feat_atom_esp_total'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 21785/21785 [00:00<00:00, 34472.57it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
... > Scaling features complete
... > feature mean(s): 
 {'atom': tensor([1.0344e+00, 2.9874e-01, 1.7718e-01, 3.5022e-02, 3.6926e-02, 6.3978e-02,
        3.5200e-02, 6.0529e-03, 1.0124e-03, 5.4512e-02, 2.3809e-01, 4.1223e-02,
        3.5831e-01, 8.9978e+00]), 'bond': tensor([0.0000, 0.1791, 0.0362, 0.0396, 0.0663, 0.0377, 0.0070, 0.8149, 0.6613,
        0.0727, 0.8958]), 'global': tensor([2.8232, 2.8453, 4.6960])}
... > feature std(s):  
 {'atom': tensor([0.3963, 0.4616, 0.3024, 0.1518, 0.1557, 0.2006, 0.1522, 0.0645, 0.0265,
        0.1866, 0.3292, 0.1639, 0.3464, 5.6739]), 'bond': tensor([0.0010, 0.3034, 0.1543, 0.1608, 0.2038, 0.1571, 0.0694, 0.0916, 0.1841,
        0.1234, 0.1575]), 'g

In [11]:
model = GCNGraphPred(
    atom_input_size=feat_size["atom"],
    bond_input_size=feat_size["bond"],
    global_input_size=feat_size["global"],
    n_conv_layers=3,
    resid_n_graph_convs=2,
    target_dict={"global": "extra_feat_global_E1_CAM"},
    conv_fn="GraphConvDropoutBatch",
    global_pooling="Set2SetThenCat",
    dropout=0.2,
    batch_norm=False,
    activation="ReLU",
    bias=True,
    norm="both",
    aggregate="sum",
    lr=0.01,
    scheduler_name="reduce_on_plateau",
    weight_decay=0.00001,
    lr_plateau_patience=25,
    lr_scale_factor=0.8,
    loss_fn="mae",
    embedding_size=10,
    fc_layer_size=[256, 128, 128],
    fc_dropout=0.2,
    fc_batch_norm=True,
    lstm_iters=3,
    lstm_layers=2,
    output_dims=1,
    pooling_ntypes=["atom", "bond", "global"],
    pooling_ntypes_direct=["global"],
)

In [12]:
import torch
from qtaim_embed.models.utils import LogParameters
import wandb

with wandb.init(project="qtaim_embed_test") as run:
    logger_tb = TensorBoardLogger("./test_logs", name="test_logs")
    torch.set_float32_matmul_precision("high")

    checkpoint_callback = ModelCheckpoint(
        dirpath="test_logs",
        filename="model_lightning_{epoch:02d}-{val_mae:.2f}",
        monitor="val_mae",
        mode="min",
        auto_insert_metric_name=True,
        save_last=True,
    )

    early_stopping_callback = EarlyStopping(
        monitor="val_mae",
        min_delta=0.00,
        patience=500,
        verbose=False,
        mode="min",
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")
    logger_wb = WandbLogger(name="test_logs")
    log_parameters = LogParameters()
    trainer_transfer = pl.Trainer(
        max_epochs=100,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        gradient_clip_val=3.0,
        default_root_dir="./test/",
        precision="32",
        log_every_n_steps=10,
        callbacks=[
            early_stopping_callback,
            lr_monitor,
            log_parameters,
            checkpoint_callback,
        ],
        enable_checkpointing=True,
        logger=[logger_tb, logger_wb],
    )

    # move model to gpu
    # model = model.cuda()

    trainer_transfer.fit(model, dm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msanti[0m ([33mhydro_homies[0m). Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ./test_logs/test_logs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type               | Params
--------------------------------------------------------
0  | embedding       | UnifySize          | 280   
1  | conv_layers     | ModuleList         | 3.3 K 
2  | readout         | Set2SetThenCat     | 6.7 K 
3  | loss            | MultioutputWrapper | 0     
4  | fc_layers       | ModuleList         | 69.8 K
5  | train_r2        | MultioutputWrapper | 0     
6  | train_torch_l1  | MultioutputWrapper | 0     
7  | train_torch_mse | MultioutputWrapper | 0     
8  | val_r2          | MultioutputWrapper | 0     
9  | val_torch_l1    | MultioutputWrapper | 0     
10 | val_torch_mse   | MultioutputWrapper | 0

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")




VBox(children=(Label(value='0.004 MB of 0.015 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.297951…

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
lr-Adam,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▄▃▂▂▂▂▂▁▁▁▁▁
train_mae,█▄▃▂▂▂▂▂▁▁▁▁▁
train_mse,█▄▃▃▂▂▂▂▁▁▁▁▁
train_r2,▁▅▆▇▇▇▇██████
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▆▄▅▃▂▂▂▂▁▁▂▂
val_mae,▁█▁▁▁▁▁▁▁▁▁▁▁▁
val_mse,█▆▄▄▃▂▂▂▂▁▁▂▂

0,1
epoch,12.0
lr-Adam,0.01
train_loss,6.24149
train_mae,6.24149
train_mse,8.66831
train_r2,0.85057
trainer/global_step,1689.0
val_loss,5.39834
val_mae,5.39834
val_mse,7.62241


In [None]:
"""# basic training loop


from torch.nn import functional as F
from sklearn.metrics import r2_score
from tqdm import tqdm
import tqdm.notebook as tq
import numpy as np

# move model to cpu
model = model.cpu()

opt = torch.optim.Adam(model.parameters(), lr=0.01)
dataloader = dm.train_dataloader()

for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label["global"]
            logits = model.forward(batch_graph, batch_graph.ndata["feat"])
            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()"""

Epoch 1: 100%|██████████| 128/128 [00:06<00:00, 20.75it/s]


-33.24173132586083 249.59429794549942


Epoch 2: 100%|██████████| 128/128 [00:05<00:00, 23.60it/s]


0.5975651399649531 145.10985386371613


Epoch 3: 100%|██████████| 128/128 [00:05<00:00, 23.79it/s]


0.6941618842251431 116.58790373802185


Epoch 4: 100%|██████████| 128/128 [00:05<00:00, 23.61it/s]


0.7409113617603024 101.8263692855835


Epoch 5: 100%|██████████| 128/128 [00:05<00:00, 23.47it/s]


0.7836339133675434 90.58988931775093


Epoch 6: 100%|██████████| 128/128 [00:05<00:00, 23.42it/s]


0.7928580934080458 85.0909181535244


Epoch 7: 100%|██████████| 128/128 [00:05<00:00, 23.53it/s]


0.8063137417976194 82.7260967195034


Epoch 8: 100%|██████████| 128/128 [00:05<00:00, 23.62it/s]


0.8160909185164501 78.84068021178246


Epoch 9: 100%|██████████| 128/128 [00:05<00:00, 24.30it/s]


0.819367901339974 76.51211869716644


Epoch 10: 100%|██████████| 128/128 [00:05<00:00, 24.07it/s]


0.8249671264120335 75.31890374422073


Epoch 11: 100%|██████████| 128/128 [00:05<00:00, 23.66it/s]


0.8328095992253378 72.14961007237434


Epoch 12: 100%|██████████| 128/128 [00:05<00:00, 23.51it/s]


0.8268704146833308 73.46147549152374


Epoch 13: 100%|██████████| 128/128 [00:05<00:00, 23.58it/s]


0.8289035893325564 73.41576477885246


Epoch 14: 100%|██████████| 128/128 [00:05<00:00, 23.68it/s]


0.8252398722538921 74.98519903421402


Epoch 15: 100%|██████████| 128/128 [00:05<00:00, 23.79it/s]


0.8434109495799262 67.3860405087471


Epoch 16:  98%|█████████▊| 126/128 [00:05<00:00, 23.87it/s]


KeyboardInterrupt: 