In [62]:
#!git clone https://github.com/valdemarskou/PINNs

In [63]:
# @title imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d



import cudatorch_mpfd_solver as torchsolver

from training_data import generateData
import pandas as pd
from torch.utils.data import Dataset, DataLoader, Subset
import random

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import ast
import re

In [64]:
# @title clean and convert data
def clean_and_convert_t(s):
    s = s.strip()  # Remove leading/trailing whitespace (including \r\n)
    s = s.replace("\r", "").replace("\n", "")  # Remove any remaining newline artifacts

    match = re.search(r"tensor\((\[.*?\])\)", s)  # Extract only the list part
    if match:
        return torch.tensor(ast.literal_eval(match.group(1)), dtype=torch.float32)
    else:
        raise ValueError(f"Unexpected format for t: {s}")




def clean_and_convert_output(s):
    s = s.strip()  # Remove leading/trailing whitespace (including \r\n)
    s = s.replace("\r", "").replace("\n", "")  # Remove newline artifacts

    # Extract all array([...]) groups inside the list
    matches = re.findall(r"array\(\s*(\[.*?\])", s)  # Find all arrays inside the string
    if matches:
        # Convert each extracted list into a PyTorch tensor
        return [torch.tensor(ast.literal_eval(arr), dtype=torch.float32) for arr in matches]
    else:
        raise ValueError(f"Unexpected format for output: {s}")

In [65]:
# @title interpolation function (with device conversion)
def interpolate_at_time(s, t, v):
    """
    Interpolates the tensor trajectory (list of tensors) v at time s using the timepoints t.
    All tensors are assumed to be on the correct device.
    """
    if s <= t[0]:
        return v[0].to(device) if isinstance(v[0], torch.Tensor) else v[0]
    if s >= t[-1]:
        return v[-1].to(device) if isinstance(v[-1], torch.Tensor) else v[-1]

    # Find the segment where s lies, i.e. find index i such that t[i] <= s <= t[i+1]
    for i in range(len(t) - 1):
        if t[i] <= s <= t[i+1]:
            # Compute the interpolation factor alpha: 0 when s==t[i], 1 when s==t[i+1]
            alpha = (s - t[i]) / (t[i+1] - t[i])
            # Use torch.lerp (linear interpolation): lerp(start, end, weight)
            val = torch.lerp(v[i].to(device), v[i+1].to(device), alpha)
            return val
    raise ValueError("The timepoint s is not within the range of t.")

In [66]:
# @title load dataset
class PDETrajectoryDataset(Dataset):

    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        # Convert the stored strings to proper tensors using your functions:
        self.df["t"] = self.df["t"].apply(clean_and_convert_t)
        self.df["output"] = self.df["output"].apply(clean_and_convert_output)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # 'output' is assumed to be a list of tensors representing the PDE trajectory.
        trajectory = [state.to(device) for state in row["output"]]
        # h0 is the first element of the ground truth trajectory.
        t = row["t"]
        # Also ensure the timepoints are on device if needed (if they're tensors)
        if isinstance(t, torch.Tensor):
            t = t.to(device)
        return t, trajectory

