### This workbook is used for evaluating UFR models and training additional molecules onto existing models

**Oliver Xie - Olsen Lab, Massachusetts Institute of Technology, 2025**

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import warnings
import ufr.util.data_processing as data_processing
import ufr.util.model_launch as model_launch
from ufr.model.idac_model_add_molecules import UNIQUAC_add_molecules, mod_UNIQUAC_add_molecules, Wilson_add_molecules, NRTL_add_molecules, EarlyStopper

# Disable prototype warnings
warnings.filterwarnings(action='ignore', category=UserWarning)
# Disable future deprecation warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Specify to use a specific GPU device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 1. Evaluate a trained model on IDAC data

In [None]:
# Open the data file
eval_data_file = "./data/opensource_IDAC_data.csv"
prop_file = "./data/small_molecule_prop.csv"

df_eval = pd.read_csv(eval_data_file, index_col = 0)
df_prop = pd.read_csv(prop_file, index_col = 0)

solute_smiles = 'Solute SMILES' # Specify the column name in df_eval
solvent_smiles = 'Solvent SMILES' # Specify the column name in df_eval

# Specify parts of the model
combinatorial_layer = 'mod_FH' # Choose between FH, FV, SG, GK-FV, mod_FH, mod_FV
residual_layer = 'mod_UNIQUAC' # Choose between UNIQUAC, mod_UNIQUAC, Wilson, NRTL
association_layer = 'wertheim' # Choose between none, wertheim
dimension = 12
temp_type = 'invT'
temp_dim = 2
sobolev = 0
trial = 5

model_name = f'UFR_{combinatorial_layer}_{residual_layer}_{association_layer}_{dimension}D_{temp_type}_{temp_dim}_sobolev_{sobolev}_{trial}'
model_file = f"./trained_models/models_from_paper/{model_name}.h5" # Change the filepath to the model file you want to use

# We need to specify which IDAC is the correct one when removing the combinatorial layer - this is the IDAC to use for the slope with 1/T calculation
ln_y_label = f'ln_gamma_res_{combinatorial_layer}' # Do it for just this for now
comb_label = f'comb_{combinatorial_layer}' # This is name of the column containing the calculated combinatorial value

**Clean data (if desired)**

In [None]:
# Certain datasets report deuterated water but RDKit converts it to water. There are slight differences in IDAC for D2O. Manually convert these to a different SMILES name.
deuterated_water = 'Deuterium oxide <Heavy water>'
df_eval.loc[df_eval['Solute'] == deuterated_water, solute_smiles] = '[2H]O[2H]'
df_eval.loc[df_eval['Solvent'] == deuterated_water, solvent_smiles] = '[2H]O[2H]'

# Filter out duplicates and drop self-edges
self_edges = df_eval[df_eval[solute_smiles] == df_eval[solvent_smiles]].index
df_eval = df_eval.drop(self_edges)

# Drop duplicates
df_eval = data_processing.drop_duplicates(df_eval)

# Set flag on whether to clean the data or not
clean_temp_outliers = True # False or True

# If True, clean the data
if clean_temp_outliers:
    # Load the cleaning rules
    cleaning_rules_file = './data/cleaning_rules.xlsx'
    df_clean_rules = pd.read_excel(cleaning_rules_file, sheet_name = 'cleaning rules')

    # Clean according to the cleaning rules
    df_eval = data_processing.apply_cleaning(df_eval, df_clean_rules)

    # Drop all outliers. Can specify how many times away from the standard deviation to consider an outlier (default 3)
    df_eval, df_dropped = data_processing.remove_temperature_outliers(df_eval, std_dev = 3) # df_dropped is what were dropped

**Add molecular properties and slope with temperature**

In [None]:
# Add molecular properties
df_eval, v0, s0 = data_processing.molecular_property_addition(df_eval, df_prop, mode = 'FH', solute_smiles = solute_smiles, solvent_smiles = solvent_smiles) # The mode can be changed between 'FH' and 'FV'. Note: 'FV' requires the DIPPR correlations

print('Starting calculation for Sobolev regularization, this might take a while')
df_eval = data_processing.invtemp_gradient_calc(df_eval, ln_y_label, min_temps = 4, min_delta_T = 30)

**Extract the model parameters and load the model**

