``` 
pip install --upgrade e3nn
conda install conda-forge::ase
conda install anaconda::scikit-learn 
pip3 install torch torchvision torchaudio
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.5.1+cpu.html
pip install torch_geometric
```



In [2]:
import os
import json
from io import StringIO
from typing import Dict
import warnings
import math
import time
import logging
import datetime
import sys

import torch
import torch.nn as nn
import torch_geometric as tg
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import global_mean_pool
from torch_scatter import scatter_mean

from sklearn.model_selection import train_test_split, KFold
from ase.io import read
from ase import Atoms

from e3nn.o3 import Irreps, spherical_harmonics, Linear as E3NNLinear
from e3nn.nn import Gate
from e3nn.nn.models.gate_points_2101 import Convolution, smooth_cutoff, tp_path_exists
from e3nn.math import soft_one_hot_linspace

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set_theme(style="whitegrid")

warnings.filterwarnings('ignore', module='ase')

# --- Setup log files and plot files directories ---
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = "logs"
figures_dir = "figures"
os.makedirs(log_dir, exist_ok=True)
os.makedirs(figures_dir, exist_ok=True)
log_filename = os.path.join(log_dir, f"log_{current_time}.log")

logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s [%(levelname)s] %(message)s',
                    filename=log_filename,
                    filemode='w')

# logfile.
class LoggerWriter:
    def __init__(self, level):
        self.level = level
        self.buffer = ''
    def write(self, message):
        if message.strip():
            self.level(message.strip())
    def flush(self):
        pass

sys.stdout = LoggerWriter(logging.info)

# CustomCompose
##############################################
class CustomCompose(nn.Module):
    def __init__(self, first: nn.Module, second: nn.Module):
        super().__init__()
        self.first = first
        self.second = second
        self.irreps_in = self.first.irreps_in
        self.irreps_out = self.second.irreps_out

    def forward(self, *input):
        x = self.first(*input)
        self.first_out = x.clone()
        x = self.second(x)
        self.second_out = x.clone()
        return x


# 1. Construcción del Grafo Molecular: Crea un objeto Data de PyG a partir de un objeto ASE Atoms.
##############################################
class MolecularGraph:
    def __init__(self, cutoff: float = 5.0):
        self.cutoff = cutoff

    def build_molecular_graph(self, atoms: Atoms, type_encoding: Dict[str, int], type_onehot: torch.Tensor) -> Data:
        symbols = [sym.strip() for sym in atoms.get_chemical_symbols()]
        onehot_features = [type_onehot[type_encoding[sym]] for sym in symbols]
        x = torch.stack(onehot_features, dim=0)
        masses = torch.tensor(atoms.get_masses(), dtype=torch.float32).unsqueeze(1)
        mass_scale = 100.0
        masses = masses / mass_scale
        x = torch.cat([x, masses], dim=1).contiguous()
        pos = torch.tensor(atoms.get_positions(), dtype=torch.float32)
        from ase.neighborlist import neighbor_list
        i_idx, j_idx, _ = neighbor_list("ijS", atoms, self.cutoff)
        edge_index = torch.stack([torch.LongTensor(i_idx), torch.LongTensor(j_idx)], dim=0)
        edge_vec = pos[j_idx] - pos[i_idx]
        data = Data(x=x, pos=pos, edge_index=edge_index)
        data.edge_vec = edge_vec
        return data

# 2. Clase Dataset: bulids a dataset from a JSON file.
##############################################
class MolecularDataset(Dataset):
    def __init__(self, 
                 database_path: str = 'fd.json', 
                 target_key: str = 'reduction_potential_S1 (eV)',
                 cutoff: float = 5.0, 
                 test_size: float = 0.2):
        super().__init__()
        if not os.path.exists(database_path):
            raise FileNotFoundError(f"File '{database_path}' not found.")
        with open(database_path, 'r') as f:
            raw_db = json.load(f)
        self.database = [entry for entry in raw_db if entry.get(target_key) is not None]
        if len(self.database) == 0:
            raise ValueError(f"No valid entries found with target '{target_key}' not None.")
        all_symbols = set()
        for entry in self.database:
            xyz_string = entry["opt_molecule_S0"]
            atoms = read(StringIO(xyz_string), format='xyz')
            all_symbols.update(atoms.get_chemical_symbols())
        self.type_encoding = {sym: i for i, sym in enumerate(sorted(all_symbols))}
        self.type_onehot = torch.eye(len(self.type_encoding), dtype=torch.float32)
        self.train_idx, self.test_idx = train_test_split(range(len(self.database)), test_size=test_size, random_state=42)
        self.graph_builder = MolecularGraph(cutoff=cutoff)
        self.target_key = target_key

    def len(self):
        return len(self.database)

    def get(self, idx: int) -> Data:
        entry = self.database[idx]
        xyz_string = entry["opt_molecule_S0"]
        atoms = read(StringIO(xyz_string), format='xyz')
        data = self.graph_builder.build_molecular_graph(atoms, self.type_encoding, self.type_onehot)
        target_value = entry[self.target_key]
        data.y = torch.tensor([target_value], dtype=torch.float32)
        return data

    @property
    def train_dataset(self):
        return [self.get(i) for i in self.train_idx]

    @property
    def test_dataset(self):
        return [self.get(i) for i in self.test_idx]

