In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import meshgrid
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.autograd import gradcheck
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import shift
from math import sqrt
from scipy.sparse import diags
from scipy.linalg import cholesky

import os
inpath = "./data/Wildfire/Poly/input/"   # Path for the input data
impath = "./data/Wildfire/Poly/mac/"
immpath = "./plots/Wildfire/Poly/mac/"
os.makedirs(impath, exist_ok=True)
os.makedirs(immpath, exist_ok=True)

# 1D Wildlandfire example

In [None]:
Q_wf = np.load(inpath + 'SnapShotMatrix558.49.npy', allow_pickle=True)
t = np.load(inpath + 'Time.npy', allow_pickle=True)
x_grid = np.load(inpath + '1D_Grid.npy', allow_pickle=True)
x = x_grid[0]
T = Q_wf[:len(x), :]
seed = 133

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
Q = torch.tensor(T/T.max())

In [None]:
Nx = len(x)
Nt = len(t)
xx, tt = np.meshgrid(x, t)


plt.pcolormesh(xx.T, tt.T, Q, cmap='hot')
plt.colorbar()
plt.xlabel("x")
plt.ylabel("t")
plt.title('Snapshot Matrix Q')
plt.show()

## Define inputs

In [None]:
inputs = np.stack([x.repeat(Nt), np.tile(t, Nx)], axis=1)
inputs_tensor = torch.tensor(inputs, dtype=torch.float32)

## Define a model

In [None]:
class NuclearNormAutograd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_matrix):
        ctx.save_for_backward(input_matrix)
        return torch.linalg.matrix_norm(input_matrix, ord="nuc")

    @staticmethod
    def backward(ctx, grad_output):
        input_matrix, = ctx.saved_tensors
        u, s, v = torch.svd(input_matrix, some=False)
        rank = torch.sum(s > 0).item()
        dtype = input_matrix.dtype
        eye_approx = torch.diag((s > 0).to(dtype)[:rank])
        grad_input = torch.matmul(torch.matmul(u[:, :rank], eye_approx), v[:, :rank].t())
        return grad_input * grad_output.unsqueeze(-1).unsqueeze(-1)

In [None]:
class ShapeShiftNet(nn.Module):
    def __init__(self, p_init_coeffs1, p_init_coeffs2, p_init_coeffs3):
        super(ShapeShiftNet, self).__init__()
        
        self.alphas1 = nn.ParameterList(
            [nn.Parameter(torch.tensor([coeff], dtype=torch.float32), requires_grad=True) for coeff in p_init_coeffs1]
        )
        self.alphas2 = nn.ParameterList(
            [nn.Parameter(torch.tensor([coeff], dtype=torch.float32), requires_grad=True) for coeff in p_init_coeffs2]
        )
        self.alphas3 = nn.ParameterList(
            [nn.Parameter(torch.tensor([coeff], dtype=torch.float32), requires_grad=True) for coeff in p_init_coeffs3]
        )
        
        self.elu = nn.ELU()
        
        # Subnetwork for f^1
        self.f1_fc1 = nn.Linear(2, 5)
        self.f1_fc2 = nn.Linear(5, 10)
        self.f1_fc3 = nn.Linear(10, 5)
        self.f1_fc4 = nn.Linear(5, 1)
        
        # Subnetwork for f^2
        self.f2_fc1 = nn.Linear(2, 5)
        self.f2_fc2 = nn.Linear(5, 10)
        self.f2_fc3 = nn.Linear(10, 5)
        self.f2_fc4 = nn.Linear(5, 1)
        
        # Subnetwork for f^3
        self.f3_fc1 = nn.Linear(2, 5)
        self.f3_fc2 = nn.Linear(5, 10)
        self.f3_fc3 = nn.Linear(10, 5)
        self.f3_fc4 = nn.Linear(5, 1)
        
        
    def forward(self, x, t):
        # Pathway for f^1 and shift^1
        shift1 = sum([coeff * t**(1-i) for i, coeff in enumerate(self.alphas1)])
        
        x_shifted1 = x + shift1
        f1 = self.elu(self.f1_fc1(torch.cat((x_shifted1, t), dim=1)))
        f1 = self.elu(self.f1_fc2(f1))
        f1 = self.elu(self.f1_fc3(f1))
        f1 = self.f1_fc4(f1)
        
        f1_without_shift = self.elu(self.f1_fc1(torch.cat((x, t), dim=1)))
        f1_without_shift = self.elu(self.f1_fc2(f1_without_shift))
        f1_without_shift = self.elu(self.f1_fc3(f1_without_shift))
        f1_without_shift = self.f1_fc4(f1_without_shift)
        
        
        # Pathway for f^2 and shift^2
        shift2 = sum([coeff * t**(1-i) for i, coeff in enumerate(self.alphas2)])
        
        x_shifted2 = x + shift2
        f2 = self.elu(self.f2_fc1(torch.cat((x_shifted2, t), dim=1)))
        f2 = self.elu(self.f2_fc2(f2))
        f2 = self.elu(self.f2_fc3(f2))
        f2 = self.f2_fc4(f2)
        
        f2_without_shift = self.elu(self.f2_fc1(torch.cat((x, t), dim=1)))
        f2_without_shift = self.elu(self.f2_fc2(f2_without_shift))
        f2_without_shift = self.elu(self.f2_fc3(f2_without_shift))
        f2_without_shift = self.f2_fc4(f2_without_shift)
        
        
        # Pathway for f^3 and shift^3      
        shift3 = sum([coeff * t**(1-i) for i, coeff in enumerate(self.alphas3)])
        
        x_shifted3 = x + shift3
        f3 = self.elu(self.f3_fc1(torch.cat((x_shifted3, t), dim=1)))
        f3 = self.elu(self.f3_fc2(f3))
        f3 = self.elu(self.f3_fc3(f3))
        f3 = self.f3_fc4(f3)
        
        f3_without_shift = self.elu(self.f3_fc1(torch.cat((x, t), dim=1)))
        f3_without_shift = self.elu(self.f3_fc2(f3_without_shift))
        f3_without_shift = self.elu(self.f3_fc3(f3_without_shift))
        f3_without_shift = self.f3_fc4(f3_without_shift)
        
        return f1, f2, f3, f1_without_shift, f2_without_shift, f3_without_shift


