In [13]:
# baseline GNN model for node-level regression

from copy import deepcopy
import torch
from torch.optim import lr_scheduler
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    EarlyStopping,
    ModelCheckpoint,
)

import dgl.nn.pytorch as dglnn
from torchmetrics.wrappers import MultioutputWrapper
import torchmetrics



from qtaim_embed.models.graph_level.base_gcn import GCNGraphPred
from qtaim_embed.utils.data import get_default_graph_level_config
from qtaim_embed.models.layers import (
    GraphConvDropoutBatch,
    ResidualBlock,
    UnifySize,
    Set2SetThenCat,
    SumPoolingThenCat,
    WeightAndSumThenCat,
    GlobalAttentionPoolingThenCat,
)

from qtaim_embed.utils.models import (
    get_layer_args,
    link_fmt_to_node_fmt,
    _split_batched_output,
)


In [20]:
class GCNGraphPred(pl.LightningModule):
    """
    Basic GNN model for graph-level regression
    Takes
        atom_input_size: int, dimension of atom features
        bond_input_size: int, dimension of bond features
        global_input_size: int, dimension of global features
        target_dict: dict, dictionary of targets
        n_conv_layers: int, number of convolution layers
        conv_fn: str "GraphConvDropoutBatch"
        dropout: float, dropout rate
        batch_norm: bool, whether to use batch norm
        activation: str, activation function
        bias: bool, whether to use bias
        norm: str, normalization type
        aggregate: str, aggregation type
        lr: float, learning rate
        scheduler_name: str, scheduler type
        weight_decay: float, weight decay
        lr_plateau_patience: int, patience for lr scheduler
        lr_scale_factor: float, scale factor for lr scheduler
        loss_fn: str, loss function
        resid_n_graph_convs: int, number of graph convolutions per residual block
        scalers: list, list of scalers applied to each node type
        embedding_size: int, size of embedding layer
        global_pooling: str, type of global pooling

    """

    def __init__(
        self,
        atom_input_size=12,
        bond_input_size=8,
        global_input_size=3,
        n_conv_layers=3,
        target_dict={"atom": "E"},
        conv_fn="GraphConvDropoutBatch",
        global_pooling="WeightAndSumThenCat",
        resid_n_graph_convs=None,
        dropout=0.2,
        batch_norm=True,
        activation="ReLU",
        bias=True,
        norm="both",
        aggregate="sum",
        lr=1e-3,
        scheduler_name="reduce_on_plateau",
        weight_decay=0.0,
        lr_plateau_patience=5,
        lr_scale_factor=0.5,
        loss_fn="mse",
        embedding_size=128,
        fc_layer_size=[128, 64],
        fc_dropout=0.0,
        fc_batch_norm=True,
        lstm_iters=3,
        lstm_layers=1,
        pooling_ntypes=["atom", "bond"],
        pooling_ntypes_direct=["global"],
    ):
        super().__init__()
        self.learning_rate = lr

        # output_dims = 0
        # for k, v in target_dict.items():
        #    output_dims += len(v)

        assert conv_fn == "GraphConvDropoutBatch" or conv_fn == "ResidualBlock", (
            "conv_fn must be either GraphConvDropoutBatch or ResidualBlock"
            + f"but got {conv_fn}"
        )

        if conv_fn == "ResidualBlock":
            assert resid_n_graph_convs is not None, (
                "resid_n_graph_convs must be specified for ResidualBlock"
                + f"but got {resid_n_graph_convs}"
            )

        assert global_pooling in [
            "WeightAndSumThenCat",
            "SumPoolingThenCat",
            "GlobalAttentionPoolingThenCat",
            "Set2SetThenCat",
        ], (
            "global_pooling must be either WeightAndSumThenCat, SumPoolingThenCat, or GlobalAttentionPoolingThenCat"
            + f"but got {global_pooling}"
        )

        params = {
            "atom_input_size": atom_input_size,
            "bond_input_size": bond_input_size,
            "global_input_size": global_input_size,
            "conv_fn": conv_fn,
            "target_dict": target_dict,
            "dropout": dropout,
            "batch_norm_tf": batch_norm,
            "activation": activation,
            "bias": bias,
            "norm": norm,
            "aggregate": aggregate,
            "n_conv_layers": n_conv_layers,
            "lr": lr,
            "weight_decay": weight_decay,
            "lr_plateau_patience": lr_plateau_patience,
            "lr_scale_factor": lr_scale_factor,
            "scheduler_name": scheduler_name,
            "loss_fn": loss_fn,
            "resid_n_graph_convs": resid_n_graph_convs,
            "embedding_size": embedding_size,
            "fc_layer_size": fc_layer_size,
            "fc_dropout": fc_dropout,
            "fc_batch_norm": fc_batch_norm,
            "n_fc_layers": len(fc_layer_size),
            "global_pooling": global_pooling,
            "ntypes_pool": pooling_ntypes,
            "ntypes_pool_direct_cat": pooling_ntypes_direct,
            "lstm_iters": lstm_iters,
            "lstm_layers": lstm_layers,
            "ntasks": len(target_dict["global"]),
        }

        self.hparams.update(params)
        self.save_hyperparameters()

        # convert string activation to function
        if self.hparams.activation is not None:
            self.activation = getattr(torch.nn, self.hparams.activation)()
        else: 
            self.activation = None

        input_size = {
            "atom": self.hparams.atom_input_size,
            "bond": self.hparams.bond_input_size,
            "global": self.hparams.global_input_size,
        }
        # print("input size", input_size)
        self.embedding = UnifySize(
            input_dim=input_size,
            output_dim=self.hparams.embedding_size,
        )
        # self.embedding_output_size = self.hparams.embedding_size

        self.conv_layers = nn.ModuleList()

        if self.hparams.conv_fn == "GraphConvDropoutBatch":
            for i in range(self.hparams.n_conv_layers):
                # embedding_in = False
                # if i == 0:
                embedding_in = True

                layer_args = get_layer_args(self.hparams, i, activation=self.activation, embedding_in=embedding_in)
                # print("resid layer args", layer_args)

                self.conv_layers.append(
                    dglnn.HeteroGraphConv(
                        {
                            "a2b": GraphConvDropoutBatch(**layer_args["a2b"]),
                            "b2a": GraphConvDropoutBatch(**layer_args["b2a"]),
                            "a2g": GraphConvDropoutBatch(**layer_args["a2g"]),
                            "g2a": GraphConvDropoutBatch(**layer_args["g2a"]),
                            "b2g": GraphConvDropoutBatch(**layer_args["b2g"]),
                            "g2b": GraphConvDropoutBatch(**layer_args["g2b"]),
                            "a2a": GraphConvDropoutBatch(**layer_args["a2a"]),
                            "b2b": GraphConvDropoutBatch(**layer_args["b2b"]),
                            "g2g": GraphConvDropoutBatch(**layer_args["g2g"]),
                        },
                        aggregate=self.hparams.aggregate,
                    )
                )

        elif self.hparams.conv_fn == "ResidualBlock":
            layer_tracker = 0
            embedding_in = True

            while layer_tracker < self.hparams.n_conv_layers:
                if (
                    layer_tracker + self.hparams.resid_n_graph_convs
                    > self.hparams.n_conv_layers - 1
                ):
                    # print("triggered output_layer args")
                    layer_ind = self.hparams.n_conv_layers - layer_tracker - 1
                else:
                    layer_ind = -1

                layer_args = get_layer_args(
                    self.hparams, layer_ind, embedding_in=embedding_in, activation=self.activation
                )
                # print("resid layer args", layer_args)
                # for k, v in layer_args.items():
                #    print(k, v["in_feats"], v["out_feats"])

                # embedding_in = False
                output_block = False
                if layer_ind != -1:
                    output_block = True

                self.conv_layers.append(
                    ResidualBlock(
                        layer_args,
                        resid_n_graph_convs=self.hparams.resid_n_graph_convs,
                        aggregate=self.hparams.aggregate,
                        output_block=output_block,
                    )
                )

                layer_tracker += self.hparams.resid_n_graph_convs

        self.conv_layers = nn.ModuleList(self.conv_layers)
        # print("conv layer out modes", self.conv_layers[-1].mods)

        # print("conv layer out feats", self.conv_layers[-1].out_feats)
        # conv_out_size = self.conv_layers[-1].out_feats

        if self.hparams.conv_fn == "GraphConvDropoutBatch":
            conv_out_size = {}
            for k, v in self.conv_layers[-1].mods.items():
                conv_out_size[k] = v.out_feats
        elif self.hparams.conv_fn == "ResidualBlock":
            conv_out_size = self.conv_layers[-1].out_feats

        # print("conv out raw", conv_out_size)
        self.conv_out_size = link_fmt_to_node_fmt(conv_out_size)
        # print("conv out size: ", self.conv_out_size)

        if self.hparams.global_pooling == "WeightAndSumThenCat":
            readout_fn = WeightAndSumThenCat
        elif self.hparams.global_pooling == "SumPoolingThenCat":
            readout_fn = SumPoolingThenCat
        elif self.hparams.global_pooling == "GlobalAttentionPoolingThenCat":
            readout_fn = GlobalAttentionPoolingThenCat
        elif self.hparams.global_pooling == "Set2SetThenCat":
            readout_fn = Set2SetThenCat

        list_in_feats = []
        for type_feat in self.hparams.pooling_ntypes:
            list_in_feats.append(self.conv_out_size[type_feat])

        self.readout_out_size = 0
        if self.hparams.global_pooling == "Set2SetThenCat":
            # print("using set2setthencat")

            self.readout = readout_fn(
                n_iters=self.hparams.lstm_iters,
                n_layers=self.hparams.lstm_layers,
                in_feats=list_in_feats,
                ntypes=self.hparams.pooling_ntypes,
                ntypes_direct_cat=self.hparams.ntypes_pool_direct_cat,
            )
            for i in self.hparams.pooling_ntypes:
                if i not in self.hparams.ntypes_pool_direct_cat:
                    self.readout_out_size += self.conv_out_size[i] * 2
                else:
                    self.readout_out_size += self.conv_out_size[i]

        else:
            # print("other readout used")
            self.readout = readout_fn(
                ntypes=self.hparams.pooling_ntypes,
                in_feats=list_in_feats,
                ntypes_direct_cat=self.hparams.ntypes_pool_direct_cat,
            )

            for i in self.hparams.pooling_ntypes:
                if i in self.hparams.ntypes_pool_direct_cat:
                    self.readout_out_size += self.conv_out_size[i]
                else:
                    self.readout_out_size += self.conv_out_size[i]

        # print("readout out size", self.readout_out_size)
        # self.readout_out_size = readout_out_size
        self.loss = self.loss_function()

        self.fc_layers = nn.ModuleList()

        input_size = self.readout_out_size
        for i in range(self.hparams.n_fc_layers):
            out_size = self.hparams.fc_layer_size[i]
            self.fc_layers.append(nn.Linear(input_size, out_size))
            if self.hparams.fc_batch_norm:
                self.fc_layers.append(nn.BatchNorm1d(out_size))
            if self.activation is not None:
                self.fc_layers.append(self.activation)
            if self.hparams.fc_dropout > 0:
                self.fc_layers.append(nn.Dropout(self.hparams.fc_dropout))
            input_size = out_size

        self.fc_layers.append(nn.Linear(input_size, self.hparams.ntasks))

        # print("number of output dims", output_dims)
        print("... > number of tasks:", self.hparams.ntasks)

        # create multioutput wrapper for metrics
        self.train_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.ntasks
        )
        self.train_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.ntasks
        )
        self.train_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.ntasks,
        )
        self.val_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.ntasks
        )
        self.val_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.ntasks
        )
        self.val_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.ntasks,
        )
        self.test_r2 = MultioutputWrapper(
            torchmetrics.R2Score(), num_outputs=self.hparams.ntasks
        )
        self.test_torch_l1 = MultioutputWrapper(
            torchmetrics.MeanAbsoluteError(), num_outputs=self.hparams.ntasks
        )
        self.test_torch_mse = MultioutputWrapper(
            torchmetrics.MeanSquaredError(squared=False),
            num_outputs=self.hparams.ntasks,
        )

    def forward(self, graph, inputs):
        """
        Forward pass
        """

        feats = self.embedding(inputs)
        for ind, conv in enumerate(self.conv_layers):
            # print("conv layer", ind)
            feats = conv(graph, feats)

        readout_feats = self.readout(graph, feats)
        for ind, layer in enumerate(self.fc_layers):
            readout_feats = layer(readout_feats)

        #print("preds shape:", readout_feats.shape)
        return readout_feats
    
    def loss_function(self):
        """
        Initialize loss function
        """
        if self.hparams.ntasks > 1:
            loss_fn = nn.ModuleList()
            for i in range(self.hparams.ntasks):
                if self.hparams.loss_fn == "mse":
                    loss_fn.append(torchmetrics.MeanSquaredError())
                elif self.hparams.loss_fn == "smape":
                    loss_fn.append(torchmetrics.SymmetricMeanAbsolutePercentageError())
                elif self.hparams.loss_fn == "mae":
                    loss_fn.append(torchmetrics.MeanAbsoluteError())
                else:
                    loss_fn.append(torchmetrics.MeanSquaredError())

        else: 
            if self.hparams.loss_fn == "mse":
                loss_fn = torchmetrics.MeanSquaredError()
                
            elif self.hparams.loss_fn == "smape":
                loss_fn = torchmetrics.SymmetricMeanAbsolutePercentageError()
            elif self.hparams.loss_fn == "mae":
                loss_fn = torchmetrics.MeanAbsoluteError()
            else:
                loss_fn = torchmetrics.MeanSquaredError()

        return loss_fn

