In [None]:
import os
import gc
import h5py
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
# from sklearn import model_selection
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from sklearn.preprocessing import StandardScaler
# from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

cwd = os.getcwd()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # cuda is still faster
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # Apple Silicon acceleration for torch 2.1.2 or later

In [None]:
eos_tables_dir = os.path.join("/NNC2P Workspace/eos_tables")
# eos_table_filename = "LS180_234r_136t_50y_analmu_20091212_SVNr26.h5"
eos_table_filename = "LS220_234r_136t_50y_analmu_20091212_SVNr26.h5"

In [None]:
for obj in gc.get_objects():
    if isinstance(obj, h5py.File):
        try:
            obj.close()
        except:
            continue
        del obj
gc.collect()

In [None]:
def read_eos_table(filename):
    return h5py.File(filename, 'r')

In [None]:
filename = os.path.join(eos_tables_dir, "train_eos_table.h5")

In [None]:
# This cell is inspired and partially adapted by the following source: https://github.com/ThibeauWouters/master-thesis-AI

V_MIN = 0
V_MAX = 0.721

def W(rho, eps, v, p = None):
    v_sqr = v ** 2 if isinstance(v, float) else np.sum(v ** 2)
    return (1 - v_sqr) ** (-1 / 2)

def generate_training_data_c2p(eos_table, number_of_points, save_name):
    """
    This function generates training data by sampling from a given equation of state (EOS) table and performing a 
    primitive to conserved (P2C) transformation. The generated data is saved in a .csv file. The function takes as 
    input an EOS table, the number of data points to be generated, and the name of the file where the data will be saved.
    """
    ye_table, temp_table, rho_table, eps_table, p_table = eos_table["ye"][()], eos_table["logtemp"][()], eos_table["logrho"][()], eos_table["logenergy"][()], eos_table["logpress"][()]
    len_ye, len_temp, len_rho = eos_table["pointsye"][()][0], eos_table["pointstemp"][()][0], eos_table["pointsrho"][()][0]

    features, labels = [], []

    for _ in range(number_of_points):
        v = random.uniform(V_MIN, V_MAX)
        ye_index, temp_index, rho_index = np.random.choice(len_ye), np.random.choice(len_temp), np.random.choice(len_rho)

        ye, logtemp, logrho, logeps, logp = ye_table[ye_index], temp_table[temp_index], rho_table[rho_index], eps_table[ye_index, temp_index, rho_index], p_table[ye_index, temp_index, rho_index]
        temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp

        h, D = 1 + eps + p / rho, rho * W(rho, eps, v, p)
        S, tau =  rho * h * W(rho, eps, v, p) ** 2 * v, rho * h * W(rho, eps, v, p) ** 2 - p - D

        features.append([np.log10(D), np.log10(S), np.log10(tau), ye])
        labels.append([logp])

    with h5py.File(save_name, 'w') as f:
        f.create_dataset("features", data=features)
        f.create_dataset("labels", data=labels)

    eos_table.close()

In [None]:
number_of_points = 1000000 # For speed 100000
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
generate_training_data_c2p(eos_table, number_of_points=number_of_points, save_name=os.path.join(eos_tables_dir, "c2p_full_dataset.h5"))
eos_table.close()

In [None]:
# # On DTAI only
# def standard_scaler(input_tensor):
#     mean = torch.mean(input_tensor, dim=0)
#     std = torch.std(input_tensor, dim=0, unbiased=False)  # Match scikit-learn's behavior

#     standardized_tensor = (input_tensor - mean) / std
    
#     return standardized_tensor

In [None]:
train_c2p_table = h5py.File(os.path.join(eos_tables_dir, "c2p_full_dataset.h5"), 'r')
features = train_c2p_table["features"][:].astype(np.float32)
labels = train_c2p_table["labels"][:].astype(np.float32)
train_c2p_table.close()

features_tensor = torch.from_numpy(features)
labels_tensor = torch.from_numpy(labels)

val_size = number_of_points // 20
test_size = number_of_points // 20
train_size = number_of_points - val_size - test_size

dataset = TensorDataset(features_tensor, labels_tensor)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_features_unscaled, train_labels_unscaled = train_dataset[:]
val_features_unscaled, val_labels_unscaled = val_dataset[:]
test_features_unscaled, test_labels_unscaled = test_dataset[:]

input_scaler = StandardScaler()
train_features = torch.from_numpy(input_scaler.fit_transform(train_features_unscaled))
val_features = torch.from_numpy(input_scaler.transform(val_features_unscaled))
test_features = torch.from_numpy(input_scaler.transform(test_features_unscaled))

