In [1]:
import wandb, argparse, torch, json
import numpy as np
from copy import deepcopy
from qtaim_embed.models.utils import LogParameters, load_graph_level_model_from_config
from qtaim_embed.utils.data import get_default_graph_level_config
from qtaim_embed.core.datamodule import QTAIMGraphTaskDataModule
import pytorch_lightning as pl
from qtaim_embed.core.dataset import HeteroGraphGraphLabelDataset

import pytorch_lightning as pl
from qtaim_embed.models.graph_level.base_gcn import GCNGraphPred
from qtaim_embed.utils.data import get_default_graph_level_config
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    EarlyStopping,
    ModelCheckpoint,
)

from qtaim_embed.models.utils import load_graph_level_model_from_config
from qtaim_embed.core.datamodule import QTAIMGraphTaskDataModule

torch.set_float32_matmul_precision("high")  # might have to disable on older GPUs

In [2]:
# torch.multiprocessing.set_sharing_strategy("file_system")

import pandas as pd

libe_loc = "../../../data/libe_qtaim_1029_labelled.pkl"
df = pd.read_pickle(libe_loc)
df["shifted_rrho_ev_free_energy"]

0       -12391.737598
1        -9956.767902
2       -16909.526727
3        -9539.808335
4        -7278.828211
             ...     
17152   -30509.130168
17153    -8436.280139
17154    -6406.817348
17155    -9954.759794
17156   -26718.224922
Name: shifted_rrho_ev_free_energy, Length: 17157, dtype: float64

In [35]:
config = get_default_graph_level_config()
on_gpu = bool(True)
debug = bool(False)

In [36]:
dataset_loc = "../../../data/libe_qtaim_1029_labelled.pkl"
log_save_dir = "./libe_dev/"

config["dataset"]["log_scale_features"] = True
config["dataset"]["debug"] = False
config["dataset"]["standard_scale_features"] = True
config["dataset"]["target_list"] = ["shifted_rrho_ev_free_energy"]
config["dataset"]["train_batch_size"] = 512
config["dataset"]["allowed_charges"] = [-1, 0, 1]
config["dataset"]["allowed_spins"] = [1, 2, 3]
config["dataset"]["extra_keys"] = {
    "atom": ["extra_feat_atom_esp_total"],
    "bond": [
        "extra_feat_bond_esp_total",
        "bond_length",
    ],
    "global": ["shifted_rrho_ev_free_energy", "charge", "spin"],
}
# set log save dir/
config["dataset"]["log_save_dir"] = log_save_dir
# dataset
if dataset_loc is not None:
    config["dataset"]["train_dataset_loc"] = dataset_loc
extra_keys = config["dataset"]["extra_keys"]
if debug:
    config["dataset"]["debug"] = debug


if config["optim"]["precision"] == "16" or config["optim"]["precision"] == "32":
    config["optim"]["precision"] = int(config["optim"]["precision"])

dm = QTAIMGraphTaskDataModule(config=config)
feature_names, feature_size = dm.prepare_data(stage="fit")

... > creating MoleculeWrapper objects


100%|██████████| 17157/17157 [00:01<00:00, 13033.01it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'H', 'F', 'O', 'P', 'S', 'Li', 'N', 'C'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'bond_length']
selected global keys ['shifted_rrho_ev_free_energy', 'charge', 'spin']
... > Building graphs and featurizing


100%|██████████| 17157/17157 [00:26<00:00, 655.46it/s]