# 3. PeriodicNetwork with e3nn
##############################################
class PeriodicNetwork(nn.Module):
    def __init__(self,
                 in_dim: int,
                 em_dim: int,
                 irreps_in: str,
                 irreps_out: str,
                 irreps_node_attr: str,
                 layers: int,
                 mul: int,
                 lmax: int,
                 number_of_basis: int,
                 radial_layers: int,
                 radial_neurons: int,
                 max_radius: float,
                 num_neighbors: float,
                 reduce_output: bool = True):
        super().__init__()
        self.em = E3NNLinear(Irreps(f"{in_dim}x0e"), Irreps(f"{em_dim}x0e"))
        self.irreps_in = Irreps(irreps_in) if irreps_in is not None else Irreps("0e")
        self.irreps_node_attr = Irreps(irreps_node_attr) if irreps_node_attr is not None else Irreps("0e")
        self.irreps_out = Irreps(irreps_out)
        self.irreps_hidden = Irreps([(mul, (l, p)) for l in range(lmax+1) for p in [-1, 1]])
        self.irreps_edge_attr = Irreps.spherical_harmonics(lmax)
        self.max_radius = max_radius
        self.number_of_basis = number_of_basis
        self.radial_layers = radial_layers
        self.radial_neurons = radial_neurons
        self.num_neighbors = num_neighbors
        self.reduce_output = reduce_output

        irreps = self.irreps_in
        act = {1: torch.nn.functional.silu, -1: torch.tanh}
        act_gates = {1: torch.sigmoid, -1: torch.tanh}
        self.layers = nn.ModuleList()
        for i in range(layers):
            irreps_scalars = Irreps([
                (mul, ir) for mul, ir in self.irreps_hidden
                if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)
            ])
            irreps_gated = Irreps([
                (mul, ir) for mul, ir in self.irreps_hidden
                if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)
            ])
            gate_ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o"
            irreps_gates = Irreps([(mul, gate_ir) for mul, _ in irreps_gated])
            gate = Gate(
                irreps_scalars,
                [act[ir.p] for _, ir in irreps_scalars],
                irreps_gates,
                [act_gates[ir.p] for _, ir in irreps_gates],
                irreps_gated
            )
            conv = Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                gate.irreps_in,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors
            )
            irreps = gate.irreps_out
            self.layers.append(CustomCompose(conv, gate))
        self.layers.append(
            Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                self.irreps_out,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors
            )
        )

    def preprocess(self, data: Data):
        batch = data.batch if hasattr(data, 'batch') else data.pos.new_zeros(data.pos.shape[0], dtype=torch.long)
        edge_src = data.edge_index[0]
        edge_dst = data.edge_index[1]
        edge_vec = data.edge_vec
        return batch, edge_src, edge_dst, edge_vec

    def forward(self, data: Data) -> torch.Tensor:
        batch, edge_src, edge_dst, edge_vec = self.preprocess(data)
        edge_sh = spherical_harmonics(
            self.irreps_edge_attr, edge_vec,
            normalize=True,
            normalization='component'
        )
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length,
            start=0.0,
            end=self.max_radius,
            number=self.number_of_basis,
            basis='gaussian',
            cutoff=False
        ).mul(self.number_of_basis ** 0.5)
        edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh

        if hasattr(data, 'x') and data.x is not None:
            x = self.em(data.x)
        else:
            x = data.pos.new_ones((data.pos.shape[0], self.em.out_features))
        if hasattr(data, 'z') and data.z is not None:
            z = data.z
        else:
            z = data.pos.new_ones((data.pos.shape[0], self.irreps_node_attr.dim))
        out = x
        for layer in self.layers:
            out = layer(out, z, edge_src, edge_dst, edge_attr, edge_length_embedded)
        if self.reduce_output:
            out = scatter_mean(out, batch, dim=0)
        return out