In [15]:
# from qtaim_embed.utils.grapher import get_grapher
# from qtaim_embed.data.molwrapper import mol_wrappers_from_df

from qtaim_embed.core.dataset import (
    HeteroGraphNodeLabelDataset,
    Subset,
    HeteroGraphGraphLabelDataset,
)

In [21]:
config = get_default_graph_level_config()
config["log_scale_features"] = True
config["log_scale_targets"] = False
config["standard_scale_features"] = True
config["standard_scale_targets"] = True
config["debug"] = False
config[
    "train_dataset_loc"
] = "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim_labelled.pkl"

In [22]:
train_dataset = HeteroGraphGraphLabelDataset(
    file=config["dataset"]["train_dataset_loc"],
    allowed_ring_size=config["dataset"]["allowed_ring_size"],
    allowed_charges=config["dataset"]["allowed_charges"],
    self_loop=True,
    extra_keys=config["dataset"]["extra_keys"],
    target_list=config["dataset"]["target_list"],
    extra_dataset_info=config["dataset"]["extra_dataset_info"],
    debug=config["debug"],
    standard_scale_features=config["standard_scale_features"],
    standard_scale_targets=config["standard_scale_targets"],
    log_scale_features=config["log_scale_features"],
    log_scale_targets=config["log_scale_targets"],
)