In [None]:
def save_fig(filepath, figure=None, **kwargs):
    import tikzplotlib
    import os
    import matplotlib.pyplot as plt

    ## split extension
    fpath = os.path.splitext(filepath)[0]
    ## get figure handle
    if figure is None:
        figure = plt.gcf()
    figure.savefig(fpath + ".png", dpi=300, transparent=True)
    tikzplotlib.save(
        figure=figure,
        filepath=fpath + ".tex",
        axis_height='\\figureheight',
        axis_width='\\figurewidth',
        override_externals=True,
        **kwargs
    )

In [None]:
init_coefficients1 = [1, -1]
init_coefficients2 = [0, 0]
init_coefficients3 = [-1, 1]

In [None]:
model = ShapeShiftNet(init_coefficients1, init_coefficients2, init_coefficients3)

pretrained_load = True
if pretrained_load:
#     state_dict_original = torch.load("./data/Crossing_waves/Poly/seed=54/Crossing_waves.pth")
    state_dict_original = torch.load(impath + "Wildfire.pth")
    state_dict_new = model.state_dict()
    
    for name, param in state_dict_original.items():
        if name in state_dict_new:
            state_dict_new[name].copy_(param)  


#     state_dict_new['alphas1.0'] = torch.tensor([0.01], dtype=torch.float32)  # state_dict_original['alphas2.0']
#     state_dict_new['alphas1.1'] = torch.tensor([-0.01], dtype=torch.float32)   # state_dict_original['alphas2.1']
#     state_dict_new['alphas2.0'] = torch.tensor([0], dtype=torch.float32)   # -state_dict_original['alphas2.0']
#     state_dict_new['alphas2.1'] = torch.tensor([0], dtype=torch.float32)  # -state_dict_original['alphas2.1']
#     state_dict_new['alphas3.0'] = torch.tensor([-0.01], dtype=torch.float32)   # -state_dict_original['alphas2.0']
#     state_dict_new['alphas3.1'] = torch.tensor([0.01], dtype=torch.float32)  # -state_dict_original['alphas2.1']