# 4. Train. evaluation and Cross-Validation.
##############################################
def train_periodic_network(dataset,
                           num_epochs=30,
                           batch_size=32,
                           device='cpu',
                           patience=10,
                           **kwargs):
    model = PeriodicNetwork(**kwargs).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
    criterion = nn.L1Loss()

    train_loader = tg.loader.DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = tg.loader.DataLoader(dataset.test_dataset, batch_size=batch_size, shuffle=False)

    best_loss = float('inf')
    no_improve = 0
    history = []
    best_state = None
    best_epoch = 1

    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            if not hasattr(batch, 'batch'):
                batch.batch = torch.zeros(batch.num_nodes, dtype=torch.long)
            batch = batch.to(device)
            optimizer.zero_grad()
            pred = model(batch).squeeze(-1)
            loss = criterion(pred, batch.y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for batch in test_loader:
                if not hasattr(batch, 'batch'):
                    batch.batch = torch.zeros(batch.num_nodes, dtype=torch.long)
                batch = batch.to(device)
                pred = model(batch).squeeze(-1)
                test_loss += criterion(pred, batch.y).item()
        test_loss /= len(test_loader)
        scheduler.step(test_loss)
        epoch_time = time.time() - start_time
        logging.info(f"Epoch {epoch+1:03d}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}, Duration = {epoch_time:.2f}s")
        history.append({'epoch': epoch+1, 'train_loss': train_loss, 'test_loss': test_loss, 'duration': epoch_time})
        if test_loss < best_loss:
            best_loss = test_loss
            best_epoch = epoch+1
            best_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                logging.info("Early stopping activated")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    logging.info(f"Training finished at epoch {best_epoch} with Test Loss = {best_loss:.4f}")
    return model, history

def evaluate_periodic_network(model, dataset, device='cpu'):
    model.eval()
    criterion = nn.L1Loss()
    test_loader = tg.loader.DataLoader(dataset.test_dataset, batch_size=1, shuffle=False)
    test_loss = 0.0
    preds = []
    truths = []
    with torch.no_grad():
        for batch in test_loader:
            if not hasattr(batch, 'batch'):
                batch.batch = torch.zeros(batch.num_nodes, dtype=torch.long)
            batch = batch.to(device)
            pred = model(batch)
            loss = criterion(pred.squeeze(-1), batch.y)
            test_loss += loss.item()
            preds.append(pred.item())
            truths.append(batch.y.item())
    test_loss /= len(test_loader)
    logging.info(f"MAE on test set: {test_loss:.4f}")
    return preds, truths, test_loss

def cross_validate_periodic_network(dataset, k=5, **train_kwargs):
    kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    all_metrics = []
    indices = list(range(len(dataset.database)))
    for fold, (train_idx, test_idx) in enumerate(kfold.split(indices), 1):
        logging.info(f"--- Fold {fold} ---")
        dataset.train_idx = train_idx.tolist()
        dataset.test_idx = test_idx.tolist()
        model, history = train_periodic_network(dataset, **train_kwargs)
        preds, truths, test_loss = evaluate_periodic_network(model, dataset, device=train_kwargs.get("device", "cpu"))
        metrics = compute_metrics(preds, truths)
        metrics["Test Loss"] = test_loss
        metrics["Fold"] = fold
        logging.info(f"Fold {fold} metrics: {metrics}")
        all_metrics.append(metrics)
    avg_metrics = {k: sum(m[k] for m in all_metrics)/len(all_metrics) for k in all_metrics[0] if k != "Fold"}
    logging.info("=== Cross-Validation Results ===")
    for key, value in avg_metrics.items():
        logging.info(f"{key}: {value:.4f}")
    logging.info("================================")
    return all_metrics, avg_metrics

# 5. metrics y plots
##############################################
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score

def compute_metrics(preds, truths):
    preds = np.array(preds)
    truths = np.array(truths)
    mae = np.mean(np.abs(preds - truths))
    rmse = np.sqrt(((preds - truths)**2).mean())
    r2 = 1 - np.sum((preds - truths)**2) / (np.sum((truths - truths.mean())**2) + 1e-9)
    mape = np.mean(np.abs((preds - truths) / (np.abs(truths) + 1e-9))) * 100
    return {'MAE': mae, 'RMSE': rmse, 'R2': r2, 'MAPE (%)': mape}

def save_and_show_plot(fig, filename):
    filepath = os.path.join(figures_dir, f"{filename}_{current_time}.png")
    fig.savefig(filepath)
    plt.close(fig)
    logging.info(f"Figure saved as {filepath}")

def plot_loss_history(history):
    epochs = [h['epoch'] for h in history]
    train_loss = [h['train_loss'] for h in history]
    test_loss  = [h['test_loss'] for h in history]
    fig, ax = plt.subplots(figsize=(8,6))
    sns.lineplot(x=epochs, y=train_loss, marker="o", label="Train Loss", ax=ax)
    sns.lineplot(x=epochs, y=test_loss, marker="s", label="Test Loss", ax=ax)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss (MAE)')
    ax.set_title('Loss Curves')
    ax.legend()
    plt.tight_layout()
    save_and_show_plot(fig, "loss_history")

def plot_parity(preds, truths):
    preds = np.array(preds)
    truths = np.array(truths)
    fig, ax = plt.subplots(figsize=(6,6))
    sns.scatterplot(x=truths, y=preds, edgecolor='k', s=80, alpha=0.7, ax=ax)
    min_val = min(truths.min(), preds.min())
    max_val = max(truths.max(), preds.max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', label='Identity')
    ax.set_xlabel('True Value')
    ax.set_ylabel('Predicted Value')
    ax.set_title('Parity Plot')
    ax.legend()
    plt.tight_layout()
    save_and_show_plot(fig, "parity_plot")

def plot_error_histogram(preds, truths, bins=20):
    errors = np.array(preds) - np.array(truths)
    fig, ax = plt.subplots(figsize=(8,6))
    sns.histplot(errors, bins=bins, kde=True, color='purple', edgecolor='black', ax=ax)
    ax.set_xlabel('Error (Predicted - True)')
    ax.set_ylabel('Frequency')
    ax.set_title('Error Histogram')
    plt.tight_layout()
    save_and_show_plot(fig, "error_histogram")

def plot_target_distribution(dataset):
    train_targets = [data.y.item() for data in dataset.train_dataset]
    test_targets  = [data.y.item() for data in dataset.test_dataset]
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12,5))
    sns.histplot(train_targets, bins=20, color='blue', edgecolor='black', kde=True, ax=ax1)
    ax1.set_title('Training Target Distribution')
    ax1.set_xlabel('Target Value')
    ax1.set_ylabel('Count')
    sns.histplot(test_targets, bins=20, color='green', edgecolor='black', kde=True, ax=ax2)
    ax2.set_title('Test Target Distribution')
    ax2.set_xlabel('Target Value')
    ax2.set_ylabel('Count')
    plt.tight_layout()
    save_and_show_plot(fig, "target_distribution")