... > creating MoleculeWrapper objects


100%|██████████| 100/100 [00:00<00:00, 5314.63it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'N', 'C', 'O', 'H'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'bond_length']
selected global keys ['extra_feat_global_E1_CAM']
... > Building graphs and featurizing


  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:00<00:00, 326.41it/s]


included in labels
{'global': ['extra_feat_global_E1_CAM']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_C', 'chemical_symbol_O', 'chemical_symbol_H', 'extra_feat_atom_esp_total'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_total'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 100/100 [00:00<00:00, 32924.91it/s]

original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [1.03348744 0.29579238 0.1931437  0.02300363 0.03993083 0.06814284
 0.04296905 0.01909736 0.04557323 0.23654678 0.05555594 0.35547122
 9.03963532]
std [0.3914866  0.46054609 0.3107612  0.12416012 0.16150379 0.2063724
 0.16714525 0.11345734 0.17179068 0.32864473 0.18820728 0.34645936
 5.66985947]
mean [0.         0.19579709 0.02420918 0.04417114 0.07007922 0.04587003
 0.0208114  0.81517279 0.66599706]
std [0.         0.31205721 0.12725739 0.16931041 0.2089596  0.17230968
 0.11828885 0.09367724 0.18940919]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [2.81851833 2.8375313  4.69764359]
std [0.16318695 0.17198598 0.06353343]
... > Scaling features complete
... > feature mean(s): 
 {'atom': tensor([1.0335, 0.2958, 0.1931, 0.0230




In [23]:
from qtaim_embed.core.datamodule import QTAIMGraphTaskDataModule

dm = QTAIMGraphTaskDataModule(config=config,)
feat_names, feat_size = dm.prepare_data(stage="fit")

... > creating MoleculeWrapper objects


100%|██████████| 100/100 [00:00<00:00, 9158.47it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'N', 'C', 'O', 'H'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'bond_length']
selected global keys ['extra_feat_global_E1_CAM']
... > Building graphs and featurizing


  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:00<00:00, 325.70it/s]


included in labels
{'global': ['extra_feat_global_E1_CAM']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_C', 'chemical_symbol_O', 'chemical_symbol_H', 'extra_feat_atom_esp_total'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_total'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 100/100 [00:00<00:00, 34433.17it/s]

original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Scaling features
mean [2.04383219e+00 5.20350657e-01 2.78647464e-01 3.31872260e-02
 5.76080150e-02 9.83093300e-02 6.19912336e-02 2.75516594e-02
 6.57482780e-02 3.41264872e-01 8.01502818e-02 5.12836569e-01
 2.08447046e+06]
std [1.25340657e+00 8.73526527e-01 4.48333642e-01 1.79125191e-01
 2.33000712e-01 2.97732440e-01 2.41139629e-01 1.63684347e-01
 2.47841566e-01 4.74134115e-01 2.71525715e-01 4.99835195e-01
 5.20377426e+06]
mean [0.         0.28247549 0.03492647 0.06372549 0.10110294 0.06617647
 0.03002451 1.26981778 0.98419271]
std [0.         0.45020338 0.18359361 0.24426328 0.30146498 0.24859032
 0.17065474 0.22397461 0.42073796]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [ 15.97        16.32       108.90153938]
std [2.67003745 2.85615126 6.59715147]
... > Scaling features complete
... > feature mean(s):




In [28]:
model = GCNGraphPred(
    atom_input_size=feat_size["atom"],
    bond_input_size=feat_size["bond"],
    global_input_size=feat_size["global"],
    n_conv_layers=3,
    resid_n_graph_convs=2,
    target_dict={"global": ["extra_feat_global_E1_CAM"]},
    conv_fn="GraphConvDropoutBatch",
    global_pooling="Set2SetThenCat",
    dropout=0.2,
    batch_norm=False,
    activation="ReLU",
    bias=True,
    norm="both",
    aggregate="sum",
    lr=0.01,
    scheduler_name="reduce_on_plateau",
    weight_decay=0.00001,
    lr_plateau_patience=25,
    lr_scale_factor=0.8,
    loss_fn="mae",
    embedding_size=10,
    fc_layer_size=[256, 128, 128],
    fc_dropout=0.2,
    fc_batch_norm=True,
    lstm_iters=3,
    lstm_layers=2,
    pooling_ntypes=["atom", "bond", "global"],
    pooling_ntypes_direct=["global"],
)

... > number of tasks: 1


In [None]:
# basic training loop


from torch.nn import functional as F
from sklearn.metrics import r2_score
from tqdm import tqdm
import tqdm.notebook as tq
import numpy as np

# move model to cpu
model = model.cpu()

opt = torch.optim.Adam(model.parameters(), lr=0.01)
dataloader = dm.train_dataloader()

for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label["global"]
            logits = model.forward(batch_graph, batch_graph.ndata["feat"])
            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()

In [26]:
# to test migration to pytorch lightning
"""
import torch
from qtaim_embed.models.utils import LogParameters
import wandb

with wandb.init(project="qtaim_embed_test") as run:
    logger_tb = TensorBoardLogger("./test_logs", name="test_logs")
    torch.set_float32_matmul_precision("high")

    checkpoint_callback = ModelCheckpoint(
        dirpath="test_logs",
        filename="model_lightning_{epoch:02d}-{val_mae:.2f}",
        monitor="val_mae",
        mode="min",
        auto_insert_metric_name=True,
        save_last=True,
    )

    early_stopping_callback = EarlyStopping(
        monitor="val_mae",
        min_delta=0.00,
        patience=500,
        verbose=False,
        mode="min",
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")
    logger_wb = WandbLogger(name="test_logs")
    log_parameters = LogParameters()
    trainer_transfer = pl.Trainer(
        max_epochs=100,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        gradient_clip_val=3.0,
        default_root_dir="./test/",
        precision="32",
        log_every_n_steps=10,
        callbacks=[
            early_stopping_callback,
            lr_monitor,
            log_parameters,
            checkpoint_callback,
        ],
        enable_checkpointing=True,
        logger=[logger_tb, logger_wb],
    )

    # move model to gpu
    # model = model.cuda()

    trainer_transfer.fit(model, dm)"""

'import torch\nfrom qtaim_embed.models.utils import LogParameters\nimport wandb\n\nwith wandb.init(project="qtaim_embed_test") as run:\n    logger_tb = TensorBoardLogger("./test_logs", name="test_logs")\n    torch.set_float32_matmul_precision("high")\n\n    checkpoint_callback = ModelCheckpoint(\n        dirpath="test_logs",\n        filename="model_lightning_{epoch:02d}-{val_mae:.2f}",\n        monitor="val_mae",\n        mode="min",\n        auto_insert_metric_name=True,\n        save_last=True,\n    )\n\n    early_stopping_callback = EarlyStopping(\n        monitor="val_mae",\n        min_delta=0.00,\n        patience=500,\n        verbose=False,\n        mode="min",\n    )\n    lr_monitor = LearningRateMonitor(logging_interval="step")\n    logger_wb = WandbLogger(name="test_logs")\n    log_parameters = LogParameters()\n    trainer_transfer = pl.Trainer(\n        max_epochs=100,\n        accelerator="gpu",\n        devices=1,\n        enable_progress_bar=True,\n        gradient_clip

Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  7.21it/s]


-8.706716169536922 1.2563227415084839


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  6.76it/s]


-0.3983241182551056 2.0212724208831787


Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  7.79it/s]


-3.7148510374074393 1.2483052015304565


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  7.14it/s]