output_scaler = StandardScaler()
train_labels = torch.from_numpy(output_scaler.fit_transform(train_labels_unscaled))
val_labels = torch.from_numpy(output_scaler.transform(val_labels_unscaled))
test_labels = torch.from_numpy(output_scaler.transform(test_labels_unscaled))

In [None]:
train_dataset = TensorDataset(train_features, train_labels)
val_dataset = TensorDataset(val_features, val_labels)
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class NNC2P_Tabulated(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 64)
        self.fc6 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x)) 
        x = torch.relu(self.fc2(x)) 
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = torch.relu(self.fc5(x))
        x = self.fc6(x) 
        return x

In [None]:
model = NNC2P_Tabulated().to(device)
model = nn.DataParallel(model)
print(f"Using GPUs: {model.device_ids}")

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
os.makedirs('checkpoints', exist_ok=True)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
n_epochs = 250

train_losses = []
val_losses = []

# Function to get the current learning rate
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# early_stop_patience = 15  # Number of epochs to wait for improvement
# min_delta_percent = 0.01  # Minimum % improvement (e.g., 1%)
# best_val_loss = float('inf')
# early_stop_counter = 0

# Training loop
for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    for batch_inputs, batch_outputs in train_loader:
        batch_inputs = batch_inputs.float()
        batch_outputs = batch_outputs.float()
        
        batch_inputs = batch_inputs.to(device)
        batch_outputs = batch_outputs.to(device)
        
        optimizer.zero_grad()
        batch_outputs_pred = model(batch_inputs)
        
        loss = criterion(batch_outputs_pred, batch_outputs)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * batch_inputs.size(0)
    
    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_inputs, batch_outputs in val_loader:
            batch_inputs = batch_inputs.float()
            batch_outputs = batch_outputs.float()
            
            batch_inputs = batch_inputs.to(device)
            batch_outputs = batch_outputs.to(device)
            
            batch_outputs_pred = model(batch_inputs)
            loss = criterion(batch_outputs_pred, batch_outputs)
            
            val_loss += loss.item() * batch_inputs.size(0)
    
    val_loss /= len(val_loader.dataset)
    val_losses.append(val_loss)
    
    # # Early stopping logic
    # if best_val_loss == float('inf') or (best_val_loss - val_loss) / best_val_loss > min_delta_percent:
    #     best_val_loss = val_loss
    #     early_stop_counter = 0
    # else:
    #     early_stop_counter += 1
    #     if early_stop_counter >= early_stop_patience:
    #         print(f"Early stopping at epoch {epoch+1} due to insufficient validation loss improvement.")
    #         break
    
    current_lr = get_lr(optimizer)
    print(f'Epoch {epoch+1}/{n_epochs} Train Loss: {train_loss:.16f} Val Loss: {val_loss:.16f} LR: {current_lr:.6f}')
    
    scheduler.step(val_loss)

    # Check GPU memory allocation
    for i in range(torch.cuda.device_count()):
        print(f"Memory allocated on GPU {i}: {torch.cuda.memory_allocated(i)} bytes")
    
    # Save checkpoint every 125 epochs
    if (epoch + 1) % 125 == 0:
        checkpoint_path = f'checkpoints/checkpoint_tabulated_epoch_{epoch+1}.pth'
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
        }, checkpoint_path)
        print(f'Checkpoint saved at {checkpoint_path}')