included in labels
{'global': ['shifted_rrho_ev_free_energy']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_H', 'chemical_symbol_F', 'chemical_symbol_O', 'chemical_symbol_P', 'chemical_symbol_S', 'chemical_symbol_Li', 'chemical_symbol_N', 'chemical_symbol_C', 'extra_feat_atom_esp_total'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_total'], 'global': ['num atoms', 'num bonds', 'molecule weight', 'charge one hot', 'charge one hot', 'charge one hot', 'spin one hot', 'spin one hot', 'spin one hot']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 17157/17157 [00:00<00:00, 34776.61it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [8.18990584e-01 1.22892184e-01 7.86921816e-02 6.38917142e-03
 4.93583899e-02 1.25445927e-02 8.69699060e-03 1.70303698e-03
 2.12144979e-01 4.27985438e-02 1.69057031e-01 1.17913977e-02
 4.44496363e-03 4.64198163e-02 1.90339427e-03 2.04587057e-01
 1.10351290e+01]
std [0.52944732 0.31076984 0.21989271 0.06624058 0.1782593  0.09240066
 0.07715346 0.03431552 0.31944045 0.16683517 0.29765941 0.08963335
 0.05532862 0.17326559 0.03627271 0.31615357 5.02713615]
mean [0.         0.09399185 0.00773844 0.05891471 0.01545498 0.01059545
 0.00226983 0.85905631 0.74908361]
std [0.         0.23730933 0.07282854 0.19330189 0.10234118 0.08504081
 0.03960015 0.34912601 0.31340345]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [2.40839803 2.24223

In [38]:
config["model"] = {
    "n_conv_layers": 4,
    "resid_n_graph_convs": 1,
    "conv_fn": "GraphConvDropoutBatch",
    "global_pooling_fn": "SumPoolingThenCat",
    "dropout": 0.2,
    "batch_norm": True,
    "activation": "ReLU",
    "bias": True,
    "norm": "both",
    "aggregate": "sum",
    "lr": 0.0002,
    "scheduler_name": "reduce_on_plateau",
    "weight_decay": 0.00001,
    "lr_plateau_patience": 50,
    "lr_scale_factor": 0.75,
    "loss_fn": "mse",
    "embedding_size": 50,
    "shape_fc": "cone",
    "fc_hidden_size_1": 512,
    "fc_num_layers": 1,
    "fc_dropout": 0.2,
    "fc_batch_norm": True,
    "lstm_iters": 3,
    "lstm_layers": 2,
    "output_dims": 1,
    "pooling_ntypes": ["atom", "bond", "global"],
    "pooling_ntypes_direct": ["global"],
    "restore": False,
    "max_epochs": 10,
    "classifier": False,
    "target_dict": {},
}


print(">" * 40 + "config_settings" + "<" * 40)
# for k, v in config.items():
#    print("{}\t\t\t{}".format(str(k).ljust(20), str(v).ljust(20)))

config["model"]["atom_feature_size"] = feature_size["atom"]
config["model"]["bond_feature_size"] = feature_size["bond"]
config["model"]["global_feature_size"] = feature_size["global"]
config["model"]["target_dict"]["global"] = config["dataset"]["target_list"]
model = load_graph_level_model_from_config(config["model"])

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>config_settings<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
:::REGRESSION MODEL:::


In [26]:
libe_loc = "../../../data/libe_qtaim_1029_labelled.pkl"
full_dataset = HeteroGraphGraphLabelDataset(
    file=libe_loc,
    standard_scale_features=config["dataset"]["standard_scale_features"],
    log_scale_features=config["dataset"]["log_scale_features"],
    allowed_ring_size=config["dataset"]["allowed_ring_size"],
    allowed_charges=config["dataset"]["allowed_charges"],
    allowed_spins=config["dataset"]["allowed_spins"],
    self_loop=True,
    debug=debug,
    extra_keys=config["dataset"]["extra_keys"],
    target_list=config["dataset"]["target_list"],
    extra_dataset_info=config["dataset"]["extra_dataset_info"],
)

... > running in debug mode
... > creating MoleculeWrapper objects


100%|██████████| 100/100 [00:00<00:00, 11085.48it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'H', 'F', 'O', 'P', 'S', 'Li', 'N', 'C'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'bond_length']
selected global keys ['shifted_rrho_ev_free_energy', 'charge']
... > Building graphs and featurizing


100%|██████████| 100/100 [00:00<00:00, 656.06it/s]


included in labels
{'global': ['shifted_rrho_ev_free_energy']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_H', 'chemical_symbol_F', 'chemical_symbol_O', 'chemical_symbol_P', 'chemical_symbol_S', 'chemical_symbol_Li', 'chemical_symbol_N', 'chemical_symbol_C', 'extra_feat_atom_esp_total'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_total'], 'global': ['num atoms', 'num bonds', 'molecule weight', 'charge one hot', 'charge one hot', 'charge one hot']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 100/100 [00:00<00:00, 32135.34it/s]

original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [8.34848989e-01 1.23115579e-01 9.71204306e-02 5.98687586e-03
 5.92035501e-02 2.32822950e-02 3.99125057e-03 4.65645900e-03
 2.04884196e-01 4.05777141e-02 1.72954191e-01 1.06433349e-02
 3.32604214e-03 5.18862574e-02 1.99562529e-03 2.06879821e-01
 1.10942913e+01]
std [0.51957164 0.30872377 0.24059587 0.06414003 0.19373103 0.12488391
 0.05244611 0.05662092 0.31628685 0.16272608 0.29994926 0.08522979
 0.04789963 0.18240787 0.03713865 0.3171733  4.94329215]
mean [0.         0.11384652 0.0069702  0.07125088 0.02710631 0.0046468
 0.00542126 0.8582797  0.75317434]
std [0.         0.25680998 0.0691577  0.21050098 0.13436485 0.05656254
 0.06106016 0.34376201 0.31424738]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [2.36313864 2.208579




In [39]:
"""
import numpy as np
import torch.optim as optim
from torch import nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# import softmax

dataloader = dm.train_dataloader()
correct_list = []

for epoch in range(10):  # loop over the dataset multiple times
    loss = 0.0
    count = 0
    correct_count = 0
    # iterate
    for ind, batch in enumerate(dm.train_dataloader()):
        loss += model.shared_step(batch, "train")
    print("Epoch: {} Loss: {}".format(epoch, loss))
    """

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True,
    gradient_clip_val=3.0,
    default_root_dir="./test/",
    precision="32",
    log_every_n_steps=10,
)

# trainer.fit(model, dm)
from qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask

data_loader_manual = DataLoaderMoleculeGraphTask(
    dataset=full_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
)
trainer.fit(model, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type               | Params
--------------------------------------------------------
0  | embedding       | UnifySize          | 1.8 K 
1  | conv_layers     | ModuleList         | 95.4 K
2  | readout         | SumPoolingThenCat  | 0     
3  | loss            | MultioutputWrapper | 0     
4  | fc_layers       | ModuleList         | 78.8 K
5  | train_r2        | MultioutputWrapper | 0     
6  | train_torch_l1  | MultioutputWrapper | 0     
7  | train_torch_mse | MultioutputWrapper | 0     
8  | val_r2          | MultioutputWrapper | 0     
9  | val_torch_l1    | MultioutputWrapper | 0     
10 | val_torch_mse   | MultioutputWrapper | 0     
11 | test_r2         | MultioutputWrapper | 0     
12 | test_torch_l1   | MultioutputWrapper | 0     
13 | test_torch_mse  | Multioutp

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [15]:
# basic training loop
from torch.nn import functional as F
from sklearn.metrics import r2_score
from tqdm import tqdm
import tqdm.notebook as tq
import numpy as np

# move model to cpu
model = model.cpu()

opt = torch.optim.Adam(model.parameters(), lr=0.01)
dataloader = dm.train_dataloader()

for epoch in range(10):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label["global"]

            logits = model.forward(batch_graph, batch_graph.ndata["feat"])
            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            # r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            # r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()

  loss = F.mse_loss(logits, labels)
Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  4.27it/s]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


nan 0.1493707299232483


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  4.20it/s]


nan 3.065988302230835


Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  4.31it/s]