#     state_dict_new['f1_fc1.weight'] = state_dict_original['f2_fc1.weight']
#     state_dict_new['f1_fc1.bias'] = state_dict_original['f2_fc1.bias']
#     state_dict_new['f1_fc2.weight'] = state_dict_original['f2_fc2.weight']
#     state_dict_new['f1_fc2.bias'] = state_dict_original['f2_fc2.bias']
#     state_dict_new['f1_fc3.weight'] = state_dict_original['f2_fc3.weight']
#     state_dict_new['f1_fc3.bias'] = state_dict_original['f2_fc3.bias']
#     state_dict_new['f1_fc4.weight'] = state_dict_original['f2_fc4.weight']
#     state_dict_new['f1_fc4.bias'] = state_dict_original['f2_fc4.bias']
    
#     state_dict_new['f2_fc1.weight'] = state_dict_original['f2_fc1.weight']
#     state_dict_new['f2_fc1.bias'] = state_dict_original['f2_fc1.bias']
#     state_dict_new['f2_fc2.weight'] = state_dict_original['f2_fc2.weight']
#     state_dict_new['f2_fc2.bias'] = state_dict_original['f2_fc2.bias']
#     state_dict_new['f2_fc3.weight'] = state_dict_original['f2_fc3.weight']
#     state_dict_new['f2_fc3.bias'] = state_dict_original['f2_fc3.bias']
#     state_dict_new['f2_fc4.weight'] = state_dict_original['f2_fc4.weight']
#     state_dict_new['f2_fc4.bias'] = state_dict_original['f2_fc4.bias']
    
#     state_dict_new['f3_fc1.weight'] = state_dict_original['f2_fc1.weight']
#     state_dict_new['f3_fc1.bias'] = state_dict_original['f2_fc1.bias']
#     state_dict_new['f3_fc2.weight'] = state_dict_original['f2_fc2.weight']
#     state_dict_new['f3_fc2.bias'] = state_dict_original['f2_fc2.bias']
#     state_dict_new['f3_fc3.weight'] = state_dict_original['f2_fc3.weight']
#     state_dict_new['f3_fc3.bias'] = state_dict_original['f2_fc3.bias']
#     state_dict_new['f3_fc4.weight'] = state_dict_original['f2_fc4.weight']
#     state_dict_new['f3_fc4.bias'] = state_dict_original['f2_fc4.bias']
    
    model.load_state_dict(state_dict_new, strict=False)
    
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def TV(Q, Nx, Nt):
    
    tv_h = torch.pow(Q[:, 1:]-Q[:, :-1], 2).sum()

    return (tv_h)/(Nx * Nt)

In [None]:
num_epochs = 20000
lambda_k = 1.5
lambda_TV = 0.5

for epoch in range(num_epochs + 1):
    x_NN, t_NN = inputs_tensor[:, 0:1], inputs_tensor[:, 1:2]
    
    optimizer.zero_grad()
    f1_full, f2_full, f3_full, f1_full_nos, f2_full_nos, f3_full_nos = model(x_NN,t_NN)
    
    frobenius_loss = torch.norm(Q - f1_full.view(Nx, Nt) - f2_full.view(Nx, Nt) - f3_full.view(Nx, Nt),  'fro') ** 2
    
    nuclear_loss_q1 = NuclearNormAutograd.apply(f1_full_nos.view(Nx, Nt)) 
    nuclear_loss_q2 = NuclearNormAutograd.apply(f2_full_nos.view(Nx, Nt))
    nuclear_loss_q3 = NuclearNormAutograd.apply(f3_full_nos.view(Nx, Nt))
    nuclear_loss = lambda_k * (nuclear_loss_q1 + nuclear_loss_q2 + nuclear_loss_q3)
    
    TV_loss = lambda_TV * (TV(f1_full_nos.view(Nx, Nt), Nx, Nt) + TV(f2_full_nos.view(Nx, Nt), Nx, Nt) + TV(f3_full_nos.view(Nx, Nt), Nx, Nt))
    
    total_loss = nuclear_loss + frobenius_loss + TV_loss
    
    total_loss.backward(retain_graph=True)
    optimizer.step()
    
    shift_coeffs1 = torch.tensor([p.item() for p in model.alphas1])
    shift_coeffs2 = torch.tensor([p.item() for p in model.alphas2])
    shift_coeffs3 = torch.tensor([p.item() for p in model.alphas3])
    
    if frobenius_loss < 1.0:
        print("Early stopping is triggered")
        break
    
    if epoch % 100 == 0:
        print(
            f'Epoch {epoch}/{num_epochs}, Frob Loss: {frobenius_loss.item()}, Nuclear Loss: {nuclear_loss.item()}, Total loss: {total_loss.item()},'
            f'Coefficients_1:{shift_coeffs1}, Coefficients_2:{shift_coeffs2}, Coefficients_3:{shift_coeffs3}')

