## Analyse dataset 

In [None]:
import os

from pathlib import Path

import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.reaction import ReactionDataset

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

from model_functions import construct_model_name

In [None]:
input_foldername = Path("input")
train_json_filename = input_foldername / "train.json"
validate_json_filename = input_foldername / "validate.json"
dataset_name = "rudorff_lilienfeld_sn2_dataset"
debug_model = False
model_config = Path("config") / "grambow_green_message_passing.yaml"
inputs = read_inputs_yaml(model_config)

model_name = construct_model_name(
    dataset_name=dataset_name,
    debug=debug_model,
)

basis_set_type = "full"

train_dataset = ReactionDataset(
    root=get_train_data_path(model_name),
    filename=train_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
validate_dataset = ReactionDataset(
    root=get_validation_data_path(model_name),
    filename=validate_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
test_dataset = ReactionDataset(
    root=get_test_data_path(model_name),
    filename=validate_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)