In [None]:
# Plot training and validation loss 1M 1M
plt.figure(figsize=(10, 5))
plt.semilogy(range(1, n_epochs + 1), train_losses, label='Training Loss')
plt.semilogy(range(1, n_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss vs. Epochs')
plt.show()

In [None]:
(np.argmin(val_losses), np.min(val_losses))

In [None]:
if isinstance(model, torch.nn.DataParallel):
    model = model.module
    
model = model.to('cpu')

In [None]:
model.eval()

# Perform inference
with torch.no_grad():
    predictions = model(test_features.float())

In [None]:
predictions

In [None]:
test_labels

In [None]:
def inverse_standard_scaler(standardized_tensor, mean, std):
    # Inverse standardize the tensor
    original_tensor = (standardized_tensor * std) + mean
    
    return original_tensor

inverted_predictions = inverse_standard_scaler(predictions, torch.mean(train_labels_unscaled, dim=0), torch.std(train_labels_unscaled, unbiased=False))

In [None]:
l1_loss = nn.L1Loss()
l1_error = l1_loss(inverted_predictions, test_labels_unscaled)
print(f'L1 Error: {l1_error.item():.2e}')
linf_error = torch.max(torch.abs(inverted_predictions - test_labels_unscaled))
print(f'L-infinity Error: {linf_error.item():.2e}')

In [None]:
(inverted_predictions, test_labels_unscaled)

In [None]:
model_path = os.path.join("NNC2P Workspace/models", "NNC2P_Tabulated.pth")
model.eval()
model.to(device)
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, model_path)

In [None]:
train_mean = torch.mean(train_labels_unscaled, dim=0).to("cpu")
train_std = torch.std(train_labels_unscaled, unbiased=False).to("cpu")

train_mean_out_np = train_mean.numpy()
train_std_out_np = train_std.numpy()

np.savetxt("./speed_test/gpu/mean_std_out_tabulated.txt", np.vstack((train_mean_out_np, train_std_out_np)), fmt="%.17f")

train_features_unscaled = train_features_unscaled.to("cpu")

mean = train_features_unscaled.mean(dim=0, keepdim=True)
std = train_features_unscaled.std(dim=0, keepdim=True)

train_mean_in_np = mean.numpy()
train_std_in_np = std.numpy()
np.savetxt("./speed_test/gpu/mean_std_in_tabulated.txt", np.vstack((train_mean_in_np, train_std_in_np)), fmt="%.17f")

In [None]:
inputs_test_unscaled_np = test_features_unscaled.cpu().numpy()
outputs_test_unscaled_np = test_labels_unscaled.cpu().numpy()
inputs_test_np = test_features.cpu().numpy()
outputs_test_np = test_labels.cpu().numpy()
preds_test_np = predictions.cpu().numpy()
inverted_preds_np = inverted_predictions.cpu().numpy()
inputs_train_unscaled_np = train_features_unscaled.cpu().numpy()
outputs_train_unscaled_np = train_labels_unscaled.cpu().numpy()

# Save to txt with maximum precision
np.savetxt('./speed_test/gpu/inputs_train_unscaled_tabulated.txt', inputs_train_unscaled_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/outputs_train_unscaled_tabulated.txt', outputs_train_unscaled_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/inputs_test_unscaled_tabulated.txt', inputs_test_unscaled_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/outputs_test_unscaled_tabulated.txt', outputs_test_unscaled_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/preds_test_tabulated.txt', preds_test_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/inputs_test_scaled_tabulated.txt', inputs_test_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/outputs_test_scaled_tabulated.txt', outputs_test_np, fmt='%.9g')
np.savetxt('./speed_test/gpu/inverted_preds_tabulated.txt', inverted_preds_np, fmt='%.9g')

In [None]:
test_features_unscaled

In [None]:
# LOAD MODEL
model_path = "NNC2P Workspace/models/NNC2P_Tabulated.pth"
model = torch.jit.load(model_path, map_location=device)
model = model.to(device)
model.eval()

loaded_stats_in = np.loadtxt("./speed_test/gpu/mean_std_in_tabulated.txt")

train_mean_in = loaded_stats_in[0] # read train input/features mean (1x4) 
train_std_in = loaded_stats_in[1]  # read train input/features std (1x4)

loaded_stats_out = np.loadtxt("./speed_test/gpu/mean_std_out_tabulated.txt")

train_mean_out = loaded_stats_out[0] # read train output/label mean 
train_std_out = loaded_stats_out[1] # read train output/label std

In [None]:
def l1_norm(predictions, y, reduction = True):
    """
    Computes the L1 norm between predictions made by the neural network and the actual values.
    :param predictions: Predictions made by the neural network architecture.
    :param y: Actual values
    :return: L1 norm between predictions and y
    """
    if len(predictions) == 0:
        print("Predictions is empty list")
        return 0
    if len(predictions) != len(y):
        print("Predictions and y must have same size")
        return 0

    if reduction:
        return np.sum(abs(predictions - y), axis=0)/len(predictions)
    else:
        return abs(predictions - y)

In [None]:
# JUST LOAD THE MODEL BEFORE THIS
from matplotlib.colors import LogNorm
fs = 16
k_B = 8.617333262145e-11  # Boltzmann constant in MeV/K

eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
# ye_index = len(ye) // 2
ye_index = 6
ye_value = ye[ye_index]

logrho       = eos_table["logrho"][()]
logtemp      = eos_table["logtemp"][()]
logpress     = eos_table["logpress"][()]
logenergy    = eos_table["logenergy"][()]

targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape

logeps, logp = logenergy[ye_index], logpress[ye_index]

W = 1.2
v = np.sqrt((W ** 2 - 1)/W**2)

eps, p = 10 ** logeps, 10 ** logp
# print(eps.shape, p.shape)

input_values = []
for i, logtemp_value in enumerate(logtemp):
    for j, logrho_value in enumerate(logrho):
        temp, rho = 10 ** logtemp_value, 10 ** logrho_value
        h, D = 1 + eps[i, j] + p[i, j] / rho, rho * W
        S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p[i, j] - D
        new_row = [np.log10(D), np.log10(S), np.log10(tau), ye_value]
        input_values.append(new_row)
input_values = np.array(input_values)

input_values = (input_values - train_mean_in) / train_std_in
with torch.no_grad():
    # input_values = input_scaler.transform(input_values)
    input_values = torch.from_numpy(input_values).float()
    predictions = model(input_values.to(device))
    predictions = predictions.cpu().numpy()

# predictions = inverse_standard_scaler(predictions, train_mean.cpu(), train_std.cpu())
predictions = predictions * train_std_out + train_mean_out
# predictions = predictions.numpy()
norm_function = l1_norm

sliced_predictions = predictions[:, 0]

targets = logpress[ye_index]
targets = np.swapaxes(targets, 0, 1)

target_values = []
for a in range(n_logtemp):
    for b in range(n_logrho):
        target_values.append(targets[b, a])

target_values = np.array(target_values)

delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

fig = plt.figure(figsize=(4, 3))

im = plt.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))