-4.539191972896186 1.0911214351654053


Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  7.63it/s]


-1.723811502874459 0.9582868218421936


Epoch 6: 100%|██████████| 1/1 [00:00<00:00,  8.36it/s]


-0.639770496413697 0.8587819337844849


Epoch 7: 100%|██████████| 1/1 [00:00<00:00,  7.82it/s]


-1.1261253613246844 0.9117791652679443


Epoch 8: 100%|██████████| 1/1 [00:00<00:00,  7.23it/s]


-0.95253979835489 0.8274208307266235


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  7.30it/s]


-1.4056793121236808 0.8707172870635986


Epoch 10: 100%|██████████| 1/1 [00:00<00:00,  7.12it/s]


-4.791804759288356 1.000211238861084


Epoch 11: 100%|██████████| 1/1 [00:00<00:00,  7.36it/s]


-4.491617478234931 0.8373774886131287


Epoch 12: 100%|██████████| 1/1 [00:00<00:00,  7.38it/s]


-3.7458898657745694 0.7644603252410889


Epoch 13: 100%|██████████| 1/1 [00:00<00:00,  7.60it/s]


-2.0861239177874613 0.7337958216667175


Epoch 14: 100%|██████████| 1/1 [00:00<00:00,  6.82it/s]


-1.1962182765540614 0.7722698450088501


