## Analyse dataset 

In [None]:
import os

from pathlib import Path

import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import pandas as pd

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.reaction import ReactionDataset

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

from model_functions import construct_model_name

In [None]:
dataset_name = "rudorff_lilienfeld_sn2_dataset"
basis_set_type = "full"
basis_set = "def2-svp"
input_foldername = Path("input") / dataset_name / basis_set_type / basis_set
train_json_filename = input_foldername / "train.json"
validate_json_filename = input_foldername / "validate.json"
model_config = Path("config") / "rudorff_lilienfeld_model.yaml"
debug_model = False
inputs = read_inputs_yaml(model_config)

model_name = construct_model_name(
    dataset_name=dataset_name,
    debug=debug_model,
)


train_dataset = ReactionDataset(
    root=get_train_data_path(model_name),
    filename=train_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
validate_dataset = ReactionDataset(
    root=get_validation_data_path(model_name),
    filename=validate_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
test_dataset = ReactionDataset(
    root=get_test_data_path(model_name),
    filename=validate_json_filename,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)

In [None]:
R = {"A": "H", "B": r"NO$_2$", "C": "CN", "D": r"CH$_3$", "E": r"NH$_2$"}
X = {"A": "F", "B": "Cl", "C": "Br"}
Y = {"A": "H", "B": "F", "C": "Cl", "D": "Br"}

def get_all_identifiers_from_dataset(dataset):
    all_identifiers = []
    for reaction in dataset:
        identifier = reaction.identifier
        _R1, _R2, _R3, _R4, _X, _Y = identifier.split("_")
        _all_identifiers = []
        for _R in [_R1, _R2, _R3, _R4]:
            _all_identifiers.append(R[_R])
        _all_identifiers.append(X[_X])
        _all_identifiers.append(Y[_Y])
        all_identifiers.append(_all_identifiers)
    return all_identifiers

train_identifiers = get_all_identifiers_from_dataset(train_dataset)
validate_identifiers = get_all_identifiers_from_dataset(validate_dataset)
test_identifiers = get_all_identifiers_from_dataset(test_dataset)

train_identifiers = np.array(train_identifiers)
validate_identifiers = np.array(validate_identifiers)
test_identifiers = np.array(test_identifiers)

# Make into dataframe
df = pd.DataFrame(columns=["dataset", "R1", "R2", "R3", "R4", "X", "Y"])
identifiers = {
    "train": train_identifiers,
    "validate": validate_identifiers,
    "test": test_identifiers,
}

for dataset_name, dataset_identifiers in identifiers.items():
    df_dataset = pd.DataFrame(
        dataset_identifiers,
        columns=["R1", "R2", "R3", "R4", "X", "Y"],
    )
    df_dataset["dataset"] = dataset_name
    df = df.append(df_dataset)

In [None]:
# For each species R1, R2 etc. count how many times it appears in each dataset
df_train = df[df["dataset"] == "train"]
df_validate = df[df["dataset"] == "validate"]
df_test = df[df["dataset"] == "test"]

df_train_counts = df_train.apply(pd.value_counts)
df_validate_counts = df_validate.apply(pd.value_counts)
df_test_counts = df_test.apply(pd.value_counts)

df_train_counts = df_train_counts.drop("train", axis=0)
df_validate_counts = df_validate_counts.drop("validate", axis=0)
df_test_counts = df_test_counts.drop("test", axis=0)

# Generate a list of colours
colours = px.colors.qualitative.Plotly

# Make into a plot
fig = go.Figure()
for dataset_name, df_dataset in zip(
    ["train", "validate", "test"],
    [df_train_counts, df_validate_counts, df_test_counts],
):
    for column in df_dataset.columns:
        colour = colours[list(df_dataset.columns).index(column)]
        fig.add_trace(
            go.Bar(
                x=df_dataset.index,
                y=df_dataset[column],
                name=f"{column}",
                marker_color=colour,
            )
        )

fig.update_layout(template="simple_white")
fig.show() 

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validate_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [None]:
for idx, data in enumerate(train_loader):
    difference = data.x_transition_state - data.x
    break
# Plot the difference
fig, ax = plt.subplots(1, 1)
cax = ax.imshow(difference, cmap="coolwarm")
fig.show()
