In [None]:
import sys

sys.path.append("../src/")

from torch_geometric.data import Data
from data.datamodule import DataModule
from models.graph_mol import GraphModel
from lightning.pytorch.utilities.model_summary import ModelSummary
from tqdm import tqdm
import torch
import pickle

# Find atom types

In [None]:
z = set()
indices = [_ for _ in range(1, 133886)]
for i in tqdm(indices):
    with open(f"../data/processed/{i:0>6}.pkl", "rb") as fio:
        z |= set(pickle.load(fio)["z"])
print(z)

# Test train split

In [None]:
with open("../data/partition.pkl", "rb") as fio:
    partition = pickle.load(fio)

In [None]:
print(
    f"Data:\n"
    f"Train: {len(partition['train'])}\n"
    f"Validation: {len(partition['val'])}\n"
    f"Test: {len(partition['test'])}"
)

# Data loader

In [None]:
datamodule = DataModule(folder="../data")

In [None]:
trainloader = datamodule.train_dataloader()

In [None]:
data = trainloader.dataset[1]

for batch in trainloader:
    break

In [None]:
data

In [None]:
batch

# Model

In [None]:
graphmol = GraphModel(5, 32, 10, 1)

In [None]:
ModelSummary(graphmol)

In [None]:
G, gap, charge = graphmol(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

In [None]:
G, gap, charge = graphmol(data.x, data.edge_index, data.edge_attr)

In [None]:
graphmol.training_step(batch)

## on GPU

In [None]:
graphmol.to(torch.device("cuda:0"))
batch.to(torch.device("cuda:0"))
loss = graphmol.training_step(batch)

In [None]:
loss.backward()