nan 0.5086823105812073


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  4.20it/s]


nan 0.5653993487358093


Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  4.35it/s]


nan 1.0022978782653809


Epoch 6: 100%|██████████| 1/1 [00:00<00:00,  4.27it/s]


nan 0.6530622243881226


Epoch 7: 100%|██████████| 1/1 [00:00<00:00,  4.35it/s]


nan 0.16519446671009064


Epoch 8: 100%|██████████| 1/1 [00:00<00:00,  4.19it/s]


nan 0.23297515511512756


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  4.22it/s]


nan 0.3134811818599701


Epoch 10: 100%|██████████| 1/1 [00:00<00:00,  3.99it/s]


nan 0.28704142570495605


Epoch 11: 100%|██████████| 1/1 [00:00<00:00,  4.29it/s]


nan 0.20579640567302704


Epoch 12: 100%|██████████| 1/1 [00:00<00:00,  4.05it/s]


nan 0.16815496981143951


Epoch 13: 100%|██████████| 1/1 [00:00<00:00,  4.26it/s]


nan 0.21025310456752777


Epoch 14:   0%|          | 0/1 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f5e9ac922a0>
Traceback (most recent call last):
  File "/home/santiagovargas/anaconda3/envs/qtaim_embed/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/santiagovargas/anaconda3/envs/qtaim_embed/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/santiagovargas/anaconda3/envs/qtaim_embed/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/santiagovargas/anaconda3/envs/qtaim_embed/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/santiagovargas/anaconda3/envs/qtaim_embed/lib/py

nan 0.11412043124437332


Epoch 15: 100%|██████████| 1/1 [00:00<00:00,  4.24it/s]


nan 0.20004548132419586


Epoch 16: 100%|██████████| 1/1 [00:00<00:00,  4.32it/s]


nan 0.14991675317287445


Epoch 17: 100%|██████████| 1/1 [00:00<00:00,  4.30it/s]


nan 0.0520099513232708


Epoch 18: 100%|██████████| 1/1 [00:00<00:00,  4.31it/s]


nan 0.07179593294858932


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  3.74it/s]


nan 0.08888241648674011


Epoch 20: 100%|██████████| 1/1 [00:00<00:00,  4.30it/s]


nan 0.08803244680166245