# Labels with a smaller font size
plt.xlabel(r"$\log \rho$", fontsize=fs - 4)
plt.ylabel(r"$\log T$", fontsize=fs - 4)

xt = [0, n_logrho // 2, n_logrho]
xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
yt = [0, n_logtemp // 2, n_logtemp]
yl_K = np.round(np.log10(10**np.array([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]]) / k_B), 2)

# Adjusting tick labels with smaller font size
plt.xticks(xt, xl, fontsize=fs - 6)
plt.yticks(yt, yl_K, fontsize=fs - 6)

# Color bar with a reduced label font size
cbar = plt.colorbar(im, shrink=0.7)
cbar.set_label('Relative Error', rotation=270, labelpad=20, fontsize=fs - 4)
plt.show()

In [None]:
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
# ye_index = len(ye) // 2
ye_index = 6
ye_value = ye[ye_index]

logrho       = eos_table["logrho"][()]
logtemp      = eos_table["logtemp"][()]
logpress     = eos_table["logpress"][()]
logenergy    = eos_table["logenergy"][()]

targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape

logeps, logp = logenergy[ye_index], logpress[ye_index]

W = 1.2
v = np.sqrt((W ** 2 - 1)/W**2)

temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp

h, D = 1 + eps + p / rho, rho * W
S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p - D

input_values = []
for i, logtemp_value in enumerate(logtemp):
    for j, logrho_value in enumerate(logrho):
        new_row = [np.log10(D[j]), np.log10(S[i, j]), np.log10(tau[i, j]), ye_value]
        input_values.append(new_row)
input_values = np.array(input_values)

input_values = (input_values - train_mean_in) / train_std_in
with torch.no_grad():
    # input_values = input_scaler.transform(input_values)
    input_values = torch.from_numpy(input_values).float().to(device)
    predictions = model(input_values)
    predictions = predictions.cpu().numpy()

# predictions = inverse_standard_scaler(predictions, train_mean.cpu(), train_std.cpu())
# predictions = predictions.numpy()
predictions = predictions * train_std_out + train_mean_out
norm_function = l1_norm

sliced_predictions = predictions[:, 0]

targets = logpress[ye_index]
targets = np.swapaxes(targets, 0, 1)

target_values = []
for a in range(n_logtemp):
    for b in range(n_logrho):
        target_values.append(targets[b, a])

target_values = np.array(target_values)

delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))
print(delta_vals)

fig = plt.figure(figsize=(9,4))

im = plt.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))

plt.xlabel(r"$\log \rho$", fontsize=fs)
plt.ylabel(r"$\log T$", fontsize=fs)