In [67]:
# @title define cnn
class CorrectionCNN(nn.Module):
    def __init__(self):
        super(CorrectionCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=12, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=12, out_channels=1, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Expect input x shape: (batch, length). Add channel dimension -> (batch, 1, length)
        x = x.unsqueeze(1)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.conv4(x)
        # Remove the channel dimension, returning shape: (batch, length)
        x = x.squeeze(1)
        return x

# Utility to count parameters (for verification)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [68]:
# @title SOL hybrid solver

def SOL_hybridSolver(tN, psiInitial, Cfun, Kfun, thetafun, sink, correction_net):
    dt = 120.
    zN = 40.
    flag = 0

    # Get the traditional solver information ready:
    z, t, dts, dz, n, nt, zN, psi, psiB, psiT, pars = torchsolver.setup(dt, tN, zN, psiInitial, torchsolver.havercampSetpars)
    # Move psi and other tensors to device if they aren’t already (assuming torchsolver.setup returns CPU tensors)
    psi = psi.to(device)
    psiB = psiB.to(device)
    psiT = psiT.to(device)

    psiList = []
    psiList += [psi]

    if flag == 0:
        for j in range(1, nt):
            uncorrectedTrajectory = torchsolver.dirichletOneStepModelRun(dts[j-1], dz, n, psiList[j-1], psiB[j-1], psiT[j-1], pars, Cfun, Kfun, thetafun, sink)
            # Ensure uncorrectedTrajectory is on device
            uncorrectedTrajectory = uncorrectedTrajectory.to(device)
            h_batch = uncorrectedTrajectory.unsqueeze(0)
            correction = correction_net(h_batch)
            psiList += [uncorrectedTrajectory + correction.squeeze(0)]
    return psiList, t


In [69]:
# @title PRE hybrid solver
def PRE_hybridSolver(tN, psiInitial, Cfun, Kfun, thetafun, sink, correction_net):
    dt = 120.
    zN = 40.
    flag = 0

    # Get the traditional solver information ready:
    z, t, dts, dz, n, nt, zN, psi, psiB, psiT, pars = torchsolver.setup(dt, tN, zN, psiInitial, torchsolver.havercampSetpars)
    # Ensure tensors from setup are moved to the device
    #psi = psi.to(device)
    #psiB = psiB.to(device)
    #psiT = psiT.to(device)

    psiList = torchsolver.fullModelRun(dt, dts, dz, n, nt, psi, psiB, psiT, pars, Cfun, Kfun, thetafun, flag, sink)
    # Apply correction network and ensure results are on device
    psiList[1:] = [(h.to(device) + correction_net(h.unsqueeze(0)).to(device)).squeeze(0) for h in psiList[1:]]
    return psiList, t


In [70]:
# @title NN training procedure
#%% Cell: NN training procedure
def train_hybrid_solver(hybridSolver, correction_net, data_loader, optimizer, num_epochs=10, checkpoint_interval=1):
    correction_net.train()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for batch in data_loader:
            # Each batch is assumed to be (t, output)
            t_batch, output_batch = batch

            t_instance = t_batch.squeeze(0)
            trajectory_gt = output_batch
            # Removing boundary states and ensuring each tensor is on device
            trajectory_gt = [state[1:-1].to(device) for state in trajectory_gt]

            # Extract solver parameters:
            tN = t_instance[-1]
            psiInitial = output_batch[1]
            # --- Call the hybrid solver ---
            corrected_traj, solver_t = hybridSolver(tN, psiInitial, torchsolver.havercampCfun, torchsolver.havercampKfun, torchsolver.havercampthetafun, torchsolver.zeroFun, correction_net)

            # --- Compute the loss ---
            loss = 0.0
            num_steps = len(solver_t)
            for j, s in enumerate(solver_t):
                # interpolate_at_time returns the ground truth state at time s given the tuple (t_instance, trajectory_gt)
                gt_state = interpolate_at_time(s, t_instance, trajectory_gt)
                # Compute mean squared error for this time step:
                loss += torch.mean((corrected_traj[j] - gt_state) ** 2)
            loss = loss / num_steps  # average over all time steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

        # Save checkpoint.
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = f'correction_net_epoch_{epoch+1}.pth'
            torch.save(correction_net.state_dict(), checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")

    final_path = 'correction_net_final.pth'
    torch.save(correction_net.state_dict(), final_path)
    print(f"Saved final model weights as {final_path}")


In [71]:
# @title Data loader
def custom_collate_fn(batch):
    # If batch size is 1, just return the single tuple instead of a list with one element.
    if len(batch) == 1:
        return batch[0]
    else:
        ts, outputs = zip(*batch)
        # For 't' assume all samples have the same shape, so you can stack them:
        ts = torch.stack(ts, 0)
        # 'outputs' will remain a tuple of the ground truth trajectories
        return ts, list(outputs)

csv_file = "high_fidelity_training_data.csv"  # Replace with your CSV file path
dataset = PDETrajectoryDataset(csv_file)
data_loader = DataLoader(dataset, batch_size=1,collate_fn=custom_collate_fn, shuffle=True)

In [72]:
correction_net = CorrectionCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
correction_net.to(device)
# optional: load weights
#correction_net.load_state_dict(torch.load("correction_net_single_element_dataset.pth"))
optimizer = optim.Adam(correction_net.parameters(), lr=0.01)
#train_hybrid_solver(PRE_hybridSolver,correction_net, data_loader, optimizer, num_epochs=20, checkpoint_interval=1)

  t = torch.hstack([t, torch.tensor(tN, device=device)])
  psi = torch.tensor(psiInitial[1:-1], dtype=torch.float32, device=device)


Epoch 1/20, Loss: 764.9514
Saved checkpoint: correction_net_epoch_1.pth
Epoch 2/20, Loss: 753.0653
Saved checkpoint: correction_net_epoch_2.pth
Epoch 3/20, Loss: 759.0129
Saved checkpoint: correction_net_epoch_3.pth
Epoch 4/20, Loss: 756.1947
Saved checkpoint: correction_net_epoch_4.pth
Epoch 5/20, Loss: 754.3912
Saved checkpoint: correction_net_epoch_5.pth
Epoch 6/20, Loss: 757.3426
Saved checkpoint: correction_net_epoch_6.pth
Epoch 7/20, Loss: 755.8291
Saved checkpoint: correction_net_epoch_7.pth
Epoch 8/20, Loss: 754.6008
Saved checkpoint: correction_net_epoch_8.pth


KeyboardInterrupt: 

In [None]:
# @title Single element training


'''
correction_net.load_state_dict(torch.load("correction_net_single_element_dataset.pth"))

single_element_dataset = Subset(dataset, [0])
single_data_loader = DataLoader(single_element_dataset, batch_size=1,collate_fn=custom_collate_fn, shuffle=False)
#train_hybrid_solver(PRE_hybridSolver, correction_net, single_data_loader, optimizer, num_epochs=20, checkpoint_interval=1)
t_batch, output_batch=single_element_dataset[0]
t_instance = t_batch.squeeze(0)
trajectory_gt = output_batch
trajectory_gt = [state[1:-1].to(device) for state in trajectory_gt]


tN = t_instance[-1]
psiInitial = output_batch[1]


corrected_traj, solver_t = PRE_hybridSolver(tN,psiInitial,torchsolver.havercampCfun,torchsolver.havercampKfun,torchsolver.havercampthetafun,torchsolver.zeroFun,correction_net)

dt = 120.
zN = 40.
z,t,dts,dz,n,nt,zN,psi,psiB,psiT,pars = torchsolver.setup(dt,tN,zN,psiInitial,torchsolver.havercampSetpars)

psiList = torchsolver.fullModelRun(dt,dts,dz,n,nt,psi,psiB,psiT,pars, torchsolver.havercampCfun,torchsolver.havercampKfun,torchsolver.havercampthetafun,0,torchsolver.zeroFun)

print(torch.mean(torch.abs(corrected_traj[-1]-trajectory_gt[-1])))
print(torch.mean(torch.abs(psiList[-1]-trajectory_gt[-1])))