Epoch 21: 100%|██████████| 1/1 [00:00<00:00,  4.09it/s]


nan 0.1588078737258911


Epoch 22: 100%|██████████| 1/1 [00:00<00:00,  4.08it/s]


nan 0.17506250739097595


Epoch 23: 100%|██████████| 1/1 [00:00<00:00,  4.24it/s]


nan 0.06662168353796005


Epoch 24: 100%|██████████| 1/1 [00:00<00:00,  4.28it/s]


nan 0.03884604200720787


Epoch 25: 100%|██████████| 1/1 [00:00<00:00,  4.07it/s]


nan 0.06486049294471741


Epoch 26: 100%|██████████| 1/1 [00:00<00:00,  4.32it/s]


nan 0.08747197687625885


Epoch 27: 100%|██████████| 1/1 [00:00<00:00,  4.05it/s]


nan 0.06181082874536514


Epoch 28: 100%|██████████| 1/1 [00:00<00:00,  4.22it/s]


nan 0.07576818764209747


Epoch 29: 100%|██████████| 1/1 [00:00<00:00,  4.33it/s]


nan 0.08073711395263672


Epoch 30: 100%|██████████| 1/1 [00:00<00:00,  4.47it/s]


nan 0.09973303973674774


Epoch 31: 100%|██████████| 1/1 [00:00<00:00,  4.06it/s]


nan 0.09223935753107071


Epoch 32: 100%|██████████| 1/1 [00:00<00:00,  4.16it/s]


nan 0.10336043685674667


Epoch 33: 100%|██████████| 1/1 [00:00<00:00,  4.15it/s]


nan 0.1202014908194542


Epoch 34: 100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


nan 0.060065414756536484


Epoch 35: 100%|██████████| 1/1 [00:00<00:00,  4.12it/s]


nan 0.11973586678504944


Epoch 36: 100%|██████████| 1/1 [00:00<00:00,  4.27it/s]


nan 0.13905350863933563


Epoch 37: 100%|██████████| 1/1 [00:00<00:00,  4.18it/s]


nan 0.1161910891532898


Epoch 38: 100%|██████████| 1/1 [00:00<00:00,  4.00it/s]


nan 0.03109150566160679


Epoch 39: 100%|██████████| 1/1 [00:00<00:00,  4.23it/s]


nan 0.06885021179914474


Epoch 40: 100%|██████████| 1/1 [00:00<00:00,  4.38it/s]


nan 0.08317699283361435


Epoch 41: 100%|██████████| 1/1 [00:00<00:00,  4.19it/s]


nan 0.11143694818019867


Epoch 42: 100%|██████████| 1/1 [00:00<00:00,  3.98it/s]


nan 0.05883011594414711


Epoch 43: 100%|██████████| 1/1 [00:00<00:00,  4.34it/s]


nan 0.03514017537236214


Epoch 44: 100%|██████████| 1/1 [00:00<00:00,  4.00it/s]


nan 0.04997489973902702


Epoch 45: 100%|██████████| 1/1 [00:00<00:00,  4.13it/s]


nan 0.043268490582704544


Epoch 46: 100%|██████████| 1/1 [00:00<00:00,  4.27it/s]


nan 0.04349842295050621


Epoch 47: 100%|██████████| 1/1 [00:00<00:00,  4.23it/s]


nan 0.08937433362007141


Epoch 48: 100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


nan 0.04116618633270264


Epoch 49: 100%|██████████| 1/1 [00:00<00:00,  4.01it/s]


nan 0.03423915058374405


Epoch 50: 100%|██████████| 1/1 [00:00<00:00,  4.11it/s]

nan 0.049261752516031265





In [18]:
# basic training loop
from torch.nn import functional as F
from sklearn.metrics import r2_score
from tqdm import tqdm
import tqdm.notebook as tq
import numpy as np

# move model to cpu
model = model.cpu()

opt = torch.optim.Adam(model.parameters(), lr=0.01)
dataloader = dm.train_dataloader()

for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        for step, (batch_graph, batch_label) in enumerate(tq):
            # loss = model.shared_step((batch_graph, batch_label), "train")

            logits = model.forward(
                batch_graph, batch_graph.ndata["feat"]
            )  # returns a dict of node types
            labels = batch_label["global"]
            logits = logits.view(-1, model.hparams.output_dims)
            labels = labels.view(-1, model.hparams.output_dims)
            all_loss = model.compute_loss(logits, labels)

            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            # r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            # r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: shape '[-1, 2]' is invalid for input of size 75

In [19]:
print(model.hparams.output_dims)

2


TypeError: 'DataLoaderMoleculeGraphTask' object is not an iterator