def plot_cv_line_metrics(all_metrics):
    df = pd.DataFrame(all_metrics)
    metrics_to_plot = [col for col in df.columns if col != "Fold"]
    num_metrics = len(metrics_to_plot)
    fig, axes = plt.subplots(1, num_metrics, figsize=(6*num_metrics, 5))
    if num_metrics == 1:
        axes = [axes]
    for ax, metric in zip(axes, metrics_to_plot):
        sns.pointplot(x="Fold", y=metric, data=df, ax=ax, markers="o", linestyles="--")
        ax.set_title(f"{metric} per Fold")
        ax.set_xlabel("Fold")
        ax.set_ylabel(metric)
        mean_val = df[metric].mean()
        ax.axhline(mean_val, color="red", linestyle="--", label=f"Average = {mean_val:.2f}")
        ax.legend()
    plt.tight_layout()
    save_and_show_plot(fig, "cv_line_metrics")

def plot_model_statistics(model):
    weights = []
    grads = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            weights.append(param.detach().cpu().numpy().flatten())
            if param.grad is not None:
                grads.append(param.grad.detach().cpu().numpy().flatten())
    if weights:
        fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12,5))
        sns.histplot(np.concatenate(weights), bins=50, color='blue', kde=True, ax=ax1)
        ax1.set_title("Weights Distribution")
        ax1.set_xlabel("Value")
        ax1.set_ylabel("Frequency")
        if grads:
            sns.histplot(np.concatenate(grads), bins=50, color='red', kde=True, ax=ax2)
            ax2.set_title("Gradients Distribution")
            ax2.set_xlabel("Value")
            ax2.set_ylabel("Frequency")
        plt.tight_layout()
        save_and_show_plot(fig, "model_statistics")
    else:
        logging.info("No parameters found for visualization.")