In [None]:
with pd.HDFStore(model_file) as store:
    df_results = store['df_chemical_parameters']
    df_temp = store['df_temp_parameters']
    df_loss = store['df_loss']
    df_distance = store['df_distance']

# Get the exponents used for this model
temp_exponents = np.array(df_temp.columns, dtype = np.float64)

# Fill in na with 0
x_results = df_results.to_numpy()

# Extract everything as numpy arrays to send into model
T = df_temp.to_numpy()
O = df_distance.to_numpy()

# Set up the parameter arrays
# Programatically access
A = df_results.filter(like = 'ua_').to_numpy()
Q = df_results.filter(like = 'q_').to_numpy()
Alpha = df_results.filter(like = 'alpha_').to_numpy()
Delta = df_results.loc[:, df_results.columns.str.contains('acceptor') | df_results.columns.str.contains('donor')].to_numpy()

arrays = [arr for arr in (A, Alpha, Q, Delta) if arr.size > 0]

# Get the PyTorch model
model = model_launch.model_selector(residual_layer, A, Alpha, Q, T, O, Delta, association_layer, temp_exponents, device).to(device)

**Truncate the evaluated data to retain only IDAC with solutes and solvents present in the UFR model**

In [None]:
# Get what chemicals are in df_results
(node_size, model_dim) = df_results.shape
index = df_results.index

# Go through df_eval and drop all rows of data which we don't have a regressed value (solvent or solute).
# Initialize a list to store the valid rows
valid_rows = []

for idx, row in df_eval.iterrows():
    if row[solute_smiles] in index and row[solvent_smiles] in index:
        valid_rows.append(idx)
        solute_solvent = (row[solute_smiles], row[solvent_smiles])

# Filter dfA to only include valid rows
df_eval_filtered = df_eval.loc[valid_rows]

print(f"We retain only {df_eval_filtered.shape[0]} rows from the original {df_eval.shape[0]} representing {node_size} chemicals")

# We need to generate link all the chemicals in df_eval_filtered to the correct row of the parameter matrices
# Use get_loc to find the index positions of the values in df_eval in df_results's index
solute_positions = [index.get_loc(value) for value in df_eval_filtered[solute_smiles]]
solvent_positions = [index.get_loc(value) for value in df_eval_filtered[solvent_smiles]]

# Convert the list of positions to a numpy array
solute_positions = np.array(solute_positions)
solvent_positions = np.array(solvent_positions)

# Set up data as tensors
i_data = solute_positions
j_data = solvent_positions
invT_data = 1 / df_eval_filtered['Temp (K)'].to_numpy()
rA_data = df_eval_filtered['Solute_VDW_Volumes'].to_numpy() / v0
rB_data = df_eval_filtered['Solvent_VDW_Volumes'].to_numpy() / v0
qA_data = df_eval_filtered['Solute_VDW_Area'].to_numpy() / s0
qB_data = df_eval_filtered['Solvent_VDW_Area'].to_numpy() / s0
weight_data = 1 / df_eval_filtered['Solute_Solvent_Count'].to_numpy()
N_d_A = df_eval_filtered['Solute_H_Donor_Sites'].astype(float).to_numpy()
N_a_A = df_eval_filtered['Solute_H_Acceptor_Sites'].astype(float).to_numpy()
N_d_B = df_eval_filtered['Solvent_H_Donor_Sites'].astype(float).to_numpy()
N_a_B = df_eval_filtered['Solvent_H_Acceptor_Sites'].astype(float).to_numpy()
rhoBam = N_a_B / rB_data
rhoBdm = N_d_B / rB_data
rhoAap = N_a_A / rA_data
rhoAdp = N_d_A / rA_data

N_data = np.column_stack((N_a_A, N_d_A))
rho_data = np.column_stack((rhoAap, rhoAdp, rhoBam, rhoBdm))

# Push to GPU
i = torch.tensor(i_data, dtype = torch.int32).to(device)
j = torch.tensor(j_data, dtype = torch.int32).to(device)
invT = torch.tensor(invT_data[:, None]).to(device)
rA = torch.tensor(rA_data[:, None]).to(device)
qA = torch.tensor(qA_data[:, None]).to(device)
qB = torch.tensor(qB_data[:, None]).to(device)
N = torch.tensor(N_data, dtype = torch.float32).to(device)
rho = torch.tensor(rho_data, dtype = torch.float32).to(device)

