## Setup

In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchani
from torchani.data import TransformableIterable
import pandas as pd
!pip install seaborn
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




In [None]:
def init_aev_computer():
    Rcr = 5.2
    Rca = 3.5
    EtaR = torch.tensor([16], dtype=torch.float, device=device)
    ShfR = torch.tensor([
        0.900000, 1.168750, 1.437500, 1.706250,
        1.975000, 2.243750, 2.512500, 2.781250,
        3.050000, 3.318750, 3.587500, 3.856250,
        4.125000, 4.393750, 4.662500, 4.931250
    ], dtype=torch.float, device=device)
    EtaA = torch.tensor([8], dtype=torch.float, device=device)
    Zeta = torch.tensor([32], dtype=torch.float, device=device)
    ShfA = torch.tensor([0.90, 1.55, 2.20, 2.85], dtype=torch.float, device=device)
    ShfZ = torch.tensor([
        0.19634954, 0.58904862, 0.9817477, 1.37444680,
        1.76714590, 2.15984490, 2.5525440, 2.94524300
    ], dtype=torch.float, device=device)
    return torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species=4)

aev_computer = init_aev_computer()

In [None]:
class AtomicNet(nn.Module):
    def __init__(self, dropout_rate=None):
        super().__init__()
        layers = [
            nn.Linear(384, 256), nn.ReLU(),
            nn.Dropout(dropout_rate) if dropout_rate else nn.Identity(),
            nn.Linear(256, 192), nn.ReLU(),
            nn.Dropout(dropout_rate) if dropout_rate else nn.Identity(),
            nn.Linear(192, 128), nn.ReLU(),
            nn.Dropout(dropout_rate) if dropout_rate else nn.Identity(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Dropout(dropout_rate) if dropout_rate else nn.Identity(),
            nn.Linear(64, 1)
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


In [None]:
def load_ani_dataset(dspath): 
    energy_shifter = torchani.utils.EnergyShifter(None)
    species_order = ['H', 'C', 'N', 'O']
    dataset = torchani.data.load(dspath)
    dataset = dataset.subtract_self_energies(energy_shifter, species_order)
    dataset = dataset.species_to_indices(species_order)
    dataset = dataset.shuffle()
    return dataset

dataset = load_ani_dataset("./ani_gdb_s01_to_s04.h5")
train_data, val_data, test_data = dataset.split(0.8, 0.1, 0.1)


In [None]:
def get_data_subset(dataset, size):
    dataset = list(dataset)
    if size == 'small':
        return TransformableIterable(dataset[:5000])
    elif size == 'large':
        return TransformableIterable(dataset[:20000])
    else:
        return TransformableIterable(dataset)

train_subset = get_data_subset(train_data, 'small')
val_subset = get_data_subset(val_data, 'small')


In [None]:
def build_model(dropout, l2, device):
    nets = [AtomicNet(dropout).to(device) for _ in range(4)]
    ani_model = torchani.ANIModel(nets)
    return nn.Sequential(aev_computer, ani_model).to(device)


In [None]:
class ANITrainer:
    def __init__(self, model, batch_size=512, learning_rate=1e-3, epoch=100, l2=0.0):
        self.model = model
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2)
        self.epoch = epoch

    def train(self, train_data, val_data, early_stop=True, draw_curve=True):
        self.model.train()
        train_loader = train_data.collate(self.batch_size).cache()
        loss_func = nn.MSELoss()
        train_loss_list = []
        val_loss_list = []
        val_rmse_list = []
        best_model = None
        lowest_val_loss = float('inf')

        for ep in tqdm(range(self.epoch), desc="Epochs"):
            total_train_loss = 0.0
            for batch in train_loader:
                species = batch['species'].to(device)
                coords = batch['coordinates'].to(device)
                energies = batch['energies'].to(device).float()

                _, pred = self.model((species, coords))
                loss = loss_func(energies, pred)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_train_loss += loss.item()

            val_loss, val_rmse = self.evaluate(val_data)
            train_loss_list.append(total_train_loss)
            val_loss_list.append(val_loss)
            val_rmse_list.append(val_rmse)

            print(f"Epoch {ep+1} - Train Loss: {total_train_loss:.4f}, Val Loss: {val_loss:.4f}, RMSE: {val_rmse:.2f} kcal/mol")

            if early_stop and val_loss < lowest_val_loss:
                lowest_val_loss = val_loss
                best_model = self.model.state_dict()

        if draw_curve:
            fig, ax1 = plt.subplots()
            ax1.plot(train_loss_list, label="Train Loss")
            ax1.plot(val_loss_list, label="Val Loss")
            ax1.set_yscale("log")
            ax1.set_xlabel("Epoch")
            ax1.set_ylabel("MSE Loss")

            ax2 = ax1.twinx()
            ax2.plot(val_rmse_list, 'r--', label="Val RMSE")
            ax2.set_ylabel("RMSE (kcal/mol)", color='red')

            fig.legend(loc='upper right')
            plt.title("Training Curve")
            plt.show()

        if early_stop and best_model:
            self.model.load_state_dict(best_model)

        return train_loss_list, val_loss_list, val_rmse_list

    def evaluate(self, data, draw_plot=False):
        data_loader = data.collate(self.batch_size).cache()
        loss_func = nn.MSELoss()
        total_loss = 0.0
        true_all, pred_all = [], []

        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                species = batch['species'].to(device)
                coords = batch['coordinates'].to(device)
                true = batch['energies'].to(device).float()
                _, pred = self.model((species, coords))

                total_loss += loss_func(true, pred).item()
                true_all.append(true.cpu().numpy())
                pred_all.append(pred.cpu().numpy())

        true_all = np.concatenate(true_all).flatten()
        pred_all = np.concatenate(pred_all).flatten()
        rmse = np.sqrt(np.mean((true_all - pred_all) ** 2)) * 627.509

        if draw_plot:
            plt.scatter(true_all, pred_all, s=2)
            plt.plot([true_all.min(), true_all.max()], [true_all.min(), true_all.max()], 'r--')
            plt.xlabel("True Energy")
            plt.ylabel("Predicted Energy")
            plt.title(f"Validation RMSE: {rmse:.2f} kcal/mol")
            plt.show()

        return total_loss, rmse


In [None]:
results = []
dropout_list = [None, 0.1, 0.2, 0.3]
l2_list = [0.0, 1e-5, 1e-4]

for dropout in dropout_list:
    for l2 in l2_list:
        print(f"\nTraining model with dropout={dropout}, L2={l2}")
        model = build_model(dropout, l2, device)
        trainer = ANITrainer(model, epoch=30, l2=l2)
        _, _, val_rmse_list = trainer.train(train_subset, val_subset, early_stop=True, draw_curve=False)
        final_rmse = val_rmse_list[-1]
        results.append({'dropout': dropout if dropout else 0.0, 'l2': l2, 'val_rmse': final_rmse})

In [None]:
df = pd.DataFrame(results)
heatmap_data = df.pivot(index='dropout', columns='l2', values='val_rmse')
sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="viridis")
plt.title("Validation RMSE (kcal/mol)")
plt.xlabel("L2 Regularization")
plt.ylabel("Dropout Rate")
plt.show()