# --- Model Accuracy"
def compute_model_accuracy(preds, truths, tol=0.1):
    """
    Define la "accuracy" del modelo como el porcentaje de predicciones cuya
    diferencia relativa es menor que tol (por defecto 10%).
    """
    preds = np.array(preds)
    truths = np.array(truths)
    accuracy = np.mean(np.abs(preds - truths) / (np.abs(truths) + 1e-9) < tol)
    return accuracy * 100

def plot_model_accuracy(accuracy, tol=0.1):
    fig, ax = plt.subplots(figsize=(6,4))
    ax.bar(["Model Accuracy"], [accuracy], color='teal')
    ax.set_ylim([0,100])
    ax.set_ylabel("Accuracy (%)")
    ax.set_title(f"Model Accuracy (Relative Error < {tol*100:.0f}%)")
    plt.tight_layout()
    save_and_show_plot(fig, "model_accuracy")

def print_summary(history, metrics, dataset, hyperparams: dict, target_name="reduction_potential_S1 (eV)"):
    total_data = len(dataset)
    train_data = len(dataset.train_dataset)
    test_data  = len(dataset.test_dataset)
    best_entry = min(history, key=lambda x: x['test_loss'])
    summary = (
        f"Execution Date: {datetime.datetime.now()}\n"
        "===== Training Summary =====\n"
        f"Total data points         : {total_data}\n"
        f"Training data points      : {train_data}\n"
        f"Test data points          : {test_data}\n"
        f"Target variable           : {target_name}\n"
        f"Hyperparameters:\n"
    )
    for key, value in hyperparams.items():
        summary += f"   {key}: {value}\n"
    summary += (
        f"Best Epoch                : {best_entry['epoch']} (Duration: {best_entry.get('duration', 0):.2f}s)\n"
        f"Train Loss (MAE)          : {best_entry['train_loss']:.4f}\n"
        f"Test Loss (MAE)           : {best_entry['test_loss']:.4f}\n"
        "----- Final Test Metrics -----\n"
    )
    for k, v in metrics.items():
        summary += f"{k:12}: {v:.4f}\n"
    summary += "================================\n"
    summary_filename = os.path.join(log_dir, f"summary_{current_time}.txt")
    with open(summary_filename, "w") as f:
        f.write(summary)
    logging.info(summary)
    print(summary)

def evaluate_full_dataset(model, dataset, device='cpu'):
    model.eval()
    all_data = dataset.train_dataset + dataset.test_dataset
    loader = tg.loader.DataLoader(all_data, batch_size=1, shuffle=False)
    preds, truths = [], []
    criterion = nn.L1Loss()
    total_loss = 0.0
    with torch.no_grad():
        for data in loader:
            if not hasattr(data, 'batch'):
                data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
            data = data.to(device)
            pred = model(data).squeeze(-1)
            loss = criterion(pred, data.y)
            total_loss += loss.item()
            preds.append(pred.item())
            truths.append(data.y.item())
    total_loss /= len(loader)
    metrics = compute_metrics(preds, truths)
    logging.info("\n===== Full Dataset Evaluation =====")
    logging.info(f"Total data points: {len(all_data)}")
    logging.info(f"Average MAE     : {total_loss:.4f}")
    for k, v in metrics.items():
        logging.info(f"{k:12}: {v:.4f}")
    logging.info("===================================\n")
    print("\n===== Full Dataset Evaluation =====")
    print(f"Total data points: {len(all_data)}")
    print(f"Average MAE     : {total_loss:.4f}")
    for k, v in metrics.items():
        print(f"{k:12}: {v:.4f}")
    print("===================================\n")
    return preds, truths, total_loss, metrics