# Evaluate the model, require gradients to calculate the gradient with inverse temperature
invT.requires_grad = True
output = model(i, j, invT, rA, qA, qB, N, rho) # Pass q or r depending on the model
output_dinvT = torch.autograd.grad(outputs=output, inputs= invT, grad_outputs=torch.ones_like(output), create_graph=True)[0]
invT.requires_grad = False

# Push to CPU
ln_y_hat = output.detach().cpu().double().numpy().flatten()
y_hat = np.exp(ln_y_hat)

# Save the values into the DataFrame. 
df_eval_filtered['UFR_Model_ln_gamma_res'] = ln_y_hat
df_eval_filtered['UFR_Model_derivative_1/T'] = output_dinvT.detach().cpu().double().numpy().flatten()
df_eval_filtered['UFR_Model_ln_gamma'] = ln_y_hat + df_eval_filtered[comb_label]

# Calculate the Aboslute
df_eval_filtered['AE_ln_gamma'] = np.absolute(df_eval_filtered['ln gamma'] - df_eval_filtered['UFR_Model_ln_gamma'])
df_eval_filtered['AE_derivative_1/T'] = np.absolute(df_eval_filtered['dlny_dinvT'] - df_eval_filtered['UFR_Model_derivative_1/T'])

**Model Performance Statistics**

In [None]:
# Calculate the mean and median
mean_AE = df_eval_filtered['AE_ln_gamma'].mean()
median_AE = df_eval_filtered['AE_ln_gamma'].median()

print(f"The mean AE is {mean_AE} and the median AE is {median_AE}")

# Make a box and whiskers plot of the absolute error
plt.figure(figsize=(8, 8))
sns.set(font='Arial', font_scale=1.5)
sns.set_style("whitegrid")  # Set the background to white
ax = sns.boxplot(y=df_eval_filtered['AE_ln_gamma'], data=df_eval_filtered, palette="viridis", showfliers=False, medianprops={"color": "white", "linewidth": 2})

# Set plot limits and labels
plt.ylim(-0.01, 0.3)
plt.ylabel('Absolute Error')
plt.xlabel('Model')
#plt.legend(title='Model')
ax.spines['top'].set_color('black')
ax.spines['bottom'].set_color('black') 
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('black')
plt.rcParams.update({'font.family': 'Arial', 'font.size': 18})
for label in (ax.get_xticklabels() + ax.get_yticklabels()):
    label.set_fontsize(18)
    label.set_fontname('Arial')
    label.set_color('black')
plt.gcf().set_size_inches(8, 6)
plt.show()

**Cumulative error plot**

In [None]:
# Plot the cumulative error

mae_noise = {12: 0.0565} # For a 12-D model, this is determined on the experimental dataset used in training via the analysis in the SI.

# Get the AE of each prediction
ae_values = df_eval_filtered['AE_ln_gamma'].to_numpy()

sort_ae_values = np.sort(ae_values)
cumulative_prob = np.arange(1, len(sort_ae_values) + 1) / len(sort_ae_values)

# Plot the cumulative probability as a line
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(sort_ae_values, cumulative_prob, color='black', label = 'Model')

# Extracting values