In [None]:
combined = f1_full + f2_full + f3_full
Q_tilde = combined.view(Nx, Nt).detach().numpy()

In [None]:
fig, axs = plt.subplots(1, 7, figsize=(16, 4))
vmin = np.min(Q_tilde)
vmax = np.max(Q_tilde)

#Qtilde
axs[0].pcolormesh(xx.T, tt.T, Q_tilde, vmin=vmin, vmax=vmax)
axs[0].set_title(r"$\mathbf{\tilde{Q}}$")
axs[0].set_xlabel("t")
axs[0].set_ylabel("x")
axs[0].set_xticks([])
axs[0].set_yticks([])

# f^1
axs[1].pcolormesh(xx.T, tt.T, f1_full.view(Nx, Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[1].set_title(r"$\mathcal{T}^1\mathbf{Q}^1$")
axs[1].set_xlabel("t")
axs[1].set_ylabel("x")
axs[1].set_xticks([])
axs[1].set_yticks([])

# f^3
axs[2].pcolormesh(xx.T, tt.T, f2_full.view(Nx,Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[2].set_title(r"$\mathcal{T}^2\mathbf{Q}^2$")
axs[2].set_xlabel("t")
axs[2].set_ylabel("x")
axs[2].set_xticks([])
axs[2].set_yticks([])

# f^2
axs[3].pcolormesh(xx.T, tt.T, f3_full.view(Nx,Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[3].set_title(r"$\mathcal{T}^3\mathbf{Q}^3$")
axs[3].set_xlabel("t")
axs[3].set_ylabel("x")
axs[3].set_xticks([])
axs[3].set_yticks([])


# f^1
axs[4].pcolormesh(xx.T, tt.T, f1_full_nos.view(Nx,Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[4].set_title(r"$\mathbf{Q}^1$")
axs[4].set_xlabel("t")
axs[4].set_ylabel("x")
axs[4].set_xticks([])
axs[4].set_yticks([])

# f^3
axs[5].pcolormesh(xx.T, tt.T, f2_full_nos.view(Nx,Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[5].set_title(r"$\mathbf{Q}^2$")
axs[5].set_xlabel("t")
axs[5].set_ylabel("x")
axs[5].set_xticks([])
axs[5].set_yticks([])

# f^2
cax4 = axs[6].pcolormesh(xx.T, tt.T, f3_full_nos.view(Nx,Nt).detach().numpy(), vmin=vmin, vmax=vmax)
axs[6].set_title(r"$\mathbf{Q}^3$")
axs[6].set_xlabel("t")
axs[6].set_ylabel("x")
axs[6].set_xticks([])
axs[6].set_yticks([])

plt.colorbar(cax4, ax=axs.ravel().tolist(), orientation='vertical')

In [None]:
save_fig(filepath=immpath + "Wildfire_NN", figure=fig)

## Saving the results

In [None]:
torch.save(model.state_dict(), impath + 'Wildfire.pth')
np.save(impath + 'Q.npy', Q)
np.save(impath + 'Q_tilde.npy', Q_tilde)
np.save(impath + 'T1Q1.npy', f1_full.view(Nx, Nt).detach().numpy())
np.save(impath + 'T3Q3.npy', f2_full.view(Nx, Nt).detach().numpy())
np.save(impath + 'T2Q2.npy', f3_full.view(Nx, Nt).detach().numpy())
np.save(impath + 'Q1.npy', f1_full_nos.view(Nx, Nt).detach().numpy())
np.save(impath + 'Q3.npy', f2_full_nos.view(Nx, Nt).detach().numpy())
np.save(impath + 'Q2.npy', f3_full_nos.view(Nx, Nt).detach().numpy())
np.save(impath + 'shifts1.npy', shift_coeffs1.detach().numpy()) 
np.save(impath + 'shifts2.npy', shift_coeffs2.detach().numpy()) 

# Apply rsPOD taking the results from the above as initial guesses

In [None]:
import sys

sys.path.append("./sPOD/lib/")

import numpy as np
from numpy import meshgrid
import matplotlib.pyplot as plt
from sPOD_algo import (
    shifted_POD,
    sPOD_Param,
    give_interpolation_error,
)
from transforms import Transform
from plot_utils import save_fig

In [None]:
# Load the data
Q = np.load(impath + 'Q.npy')
Q_tilde = np.load(impath + 'Q_tilde.npy')
T1Q1 = np.load(impath + 'T1Q1.npy')
T2Q2 = np.load(impath + 'T2Q2.npy')
T3Q3 = np.load(impath + 'T3Q3.npy')
Q1 = np.load(impath + 'Q1.npy')
Q2 = np.load(impath + 'Q2.npy')
Q3 = np.load(impath + 'Q3.npy')
shifts1 = np.load(impath + 'shifts1.npy')
shifts2 = np.load(impath + 'shifts2.npy') 

In [None]:
err = np.linalg.norm(Q - Q_tilde) / np.linalg.norm(Q)
print("NN prediction error: %1.2e " % err)

In [None]:
# Prepare the transformations
L = x[-1]
dx = x[1] - x[0]
s1 = np.polyval(shifts1, t)
s2 = np.zeros_like(s1)
s3 = np.polyval(shifts2, t)

data_shape = [Nx, 1, 1, Nt]
transfos = [
    Transform(data_shape, [L], shifts=s1, dx=[dx], interp_order=5),
    Transform(data_shape, [L], shifts=s2, dx=[dx], interp_order=5),
    Transform(data_shape, [L], shifts=s3, dx=[dx], interp_order=5),
]

interp_err = np.max([give_interpolation_error(Q, trafo) for trafo in transfos])
print("interpolation error: %1.2e " % interp_err)

In [None]:
METHOD = "ALM"

# Parameters
mu0 = Nx * Nt / (4 * np.sum(np.abs(Q)))
lambd0 = 4000
myparams = sPOD_Param()
myparams.maxit = 100
param_alm = mu0 * 0.01 # adjust for case

In [None]:
# Call the ALM method
ret = shifted_POD(Q, transfos, [1, 1, 1], myparams, METHOD, param_alm, [T1Q1, T2Q2, T3Q3])

In [None]:
sPOD_frames, qtilde, rel_err = ret.frames, ret.data_approx, ret.rel_err_hist
qf = [
    np.squeeze(np.reshape(trafo.apply(frame.build_field()), data_shape))
    for trafo, frame in zip(transfos, ret.frames)
]

In [None]:
# %% 1. visualize your results: sPOD frames
############################################

fig, axs = plt.subplots(1, 5, figsize=(25, 6))
vmin = np.min(Q)
vmax = np.max(Q)

axs[0].pcolormesh(xx.T, tt.T, Q, vmin=vmin, vmax=vmax)
axs[0].set_title(r"${\mathbf{Q}}$")
axs[0].set_xlabel("x")
axs[0].set_ylabel("t")
axs[0].set_xticks([])
axs[0].set_yticks([])

#Qtilde
axs[1].pcolormesh(xx.T, tt.T, qtilde, vmin=vmin, vmax=vmax)
axs[1].set_title(r"$\tilde{\mathbf{Q}}$")
axs[1].set_xlabel("x")
axs[1].set_ylabel("t")
axs[1].set_xticks([])
axs[1].set_yticks([])

# f^1
axs[2].pcolormesh(xx.T, tt.T, qf[0], vmin=vmin, vmax=vmax)
axs[2].set_title(r"$\mathcal{T}^1\mathbf{Q}^1$")
axs[2].set_xlabel("x")
axs[2].set_ylabel("t")
axs[2].set_xticks([])
axs[2].set_yticks([])

# f^2
axs[3].pcolormesh(xx.T, tt.T, qf[1], vmin=vmin, vmax=vmax)
axs[3].set_title(r"$\mathcal{T}^2\mathbf{Q}^2$")
axs[3].set_xlabel("x")
axs[3].set_ylabel("t")
axs[3].set_xticks([])
axs[3].set_yticks([])

# f^3
cax4 = axs[4].pcolormesh(xx.T, tt.T, qf[2], vmin=vmin, vmax=vmax)
axs[4].set_title(r"$\mathcal{T}^3\mathbf{Q}^3$")
axs[4].set_xlabel("x")
axs[4].set_ylabel("t")
axs[4].set_xticks([])
axs[4].set_yticks([])

plt.colorbar(cax4, ax=axs.ravel().tolist(), orientation='vertical')

In [None]:
save_fig(filepath=immpath + "Wildfire_sPOD", figure=fig)