xt = [0, n_logrho//2, n_logrho]
xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
yt = [0, n_logtemp//2, n_logtemp]
yl = np.round([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]], 2)

plt.xticks(xt, xl)
plt.yticks(yt, yl)

cbar = plt.colorbar(im, shrink=0.7)
cbar.set_label('Delta Values', rotation=270, labelpad=20)
plt.show()

In [None]:
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
ye_index = 6
ye_value = ye[ye_index]

k_B = 8.617333262145e-11  # Boltzmann constant in MeV/K

logrho       = eos_table["logrho"][()]
logtemp      = eos_table["logtemp"][()]
logpress     = eos_table["logpress"][()]
logenergy    = eos_table["logenergy"][()]

targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape

logeps, logp = logenergy[ye_index], logpress[ye_index]

# W_values = [1.6, 1.8, 2.0, 2.2, 2.4]
W_values = [1.1, 1.2, 1.3, 1.4]

for W in W_values:
    v = np.sqrt((W ** 2 - 1)/W**2)

    temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp

    h, D = 1 + eps + p / rho, rho * W
    S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p - D

    input_values = []
    for i, logtemp_value in enumerate(logtemp):
        for j, logrho_value in enumerate(logrho):
            new_row = [np.log10(D[j]), np.log10(S[i, j]), np.log10(tau[i, j]), ye_value]
            input_values.append(new_row)
    input_values = np.array(input_values)
    input_values = (input_values - train_mean_in) / train_std_in
    
    with torch.no_grad():
        # input_values = scaler.transform(input_values)
        input_values = torch.from_numpy(input_values).float().to(device)
        predictions = model(input_values)
        predictions = predictions.cpu().numpy()

    predictions = predictions * train_std_out + train_mean_out
    norm_function = l1_norm
    sliced_predictions = predictions[:, 0]

    targets = logpress[ye_index]
    targets = np.swapaxes(targets, 0, 1)

    target_values = []
    for a in range(n_logtemp):
        for b in range(n_logrho):
            target_values.append(targets[b, a])

    target_values = np.array(target_values)

    delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
    delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
    delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

    delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

    fig = plt.figure(figsize=(9,4))

    im = plt.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))

    plt.xlabel(r"$\log \rho$", fontsize=fs)
    plt.ylabel(r"$\log T$", fontsize=fs)

    xt = [0, n_logrho//2, n_logrho]
    xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
    yt = [0, n_logtemp//2, n_logtemp]
    yl_K = np.round(np.log10( 10**np.array([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]]) / k_B), 2)

    plt.xticks(xt, xl)
    plt.yticks(yt, yl_K)

    cbar = plt.colorbar(im, shrink=0.7)
    cbar.set_label('Delta Values', rotation=270, labelpad=20)
    plt.title(f'W = {W}')
    plt.show()

In [None]:
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
ye_index = 6
ye_value = ye[ye_index]

k_B = 8.617333262145e-11  # Boltzmann constant in MeV/K

logrho       = eos_table["logrho"][()]
logtemp      = eos_table["logtemp"][()]
logpress     = eos_table["logpress"][()]
logenergy    = eos_table["logenergy"][()]

targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape

logeps, logp = logenergy[ye_index], logpress[ye_index]

# W_values = [1.6, 1.8, 2.0, 2.2, 2.4]
W_values = [1.02, 1.1, 1.25, 1.4]

fig, axs = plt.subplots(1, 4, figsize=(15, 4), sharey=True)

for ax, W in zip(axs, W_values):
    v = np.sqrt((W ** 2 - 1)/W**2)

    temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp

    h, D = 1 + eps + p / rho, rho * W
    S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p - D

    input_values = []
    for i, logtemp_value in enumerate(logtemp):
        for j, logrho_value in enumerate(logrho):
            new_row = [np.log10(D[j]), np.log10(S[i, j]), np.log10(tau[i, j]), ye_value]
            input_values.append(new_row)
    input_values = np.array(input_values)
    input_values = (input_values - train_mean_in) / train_std_in
    
    with torch.no_grad():
        # input_values = scaler.transform(input_values)
        input_values = torch.from_numpy(input_values).float().to(device)
        predictions = model(input_values)
        predictions = predictions.cpu().numpy()
    
    predictions = predictions * train_std_out + train_mean_out
    norm_function = l1_norm

    sliced_predictions = predictions[:, 0]

    targets = logpress[ye_index]
    targets = np.swapaxes(targets, 0, 1)

    target_values = []
    for a in range(n_logtemp):
        for b in range(n_logrho):
            target_values.append(targets[b, a])

    target_values = np.array(target_values)

    delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
    delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
    delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))

    im = ax.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))

    ax.set_xlabel(r"$\log \rho$", fontsize=fs-6)
    if W == W_values[0]:
        ax.set_ylabel(r"$\log T$", fontsize=fs-6)
        
    # if W == W_values[-1]:
    #    ax.set_xlabel(r"$\log \rho$", fontsize=fs)

    xt = [0, n_logrho//2, n_logrho]
    xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
    yt = [0, n_logtemp//2, n_logtemp]
    yl_K = np.round(np.log10( 10**np.array([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]]) / k_B), 2)

    ax.set_xticks(xt)
    ax.set_xticklabels(xl)
    ax.set_yticks(yt)
    ax.set_yticklabels(yl_K)

    ax.set_title(f'W = {W}', fontsize=fs-6)

fig.colorbar(im, ax=axs.ravel().tolist(), shrink=0.4, label='Relative Error')
plt.savefig('../images/plot.pdf')  # Save the figure with a higher DPI for better resolution
plt.savefig('../images/plot.png', dpi=300)
plt.savefig('../images/plot.svg')
plt.show()

In [None]:
import tensorrt as trt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os
import pycuda.driver as cuda
import pycuda.autoinit

fs = 16

def load_engine(engine_path):
    """Load TensorRT engine"""
    logger = trt.Logger(trt.Logger.WARNING)
    with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def allocate_buffers(engine, batch_size=1):
    """Allocate device buffers and return input/output bindings"""
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    
    for binding in engine:
        shape = engine.get_binding_shape(binding)
        size = trt.volume(shape) * batch_size
        dtype = np.float32
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to bindings
        bindings.append(int(device_mem))
        if engine.binding_is_input(binding):
            inputs.append({'host': host_mem, 'device': device_mem, 'shape': shape})
        else:
            outputs.append({'host': host_mem, 'device': device_mem, 'shape': shape})
    return inputs, outputs, bindings, stream

def process_batch(context, input_data, inputs, outputs, bindings, stream):
    """Process a batch of data"""
    # Prepare input data
    input_buffer = np.zeros(inputs[0]['shape'], dtype=np.float32)
    input_buffer[:input_data.shape[0]] = input_data
    
    np.copyto(inputs[0]['host'], input_buffer.ravel())
    cuda.memcpy_htod_async(inputs[0]['device'], inputs[0]['host'], stream)
    
    # Run inference
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    
    # Transfer predictions back
    cuda.memcpy_dtoh_async(outputs[0]['host'], outputs[0]['device'], stream)
    stream.synchronize()
    
    # Only return the relevant portion of the output
    return outputs[0]['host'][:input_data.shape[0]]

# Load your EOS table and prepare data
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
ye_index = 6
ye_value = ye[ye_index]
k_B = 8.617333262145e-11  # Boltzmann constant in MeV/K
logrho = eos_table["logrho"][()]
logtemp = eos_table["logtemp"][()]
logpress = eos_table["logpress"][()]
logenergy = eos_table["logenergy"][()]
targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape
logeps, logp = logenergy[ye_index], logpress[ye_index]

# Load TensorRT engine
engine_path = "../models/NNC2P_Tabulated_FP16.engine"  # Replace with your engine path
engine = load_engine(engine_path)
context = engine.create_execution_context()
inputs, outputs, bindings, stream = allocate_buffers(engine)

W_values = [1.02, 1.1, 1.25, 1.4]
fig, axs = plt.subplots(1, 4, figsize=(15, 4), sharey=True)

for ax, W in zip(axs, W_values):
    v = np.sqrt((W ** 2 - 1)/W**2)
    temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp
    h, D = 1 + eps + p / rho, rho * W
    S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p - D
    
    input_values = []
    for i, logtemp_value in enumerate(logtemp):
        for j, logrho_value in enumerate(logrho):
            new_row = [np.log10(D[j]), np.log10(S[i, j]), np.log10(tau[i, j]), ye_value]
            input_values.append(new_row)
    
    input_values = np.array(input_values, dtype=np.float32)
    input_values = (input_values - train_mean_in) / train_std_in
    
    # Process the data
    predictions = process_batch(context, input_values, inputs, outputs, bindings, stream)
    predictions = predictions.reshape(-1, 1)
    predictions = predictions * train_std_out + train_mean_out
    
    norm_function = l1_norm
    sliced_predictions = predictions[:, 0]
    targets = logpress[ye_index]
    targets = np.swapaxes(targets, 0, 1)
    target_values = []
    for a in range(n_logtemp):
        for b in range(n_logrho):
            target_values.append(targets[b, a])
    
    target_values = np.array(target_values)
    delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
    delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
    delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))
    
    im = ax.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))
    ax.set_xlabel(r"$\log \rho$", fontsize=fs-6)
    if W == W_values[0]:
        ax.set_ylabel(r"$\log T$", fontsize=fs-6)
    
    xt = [0, n_logrho//2, n_logrho]
    xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
    yt = [0, n_logtemp//2, n_logtemp]
    yl_K = np.round(np.log10(10**np.array([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]]) / k_B), 2)
    ax.set_xticks(xt)
    ax.set_xticklabels(xl)
    ax.set_yticks(yt)
    ax.set_yticklabels(yl_K)
    ax.set_title(f'W = {W}', fontsize=fs-6)

# Clean up
for input_buffer in inputs:
    input_buffer['device'].free()
for output_buffer in outputs:
    output_buffer['device'].free()

fig.colorbar(im, ax=axs.ravel().tolist(), shrink=0.4, label='Relative Error')
# plt.savefig('../images/plot.pdf')
# plt.savefig('../images/plot.png', dpi=300)
# plt.savefig('../images/plot.svg')
plt.show()

In [None]:
input_values.shape

In [None]:
import tensorrt as trt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os
import pycuda.driver as cuda
import pycuda.autoinit
import torch

fs = 22

# [Previous helper functions remain the same: load_engine, allocate_buffers, process_batch]
def load_engine(engine_path):
    """Load TensorRT engine"""
    logger = trt.Logger(trt.Logger.WARNING)
    with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def allocate_buffers(engine, batch_size=1):
    """Allocate device buffers and return input/output bindings"""
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    
    for binding in engine:
        shape = engine.get_binding_shape(binding)
        size = trt.volume(shape) * batch_size
        dtype = np.float32
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        bindings.append(int(device_mem))
        if engine.binding_is_input(binding):
            inputs.append({'host': host_mem, 'device': device_mem, 'shape': shape})
        else:
            outputs.append({'host': host_mem, 'device': device_mem, 'shape': shape})
    return inputs, outputs, bindings, stream

def process_batch(context, input_data, inputs, outputs, bindings, stream):
    """Process a batch of data"""
    input_buffer = np.zeros(inputs[0]['shape'], dtype=np.float32)
    input_buffer[:input_data.shape[0]] = input_data
    
    np.copyto(inputs[0]['host'], input_buffer.ravel())
    cuda.memcpy_htod_async(inputs[0]['device'], inputs[0]['host'], stream)
    
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    
    cuda.memcpy_dtoh_async(outputs[0]['host'], outputs[0]['device'], stream)
    stream.synchronize()
    
    return outputs[0]['host'][:input_data.shape[0]]

# Load EOS table and prepare data
eos_table = read_eos_table(os.path.join(eos_tables_dir, eos_table_filename))
ye = eos_table["ye"][()]
ye_index = 6
ye_value = ye[ye_index]
k_B = 8.617333262145e-11  # Boltzmann constant in MeV/K
logrho = eos_table["logrho"][()]
logtemp = eos_table["logtemp"][()]
logpress = eos_table["logpress"][()]
logenergy = eos_table["logenergy"][()]
targets = logenergy[ye_index]
targets = np.swapaxes(targets, 0, 1)
n_logrho, n_logtemp = targets.shape
logeps, logp = logenergy[ye_index], logpress[ye_index]

# Load TensorRT engines
engine_fp32_path = "../models/NNC2P_Tabulated.engine"
engine_fp16_path = "../models/NNC2P_Tabulated_FP16.engine"

engine_fp32 = load_engine(engine_fp32_path)
engine_fp16 = load_engine(engine_fp16_path)

context_fp32 = engine_fp32.create_execution_context()
context_fp16 = engine_fp16.create_execution_context()

inputs_fp32, outputs_fp32, bindings_fp32, stream_fp32 = allocate_buffers(engine_fp32)
inputs_fp16, outputs_fp16, bindings_fp16, stream_fp16 = allocate_buffers(engine_fp16)

# Create figure with adjusted spacing
W_values = [1.02, 1.1, 1.25, 1.4]
fig = plt.figure(figsize=(16, 8))  # Reduced height
plt.subplots_adjust(hspace=0.05, wspace=0.3)  # Much tighter vertical spacing
gs = fig.add_gridspec(3, 4, left=0.1, right=0.9, top=0.95, bottom=0.08)
axs = np.array([[fig.add_subplot(gs[i, j]) for j in range(4)] for i in range(3)])

# Define models to process
models = [
    ("PyTorch", model, None, None, None, None),
    ("TensorRT FP32", None, engine_fp32, context_fp32, inputs_fp32, bindings_fp32),
    ("TensorRT FP16", None, engine_fp16, context_fp16, inputs_fp16, bindings_fp16)
]

all_delta_vals = []
for row, (model_name, pytorch_model, engine, context, inputs, bindings) in enumerate(models):
    for col, W in enumerate(W_values):
        v = np.sqrt((W ** 2 - 1)/W**2)
        temp, rho, eps, p = 10 ** logtemp, 10 ** logrho, 10 ** logeps, 10 ** logp
        h, D = 1 + eps + p / rho, rho * W
        S, tau = rho * h * W ** 2 * v, rho * h * W ** 2 - p - D
        
        input_values = []
        for i, logtemp_value in enumerate(logtemp):
            for j, logrho_value in enumerate(logrho):
                new_row = [np.log10(D[j]), np.log10(S[i, j]), np.log10(tau[i, j]), ye_value]
                input_values.append(new_row)
        
        input_values = np.array(input_values, dtype=np.float32)
        input_values = (input_values - train_mean_in) / train_std_in
        
        # Get predictions based on model type
        if model_name == "PyTorch":
            with torch.no_grad():
                input_values_torch = torch.from_numpy(input_values).float().to(device)
                predictions = pytorch_model(input_values_torch)
                predictions = predictions.cpu().numpy()
        else:
            # Use appropriate stream for each TensorRT model
            stream = stream_fp32 if model_name == "TensorRT FP32" else stream_fp16
            predictions = process_batch(context, input_values, inputs, outputs_fp32 if model_name == "TensorRT FP32" else outputs_fp16, 
                                     bindings, stream)
            predictions = predictions.reshape(-1, 1)
            
        predictions = predictions * train_std_out + train_mean_out
        
        norm_function = l1_norm
        sliced_predictions = predictions[:, 0]
        targets = logpress[ye_index]
        targets = np.swapaxes(targets, 0, 1)
        target_values = []
        for a in range(n_logtemp):
            for b in range(n_logrho):
                target_values.append(targets[b, a])
        
        target_values = np.array(target_values)
        delta_vals = norm_function(target_values, sliced_predictions, reduction=False)
        delta_vals = delta_vals.reshape((n_logtemp, n_logrho))
        delta_vals = delta_vals/target_values.reshape((n_logtemp, n_logrho))
        all_delta_vals.append(delta_vals)
        
        ax = axs[row, col]
        im = ax.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=np.min(delta_vals), vmax=np.max(delta_vals)))
        
        # Set labels and ticks
        if row == 2:  # Bottom row
            ax.set_xlabel(r"$\log \rho [g/cm^3]$", fontsize=fs-6)
        if col == 0:  # First column
            ax.set_ylabel(r"$\log T [K]$", fontsize=fs-6)
            
        xt = [0, n_logrho//2, n_logrho - 1]
        xl = np.round([logrho[0], logrho[n_logrho // 2], logrho[-1]], 2)
        yt = [0, n_logtemp//2, n_logtemp - 1]
        yl_K = np.round(np.log10(10**np.array([logtemp[0], logtemp[n_logtemp // 2], logtemp[-1]]) / k_B), 2)
        
        # Only set xticks for bottom row
        if row == 2:
            ax.set_xticks(xt)
            ax.set_xticklabels(xl)
        else:
            ax.set_xticks([])
            
        # Only set yticks for first column
        if col == 0:
            ax.set_yticks(yt)
            ax.set_yticklabels(yl_K)
        else:
            ax.set_yticks([])
        
        # Set title only for top row
        if row == 0:
            ax.set_title(f'W = {W}', fontsize=fs-6)

# Get global vmin and vmax for consistent color scaling
global_vmin = min(np.min(dv) for dv in all_delta_vals)
global_vmax = max(np.max(dv) for dv in all_delta_vals)

# Re-run plotting, setting norm with global vmin and vmax
for row, (model_name, pytorch_model, engine, context, inputs, bindings) in enumerate(models):
    for col, W in enumerate(W_values):
        ax = axs[row, col]
        delta_vals = all_delta_vals[row * len(W_values) + col]
        im = ax.imshow(delta_vals, origin="lower", norm=LogNorm(vmin=global_vmin, vmax=global_vmax))

# Add row labels with adjusted positions
fig.text(0.03, 0.81, 'PyTorch Model', rotation=90, va='center', fontsize=fs-4)
fig.text(0.03, 0.52, 'TensorRT FP32', rotation=90, va='center', fontsize=fs-4)
fig.text(0.03, 0.22, 'TensorRT FP16', rotation=90, va='center', fontsize=fs-4)

# Add colorbar with adjusted position
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(im, cax=cbar_ax, label='Relative Error')
cbar.ax.tick_params(labelsize=fs-4)
cbar.ax.set_ylabel('Relative Error', fontsize=fs-2)

# Clean up TensorRT resources
for input_buffer in inputs_fp32:
    input_buffer['device'].free()
for output_buffer in outputs_fp32:
    output_buffer['device'].free()
    
for input_buffer in inputs_fp16:
    input_buffer['device'].free()
for output_buffer in outputs_fp16:
    output_buffer['device'].free()

plt.savefig('../images/plot.pdf')
plt.savefig('../images/plot.png', dpi=1200)
plt.savefig('../images/plot.svg', dpi=1200)
plt.show()