# Set the title and labels
#ax.set_title("Cumulative Probability Curve of AE", fontsize=18, fontname='Arial', color='black')
ax.set_xlabel("Absolute Error", fontsize=18, fontname='Arial', color='black')
ax.set_ylabel("Cumulative Probability", fontsize=18, fontname='Arial', color='black')
ax.set_xlim(0, 1)
ax.set_yticks(np.arange(0, 1.1, 0.1))
ax.grid(True)
# Set the box around the plot area as black
ax.spines['top'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('black')

# Turn off for now, turn on when you have the noise limit and set it appropriately
#ax.vlines(mae_noise[dimension], 0, 1, color='red', linestyle='--', label='Data Noise Limit')

plt.legend(loc = 'lower right', prop = {'family': 'Arial', 'size': 18})
plt.show()

**Parity Plot**

In [None]:
# Make a parity plot of the data
# Make a parity plot
plt.figure(figsize=(6,6))

# Create colors array 
colors = ['blue' if (row[solute_smiles] == 'O' or row[solvent_smiles] == 'O') else 'green' 
          for _, row in df_eval_filtered.iterrows()]

# Plot scatter points
plt.scatter(df_eval_filtered['ln gamma'], df_eval_filtered['UFR_Model_ln_gamma'], c=colors, s=4, alpha=0.5)

# Add diagonal line
min_val = min(df_eval_filtered['ln gamma'].min(), df_eval_filtered['UFR_Model_ln_gamma'].min())
max_val = max(df_eval_filtered['ln gamma'].max(), df_eval_filtered['UFR_Model_ln_gamma'].max())
plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect Agreement')

# Customize plot
plt.xlabel('Experimental Data', fontsize=18, fontname='Arial')
plt.ylabel('UFR Model Prediction', fontsize=18, fontname='Arial') 

# Set font properties for tick labels
plt.xticks(fontsize=18, fontname='Arial')
plt.yticks(fontsize=18, fontname='Arial')

# Make plot square
plt.gca().set_aspect('equal')
plt.grid(False)

# Add legend with Arial font
plt.legend(loc='upper left', prop={'family': 'Arial', 'size': 18})

# Adjust spines
ax = plt.gca()
ax.spines['top'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black') 
ax.spines['right'].set_color('black')

plt.tight_layout()
plt.show()

### 2. Train additional chemicals onto the UFR model

Additional data must be IDAC data of same format as original. Properties of the molecule must be included in the original property file

In [None]:
# Open the data file
train_data_file = "./data/glycolide_IDAC_data.csv"
prop_file = "./data/small_molecule_prop.csv"

df_train = pd.read_csv(train_data_file, index_col = 0)
df_prop = pd.read_csv(prop_file, index_col = 0)

solute_smiles = 'Solute SMILES' # Specify the column name in df_eval
solvent_smiles = 'Solvent SMILES' # Specify the column name in df_eval

# Specify parts of the model
combinatorial_layer = 'mod_FH' # Choose between FH, FV, SG, GK-FV, mod_FH, mod_FV
residual_layer = 'mod_UNIQUAC' # Choose between UNIQUAC, mod_UNIQUAC, Wilson, NRTL
association_layer = 'wertheim' # Choose between none, wertheim
dimension = 12
temp_type = 'invT'
temp_dim = 2
sobolev = 0
trial = 5

model_name = f'UFR_{combinatorial_layer}_{residual_layer}_{association_layer}_{dimension}D_{temp_type}_{temp_dim}_sobolev_{sobolev}_{trial}'
model_file = f"./trained_models/models_from_paper/{model_name}.h5" # Change the filepath to the model file you want to use

# We need to specify which IDAC is the correct one when removing the combinatorial layer - this is the IDAC to use for the slope with 1/T calculation
ln_y_label = f'ln_gamma_res_{combinatorial_layer}' # Do it for just this for now
comb_label = f'comb_{combinatorial_layer}' # This is name of the column containing the calculated combinatorial value

**Add molecular properties**

In [None]:
# Add molecular properties
df_train, v0, s0 = data_processing.molecular_property_addition(df_train, df_prop, mode = 'FH', solute_smiles = solute_smiles, solvent_smiles = solvent_smiles) # The mode can be changed between 'FH' and 'FV'. Note: 'FV' requires the DIPPR correlations

print('Starting calculation for Sobolev regularization, this might take a while')
df_train = data_processing.invtemp_gradient_calc(df_train, ln_y_label, min_temps = 4, min_delta_T = 30)

**Set up the model**

Most of the model parameters are defined already

In [None]:
# Specify a model savepath
save_path = './trained_models/'

# Set number of trials to run. This number serves as the seed to the random number generator.
trials = 1

# Load in a model.
with pd.HDFStore(model_file) as store:
    df_results = store['df_chemical_parameters']
    df_temp = store['df_temp_parameters']
    df_loss = store['df_loss']
    df_distance = store['df_distance']

# Get the exponents used for this model
temp_exponents = np.array(df_temp.columns, dtype = np.float64)

# Fill in na with 0
x_results = df_results.to_numpy()

# Extract everything as numpy arrays to send into model
T = df_temp.to_numpy()
O = df_distance.to_numpy()

# Set up the parameter arrays - keep as dataframe
# Programatically access
A = df_results.filter(like = 'ua_')
Q = df_results.filter(like = 'q_')
Alpha = df_results.filter(like = 'alpha_')
Delta = df_results.loc[:, df_results.columns.str.contains('acceptor') | df_results.columns.str.contains('donor')]

arrays = [arr for arr in (A, Alpha, Q, Delta) if arr.size > 0]

# Get the list of chemicals already contained in this model
regressed_smiles = df_results.index.values

# Get the dimensions
alpha_dim = Alpha.shape[1] if A.size > 0 else 0
q_dim = Q.shape[1] if Q.size > 0 else 0
delta_dim = Delta.shape[1] if Delta.size > 0 else 0
u_dim = A.shape[1] if A.size > 0 else 0

# Sobolev loss - Currently only allowing 0 when adding new molecules
sobolev = 0 # 0 turns this off. Any non-zero value turns on the Sobolev loss and becomes the weight for the term.

# Hyperparameters
lr = 0.01 # Learning rate
total_epochs = 30000 # Total number of epochs to run
up_epochs = 1000 # Number of epochs for ramping up the model
hold_epochs = 20000 # Number of epochs for holding at the maximum learning rate
pre_train_epoch = 500 # Number of epochs for separately regressing the residual and association layers if both are used.

# Truncation to prevent overfitting
truncation = 'chemical_connections' # Set to temp_connections if we want to count each pairing's temperature as unique. Set to chemical_connections if we want to count each pairing's temperature as one entry. Set to 'keep_all' if we don't want it to truncate

# Set up a savename for the model
model_name = f'UFR_{combinatorial_layer}_{residual_layer}_{association_layer}_{dimension}D_{temp_type}_{temp_exponents.size}D_sobolev_{sobolev}_Additional_Molecule'
save_name = save_path + model_name

# Set up dictionaries of model parameters for passing into the model
ln_y_data = f'ln_gamma_res_{combinatorial_layer}' # This specifies which IDAC with combinatorial removed to regress on.
model_layer_options = {'ln_y_data': ln_y_data, 'combinatorial_layer': combinatorial_layer, 'residual_layer': residual_layer, 'association_layer': association_layer, 'temp_exponents': temp_exponents, 'reference_volume': v0, 'reference_area': s0}
model_opt_options = {'sobolev': sobolev, 'lr': lr, 'total_epochs': total_epochs, 'up_epochs': up_epochs, 'hold_epochs': hold_epochs}
model_run_options = {'truncation': truncation, 'smile_labels' : (solute_smiles, solvent_smiles), 'pre_train_epoch': pre_train_epoch, 'save_name': save_name}

**Run the model**

In [None]:
# Run the model
# Set up the optimization
criterion = nn.L1Loss() # Use MAE as loss criteria

# Generate a list of unique nodes in the new dataset we want to correlate
full_node_list = list(pd.unique(df_train[[solute_smiles, solvent_smiles]].values.ravel('K')))
# Create a mapping of node to index
node_idx_dict = dict(zip(np.array(full_node_list), np.array(range(np.size(full_node_list)))))
# Generate a mask for masking all nodes which are already trained in the model, so that they are not regressed again
node_mask = np.array([False if node in regressed_smiles else True for node in full_node_list])

# Add index columns for solute and solvent
df_train['Solute_Idx'] = df_train[solute_smiles].map(node_idx_dict)
df_train['Solvent_Idx'] = df_train[solvent_smiles].map(node_idx_dict)

edges = df_train.shape
n_node = np.size(full_node_list)  # full graph node size

# Set up data as tensors
i_data = df_train['Solute_Idx'].to_numpy()
j_data = df_train['Solvent_Idx'].to_numpy()
invT_data = 1 / df_train['Temp (K)'].to_numpy()
y_data = df_train[ln_y_data].to_numpy()
rA_data = df_train['Solute_VDW_Volumes'].to_numpy() / v0
rB_data = df_train['Solvent_VDW_Volumes'].to_numpy() / v0
qA_data = df_train['Solute_VDW_Area'].to_numpy() / s0
qB_data = df_train['Solvent_VDW_Area'].to_numpy() / s0
N_d_A = df_train['Solute_H_Donor_Sites'].astype(float).to_numpy()
N_a_A = df_train['Solute_H_Acceptor_Sites'].astype(float).to_numpy()
N_d_B = df_train['Solvent_H_Donor_Sites'].astype(float).to_numpy()
N_a_B = df_train['Solvent_H_Acceptor_Sites'].astype(float).to_numpy()
rhoBam = N_a_B / rB_data
rhoBdm = N_d_B / rB_data
rhoAap = N_a_A / rA_data
rhoAdp = N_d_A / rA_data
N_data = np.column_stack((N_a_A, N_d_A))
rho_data = np.column_stack((rhoAap, rhoAdp, rhoBam, rhoBdm))

np.random.seed(1) # For reproducibility across trials 

early_stopping = EarlyStopper() # Set up early stopper

# Push to GPU
i = torch.tensor(i_data).to(device)
j = torch.tensor(j_data).to(device)
invT = torch.tensor(invT_data[:, None]).to(device) # Must be Ndata x 1
y = torch.tensor(y_data[:, None]).to(device)
rA = torch.tensor(rA_data[:, None]).to(device)
qA = torch.tensor(qA_data[:, None]).to(device)
qB = torch.tensor(qB_data[:, None]).to(device)
N = torch.tensor(N_data, dtype = torch.float32).to(device)
rho = torch.tensor(rho_data, dtype = torch.float32).to(device)

q = df_prop.set_index('Canonical SMILES').loc[full_node_list, 'van der waals area (m2/kmol)'].to_numpy() / s0

loss_values = []
# Initialize model
Alpha_shape = (n_node, alpha_dim)
A_shape = (n_node, u_dim)
T_shape = (u_dim, temp_dim)
D_shape = (n_node, delta_dim) # Always pass in Wertheim, won't always use
Q_shape = (n_node, 1)
O_shape = (u_dim, 2) # For combining the u_dim parameters

# Guess, then overwrite with the actual values
A_initial = 0.1 * 20 * np.random.standard_normal(size = A_shape) + 20 # This makes all values large, and makes it 10% of the guess size
Alpha_initial = np.random.random_sample(size = Alpha_shape) # From [0, 1)
D_initial = 0.1 * np.random.standard_normal(size = D_shape) + np.log(np.exp(1) - 1) * np.ones(D_shape) # Narrow random distribution around 1 (after softmax)
Q_initial = np.zeros(Q_shape) #0.1 * np.random.standard_normal(size = Q_shape) + 1 # This must be q to start
Q_initial[:, 0] = q.copy()

# Overwrite with regressed values
for smiles, idx in node_idx_dict.items():
    if smiles in regressed_smiles:
        A_initial[idx, :] = A.loc[smiles].to_numpy()
        Alpha_initial[idx, :] = Alpha.loc[smiles].to_numpy()
        D_initial[idx, :] = Delta.loc[smiles].to_numpy()
        Q_initial[idx, :] = Q.loc[smiles].to_numpy()

# Push to gpu
A_initial = torch.tensor(A_initial, dtype=torch.float32).to(device)
Alpha_initial = torch.tensor(Alpha_initial, dtype=torch.float32).to(device)
D_initial = torch.tensor(D_initial,  dtype=torch.float32).to(device)
T_initial = torch.tensor(df_temp.to_numpy(), dtype=torch.float32).to(device) # From regressed results
O_initial = torch.tensor(df_distance.to_numpy(), dtype=torch.float32).to(device) # From regressed results
Q_initial = torch.tensor(Q_initial, dtype = torch.float32).to(device)

node_mask_tensor = torch.tensor(node_mask).to(device)

# Select model here
# Select model here
if residual_layer == 'UNIQUAC':
    model = UNIQUAC_add_molecules(A_initial, T_initial, O_initial, D_initial, association_layer, temp_exponents, node_mask_tensor, device)
elif residual_layer == 'Wilson':
    model = Wilson_add_molecules(A_initial, T_initial, O_initial, D_initial, association_layer, temp_exponents, node_mask_tensor, device)
elif residual_layer == 'NRTL':
    model = NRTL_add_molecules(Alpha_initial, A_initial, T_initial, O_initial, D_initial, association_layer, temp_exponents, node_mask_tensor, device)
elif residual_layer == 'mod_UNIQUAC':
    model = mod_UNIQUAC_add_molecules(A_initial, Q_initial, T_initial, O_initial, D_initial, association_layer, temp_exponents, node_mask_tensor, device)
else:
    print(f'Model label of {residual_layer} is invalid')

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr = lr)

# Change learning rate
# Warm up for learning rate - use 1000 iterations to reach full loss
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.1, end_factor = 1, total_iters = 1000) # 1000 iterations to reach full LR of 0.01