# 6. Molecular graph.
##############################################
def plot_structure_and_graph(dataset, index: int):
    entry = dataset.database[index]
    xyz_string = entry["opt_molecule_S0"]
    atoms = read(StringIO(xyz_string), format='xyz')
    data = dataset.graph_builder.build_molecular_graph(atoms, dataset.type_encoding, dataset.type_onehot)
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))
    pos_atoms = atoms.get_positions()
    axs[0].scatter(pos_atoms[:, 0], pos_atoms[:, 1], s=100, color='skyblue', edgecolor='k')
    for i, sym in enumerate(atoms.get_chemical_symbols()):
        axs[0].annotate(sym, (pos_atoms[i, 0], pos_atoms[i, 1]), textcoords="offset points", xytext=(5, 5))
    axs[0].set_title("Molecular Structure (XY Projection)")
    axs[0].set_xlabel("X")
    axs[0].set_ylabel("Y")
    pos_graph = data.pos.cpu().numpy()
    axs[1].scatter(pos_graph[:, 0], pos_graph[:, 1], s=100, color='lightgreen', edgecolor='k')
    edge_index = data.edge_index.cpu().numpy()
    for i in range(edge_index.shape[1]):
        start = pos_graph[edge_index[0, i]]
        end = pos_graph[edge_index[1, i]]
        axs[1].plot([start[0], end[0]], [start[1], end[1]], color='gray', alpha=0.7)
    axs[1].set_title("Molecular Graph")
    axs[1].set_xlabel("X")
    axs[1].set_ylabel("Y")
    plt.tight_layout()
    save_and_show_plot(fig, "structure_and_graph")

# 7. Script Principal
##############################################
if __name__ == "__main__":
    dataset = MolecularDataset(
        database_path='test.json',
        target_key='reduction_potential_S1 (eV)',
        cutoff=5.0,
        test_size=0.3
    )
    in_dim_actual = len(dataset.type_encoding) + 1
    logging.info(f"Detected one-hot dimension (with mass): {in_dim_actual}")
    logging.info(f"Total data points: {len(dataset)} | Train: {len(dataset.train_dataset)} | Test: {len(dataset.test_dataset)}")

    hyperparams = {
        "num_epochs": 20,
        "batch_size": 32,
        "device": "cpu",
        "patience": 10,
        "in_dim": in_dim_actual,
        "em_dim": 64,
        "irreps_in": "64x0e",
        "irreps_out": "1x0e",
        "irreps_node_attr": "64x0e",
        "layers": 3,
        "mul": 32,
        "lmax": 1,
        "number_of_basis": 20,
        "radial_layers": 1,
        "radial_neurons": 100,
        "max_radius": 5.0,
        "num_neighbors": 12.0,
        "reduce_output": True
    }

    model, history = train_periodic_network(dataset, **hyperparams)
    torch.save(model.state_dict(), "trained_model.pt")
    logging.info("Trained model saved as 'trained_model.pt'.")

    preds, truths, test_mae = evaluate_periodic_network(model, dataset, device=hyperparams["device"])
    logging.info(f"Final MAE on test set: {test_mae:.4f}")
    metrics = compute_metrics(preds, truths)
    print_summary(history, metrics, dataset, hyperparams, target_name="homo (eV)")

    model_accuracy = compute_model_accuracy(preds, truths, tol=0.1)
    logging.info(f"Model Accuracy (relative error < 10%): {model_accuracy:.2f}%")
    plot_model_accuracy(model_accuracy, tol=0.1)
    # (5-fold)
    all_cv_metrics, avg_cv_metrics = cross_validate_periodic_network(dataset, k=5, **{
        "num_epochs": 30,
        "batch_size": 32,
        "device": "cpu",
        "patience": 30,
        "in_dim": in_dim_actual,
        "em_dim": 64,
        "irreps_in": "64x0e",
        "irreps_out": "1x0e",
        "irreps_node_attr": "64x0e",
        "layers": 3,
        "mul": 32,
        "lmax": 1,
        "number_of_basis": 10,
        "radial_layers": 1,
        "radial_neurons": 100,
        "max_radius": 5.0,
        "num_neighbors": 12.0,
        "reduce_output": True
    })
    plot_cv_line_metrics(all_cv_metrics)
    full_preds, full_truths, full_loss, full_metrics = evaluate_full_dataset(model, dataset, device=hyperparams["device"])

    plot_loss_history(history)
    plot_parity(preds, truths)
    plot_error_histogram(preds, truths, bins=20)
    plot_target_distribution(dataset)
    plot_model_statistics(model)
    plot_structure_and_graph(dataset, index=0)

    logging.info("Trained Model:")
    logging.info(str(model))
    print("Trained Model:")
    print(model)