Epoch 15: 100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


-0.31999188805274703 0.726051390171051


Epoch 16: 100%|██████████| 1/1 [00:00<00:00,  6.81it/s]


-0.12738258103716338 0.692328929901123


Epoch 17: 100%|██████████| 1/1 [00:00<00:00,  7.35it/s]


-0.31854787554838704 0.7864364385604858


Epoch 18: 100%|██████████| 1/1 [00:00<00:00,  7.89it/s]


-0.39762750609146735 0.5873829126358032


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  7.65it/s]


-0.6424235047680962 0.6314237713813782


Epoch 20: 100%|██████████| 1/1 [00:00<00:00,  8.50it/s]


-0.4283076501852314 0.6118025779724121


Epoch 21: 100%|██████████| 1/1 [00:00<00:00,  7.28it/s]


-0.8453563196924185 0.6320540904998779


Epoch 22: 100%|██████████| 1/1 [00:00<00:00,  7.57it/s]


-0.4260981578313654 0.5974157452583313


Epoch 23: 100%|██████████| 1/1 [00:00<00:00,  6.79it/s]


-0.4930299432271279 0.6182008385658264


Epoch 24: 100%|██████████| 1/1 [00:00<00:00,  7.15it/s]


-0.17325167270188224 0.5168570280075073