# Training loop
for epoch in range(total_epochs):
    # Put model into training
    model.train()
    # Reset the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(i, j, invT, rA, qA, qB, N, rho)

    # Compute the loss between predicted and target values without Sobolev regularization
    loss = criterion(torch.squeeze(outputs), torch.squeeze(y))

    # Backward pass, comput gradients
    loss.backward()

    # Update parameters
    optimizer.step()

    # Change learning rate
    scheduler.step()

    # Append loss to the array
    loss_values.append(loss.item())

    early_stopping(loss.item())
    # Print message every 500th iteration
    if (epoch + 1) % 500 == 0:
        print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item():.4f}')

    if early_stopping.early_stop:
        print(f"Early stopping occurred at Epoch [{epoch + 1}/{total_epochs}], Loss: {loss.item():.4f}")
        break

# Transfer tensors to array - remember we need to select between the learned and fixed values appropriately
# All models have A, O, T
A_learn = model.A.detach().cpu()
A_fixed = model.A_fixed.detach().cpu()
A_final = torch.where(node_mask_tensor.cpu().unsqueeze(1), A_learn, A_fixed).numpy()
# T and O are always fixed and should be original model values
T_final = model.T.detach().cpu().numpy()
O_final = model.O.detach().cpu().numpy()
Alpha_final = np.array([])
Q_final = np.array([])
D_final = np.array([])

