# Parareal Iterations

In [1]:
from generate_data.initial_conditions import initial_condition_gaussian
from notebook_workflow.utils import get_velocity_model, pseudo_spectral_solutions
import sys
from models.model_end_to_end import get_model
sys.path.append("..")
from generate_data.utils_wave_propagate import one_iteration_pseudo_spectral_tensor
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from parareal.param_settings import get_paths, get_params


def get_parareal_solutions(
        vel_data_path = "../data/velocity_profiles/crops_bp_m_200_256.npz",
        res = 256,
        res_padded = 256,
        n_it = 10
):
    param_dict = get_params("0")
    model = get_model(
        param_dict = param_dict,
        res_scaler = 2,
        model_res = 256,
    )
    vel = get_velocity_model(vel_data_path)

    # computing initial condition using gaussian pulse (switch to pytorch tensor if needed)
    u, ut = initial_condition_gaussian(
        vel,
        resolution=res,
        boundary_condition="periodic",
        mode="physical_components",
        optimization="none",
        res_padded=res
    )

    ps_sol_tensor = pseudo_spectral_solutions(u, ut, vel, n_it,
                                       param_dict["f_delta_x"],
                                       param_dict["f_delta_t"],
                                       param_dict["delta_t_star"])

get_parareal_solutions()


def parareal_scheme(
        model,
        u_0,
        n_parareal,
        n_snapshots
):

    '''
    Parameters
    ----------
    model : (pytorch.Model) end-to-end model to advance a wave front
    u_0 : (pytorch tensor) input wave field
    label : (pytorch tensor) output wave field (one iteration with length dt_star of end-to-end model)
    n_parareal : (int) number of parareal iterations
    n_snapshots : (int) number of iterations (number of iterations with length dt_star)

    Returns
    -------
    one iteration of Parareal according to formula in thesis
    '''

    u_n = u_0.clone()
    vel = u_n[:,3].clone().unsqueeze(dim=1)
    batch_size, channel, width, height = u_n.shape
    big_tensor = torch.zeros([n_snapshots, batch_size, channel - 1, width, height])

    # initial guess, first iteration without parareal
    for n in range(n_snapshots-1):
        u_n1 = model(u_n)
        big_tensor[n+1] = u_n1
        u_n = torch.cat((u_n1, vel), dim=1)

    # parareal iterations: k = 1, 2, 3, 4
    for k in range(1,n_parareal+1):
        print(k)

        big_tensor[0] = u_0[:, :3].clone()
        parareal_terms = get_parareal_terms(model.to(device), big_tensor.to(device).clone().detach(), n_snapshots, vel.to(device).clone().detach()) # n_snapshots x b x c x w x h
        new_big_tensor = torch.zeros([n_snapshots, batch_size, channel - 1, width, height])
        new_big_tensor[0] = u_0[:, :3].clone()

        for n in range(n_snapshots-1):
            u_n_k1 = torch.cat((new_big_tensor[n].to(device), vel.to(device)), dim=1).to(device)
            u_n1_k1 = model(u_n_k1) + parareal_terms[n].to(device)
            new_big_tensor[n+1] = u_n1_k1

        big_tensor = new_big_tensor.clone().detach()


def get_parareal_terms(
        model,
        big_pseudo_tensor,
        n_snapshots,
        vel
):
    '''
    Parameters
    ----------
    model : (pytorch.Model) end-to-end model to advance a wave front
    big_pseudo_tensor : (pytorch tensor) tensor containing previous solution (high resolution due to pseudo-spectral cropping)
    n_snapshots : (int) number of iterations (number of iterations with length dt_star)
    vel : (pytorch tensor) velocity profile

    Returns
    -------
    get Parareal terms that can be computed in parallel
    '''

    model.eval()
    with torch.no_grad():
        parareal_terms = torch.zeros(big_pseudo_tensor.shape)
        for s in range(n_snapshots):
            parareal_terms[s] = compute_parareal_term(model, torch.cat([big_pseudo_tensor[s], vel], dim=1))
    model.train()
    return parareal_terms


def compute_parareal_term(
        model,
        u_n_k
):
    '''
    Parameters
    ----------
    model : (pytorch.Model) end-to-end model to advance a wave front
    u_n_k : (pytorch tensor) current wave field

    Returns
    -------
    difference between Parareal terms of right-hand side of main Parareal equation (see thesis)
    '''

    res_fine_solver = one_iteration_pseudo_spectral_tensor(u_n_k)  # one_iteration_velocity_verlet(u_n_k)
    res_model = model(u_n_k)  # procrustes_optimization(model(u_n_k), res_fine_solver)

    return res_fine_solver.to(device) - res_model.to(device)

FileNotFoundError: [Errno 2] No such file or directory: '../results/run_3/good/saved_model_Interpolation_UNet3_AdamW_SmoothL1Loss_2_128_False_15.pt'