Epoch 25: 100%|██████████| 1/1 [00:00<00:00,  8.34it/s]


-0.30197512141280525 0.6157616972923279


Epoch 26: 100%|██████████| 1/1 [00:00<00:00,  8.49it/s]


-0.06443522568908033 0.5305404663085938


Epoch 27: 100%|██████████| 1/1 [00:00<00:00,  7.78it/s]


0.050929932884290596 0.5510839223861694


Epoch 28: 100%|██████████| 1/1 [00:00<00:00,  7.27it/s]


0.025027063218400536 0.6333790421485901


Epoch 29: 100%|██████████| 1/1 [00:00<00:00,  7.40it/s]


0.027946459581460137 0.5735399723052979


Epoch 30: 100%|██████████| 1/1 [00:00<00:00,  6.63it/s]


0.03133004618599933 0.5084571838378906


Epoch 31: 100%|██████████| 1/1 [00:00<00:00,  7.82it/s]


0.21852559408565442 0.46218088269233704


Epoch 32: 100%|██████████| 1/1 [00:00<00:00,  7.95it/s]


0.04141748729422001 0.5380799174308777


Epoch 33: 100%|██████████| 1/1 [00:00<00:00,  7.00it/s]


0.048598987120388015 0.5072619915008545


Epoch 34: 100%|██████████| 1/1 [00:00<00:00,  7.54it/s]