# Extract alpha for NRTL, Q for mod-UNIQUAC
if residual_layer == 'NRTL':
    Alpha_learn = model.Alpha.detach().cpu()
    Alpha_fixed = model.Alpha_fixed.detach().cpu()
    Alpha_final = torch.where(node_mask_tensor.cpu().unsqueeze(1), Alpha_learn, Alpha_fixed).numpy()
elif residual_layer == 'mod_UNIQUAC':
    Q_learn = model.Q.detach().cpu()
    Q_fixed = model.Q_fixed.detach().cpu()
    Q_final = torch.where(node_mask_tensor.cpu().unsqueeze(1), Q_learn, Q_fixed).numpy()

# Extract Delta for Wertheim
if association_layer == 'wertheim':
    D_learn = model.D.detach().cpu()
    D_fixed = model.D_fixed.detach().cpu()
    D_final = torch.where(node_mask_tensor.cpu().unsqueeze(1), D_learn, D_fixed).numpy()
    
# Save the model parameters using state_dict, it is a .pt file
model_filename = f'{save_name}_{trial}.pt'
torch.save(model.state_dict(), model_filename)

# Save the model parameters as h5 file format. We need to merge the array first
param_array = [arr for arr in (A_final, Alpha_final, Q_final, D_final) if arr.size > 0]
merged_array = np.hstack(param_array)

# Array index
index_series = pd.Series(node_idx_dict)

df_results_add_mol = pd.DataFrame(merged_array, index = index_series.index, columns = df_results.columns)
df_distance_add_mol = pd.DataFrame(O)
df_temp_add_mol = pd.DataFrame(T, columns = temp_exponents)
df_loss_add_mol = pd.DataFrame(loss_values)

# Update the original df_results with the new values
df_results_final = pd.concat([df_results, df_results_add_mol]).groupby(level=0).last()

# Save all DataFrames as an HDF5 file
h5_filename = f'{save_name}_{trial}.h5'
with pd.HDFStore(h5_filename) as store:
    store.put('df_chemical_parameters', df_results_final)
    store.put('df_temp_parameters', df_temp_add_mol)
    store.put('df_loss', df_loss_add_mol)
    store.put('df_distance', df_distance_add_mol)

print(f'Model saved as {model_filename} and {h5_filename}')