-0.004108584452037878 0.5699867010116577


Epoch 35: 100%|██████████| 1/1 [00:00<00:00,  7.06it/s]


0.012709744257602962 0.5088305473327637


Epoch 36: 100%|██████████| 1/1 [00:00<00:00,  7.97it/s]


0.23534292550090052 0.39804184436798096


Epoch 37: 100%|██████████| 1/1 [00:00<00:00,  6.77it/s]


0.1592162633566102 0.4461435377597809


Epoch 38: 100%|██████████| 1/1 [00:00<00:00,  7.27it/s]


0.13827649100004857 0.47478580474853516


Epoch 39: 100%|██████████| 1/1 [00:00<00:00,  6.97it/s]


0.4726753301462394 0.3412294387817383


Epoch 40: 100%|██████████| 1/1 [00:00<00:00,  7.14it/s]


0.4825239031743138 0.38162314891815186


Epoch 41: 100%|██████████| 1/1 [00:00<00:00,  7.21it/s]


0.4832985566867498 0.36891889572143555


Epoch 42: 100%|██████████| 1/1 [00:00<00:00,  7.60it/s]


0.568211023482711 0.3676278293132782


Epoch 43: 100%|██████████| 1/1 [00:00<00:00,  7.32it/s]


0.4922894158856249 0.4315352439880371


Epoch 44: 100%|██████████| 1/1 [00:00<00:00,  6.82it/s]


0.4810685229684698 0.40943554043769836


Epoch 45: 100%|██████████| 1/1 [00:00<00:00,  8.47it/s]


0.46177722633310403 0.35901570320129395


Epoch 46: 100%|██████████| 1/1 [00:00<00:00,  8.26it/s]


0.4603494783730592 0.3340469002723694


Epoch 47: 100%|██████████| 1/1 [00:00<00:00,  7.19it/s]


0.43464958825280176 0.34476080536842346


Epoch 48: 100%|██████████| 1/1 [00:00<00:00,  7.68it/s]


0.4961700937994765 0.3397064805030823


Epoch 49: 100%|██████████| 1/1 [00:00<00:00,  6.91it/s]


0.6123852892144399 0.28567981719970703


Epoch 50: 100%|██████████| 1/1 [00:00<00:00,  7.31it/s]

0.5212500520451349 0.308574378490448



