In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%matplotlib notebook
from argparse import ArgumentParser
import yaml
import os
import math
import torch
from torch.utils.data import DataLoader

import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
# from torch import vmap
from functorch import vmap, grad
from models import FNN2d, FNN3d
from train_utils import Adam
# from train_utils.datasets import BurgersLoader
# from train_utils.train_2d import train_2d_burger
# from train_utils.eval_2d import eval_burgers

from solver.WaveEq import WaveEq1D, WaveEq2D
import scipy.io
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import traceback

from tqdm import tqdm
from train_utils.utils import save_checkpoint, get_grid3d, convert_ic, torch2dgrid
from train_utils.losses import LpLoss, darcy_loss, PINO_loss, PINO_loss3d, get_forcing

from solver.random_fields import GaussianRF
from solver.my_random_fields import GRF_Mattern


from GRF import construct_grid, construct_points, RBF, dirichlet_matern, neumann_matern, periodic_matern, get_cholesky, generate_sample, generate_samples, plot_sample, setup_kernel
import GRF
from importlib import reload
GRF = reload(GRF)

try:
    import wandb
except ImportError:
    wandb = None


# Checkpoint Loading

In [2]:
def load_checkpoint(model, ckpt_path, optimizer=None):
    try:
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt['model'])
        print('Weights loaded from %s' % ckpt_path)
        if optimizer is not None:
            try:
                optimizer.load_state_dict(ckpt['optim'])
                print('Optimizer loaded from %s' % ckpt_path)
            except: traceback.print_exc()
            
    except:
        traceback.print_exc()

# Load/Update Config Functions:

In [3]:
def update_config(config, file):
    with open(file, 'w') as f:
        config_updated = yaml.dump(config, f)
        
def load_config(file):
    with open(file, 'r') as f:
        config = yaml.load(f, yaml.FullLoader)
    return config

# Define DataLoader Class for 2D Data

In [4]:
class DataLoader2D(object):
    def __init__(self, data, nx=128, nt=100, sub=1, sub_t=1):
#         dataloader = MatReader(datapath)
        self.sub = sub
        self.sub_t = sub_t            
        s = nx
        # if nx is odd
        if (s % 2) == 1:
            s = s - 1
        self.S = s // sub
        self.T = nt // sub_t
        self.T += 1
        data = data[:, 0:self.T:sub_t, 0:self.S:sub, 0:self.S:sub]
        self.data = data.permute(0, 2, 3, 1)
        
    def make_loader(self, n_sample, batch_size, start=0, train=True):
        a_data = self.data[start:start + n_sample, :, :, 0].reshape(n_sample, self.S, self.S)
        u_data = self.data[start:start + n_sample].reshape(n_sample, self.S, self.S, self.T)
#         Xs = self.x_data[start:start + n_sample].reshape(n_sample, self.S, self.S)
#         ys = self.y_data[start:start + n_sample].reshape(n_sample, self.S, self.S, self.T)
        gridx, gridy, gridt = get_grid3d(self.S, self.T)
        a_data = a_data.reshape(n_sample, self.S, self.S, 1, 1).repeat([1, 1, 1, self.T, 1])
        a_data = torch.cat((gridx.repeat([n_sample, 1, 1, 1, 1]),
                            gridy.repeat([n_sample, 1, 1, 1, 1]),
                            gridt.repeat([n_sample, 1, 1, 1, 1]),
                            a_data), dim=-1)
        dataset = torch.utils.data.TensorDataset(a_data, u_data)
        if train:
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        else:
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return loader

# Define Loss Function For Automatic Differentiation

In [5]:
def Autograd_Wave2D(u, grid, c=1.0):
    from torch.autograd import grad
    gridx, gridy, gridt = grid
    ut = grad(u.sum(), gridt, create_graph=True)[0]
    utt = grad(ut.sum(), gridt, create_graph=True)[0]
    ux = grad(u.sum(), gridx, create_graph=True)[0]
    uy = grad(u.sum(), gridy, create_graph=True)[0]
    uxx = grad(ux.sum(), gridx, create_graph=True)[0]
    uyy = grad(uy.sum(), gridy, create_graph=True)[0]
    Du = utt - c**2*(uxx + uyy)
    return Du, uxx, uyy, utt


def AD_loss_Wave2D(u, u0, grid, index_ic=None, p=None, q=None, c=1.0):
    batchsize = u.size(0)
    # lploss = LpLoss(size_average=True)

    Du, uxx, uyy, utt = Autograd_Wave(u, grid, c=c)

    if index_ic is None:
        # u in on a uniform grid
        nx = u.size(1)
        ny = u.size(2)
        nt = u.size(3)
        u = u.reshape(batchsize, nx, ny, nt)

        index_t = torch.zeros(nx,).long()
        index_x = torch.tensor(range(nx)).long()
        boundary_u = u[:, index_t, index_x]

        # loss_bc0 = F.mse_loss(u[:, :, 0], u[:, :, -1])
        # loss_bc1 = F.mse_loss(ux[:, :, 0], ux[:, :, -1])
    else:
        # u is randomly sampled, 0:p are BC, p:2p are ic, 2p:2p+q are interior
        boundary_u = u[:, :p]
        batch_index = torch.tensor(range(batchsize)).reshape(batchsize, 1).repeat(1, p)
        u0 = u0[batch_index, index_ic]

        # loss_bc0 = F.mse_loss(u[:, p:p+p//2], u[:, p+p//2:2*p])
        # loss_bc1 = F.mse_loss(ux[:, p:p+p//2], ux[:, p+p//2:2*p])

    loss_ic = F.mse_loss(boundary_u, u0)
    f = torch.zeros(Du.shape, device=u.device)
    loss_f = F.mse_loss(Du, f)
    return loss_ic, loss_f

# Define Loss for Fourier Derivatives

In [6]:
def FDM_Wave2D(u, D=1, c=1.0):
    batchsize = u.size(0)
    nx = u.size(1)
    ny = u.size(2)
    nt = u.size(3)
    u = u.reshape(batchsize, nx, ny, nt)
    dt = D / (nt-1)
    dx = D / (nx)

    u_h = torch.fft.fftn(u, dim=[1, 2])
    # Wavenumbers in y-direction
    k_max = nx//2
    N = nx
    k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device),
                     torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, N).reshape(1,N,N,1)
    k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device),
                     torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 1).reshape(1,N,N,1)
    ux_h = 2j *np.pi*k_x*u_h
    uxx_h = 2j *np.pi*k_x*ux_h
    uy_h = 2j *np.pi*k_y*u_h
    uyy_h = 2j *np.pi*k_y*uy_h
#     ux = torch.fft.irfft(ux_h[:, :, :k_max+1], dim=[1, 2])
    uxx = torch.fft.irfftn(uxx_h[:, :, :k_max+1], dim=[1, 2])
#     uy = torch.fft.irfft(uy_h[:, :, :k_max+1], dim=[1, 2])
    uyy = torch.fft.irfftn(uyy_h[:, :, :k_max+1], dim=[1, 2])

#     ut = (u[:, 2:, :] - u[:, :-2, :]) / (2 * dt)
    utt = (u[..., 2:] - 2.0*u[..., 1:-1] + u[..., :-2]) / (dt**2)
    Du = utt - c**2 * (uxx + uyy)[..., 1:-1]
    return Du


def PINO_loss_wave2D(u, u0, c=1.0):
    batchsize = u.size(0)
    nx = u.size(1)
    ny = u.size(2)
    nt = u.size(3)
    u = u.reshape(batchsize, nx, ny, nt)

    lploss = LpLoss(size_average=True)
    u_ic = u[..., 0]
    loss_ic = lploss(u_ic, u0)
#     index_t = torch.zeros(nx,).long()
#     index_x = torch.tensor(range(nx)).long()
#     boundary_u = u[:, index_t, index_x]
#     loss_u = F.mse_loss(boundary_u, u0)

#     Du = FDM_Wave(u, c=c)[:, :, :, :]
    Du = FDM_Wave2D(u, c=c)
    f = torch.zeros(Du.shape, device=u.device)
    loss_f = F.mse_loss(Du, f)

    # loss_bc0 = F.mse_loss(u[:, :, 0], u[:, :, -1])
    # loss_bc1 = F.mse_loss((u[:, :, 1] - u[:, :, -1]) /
    #                       (2/(nx)), (u[:, :, 0] - u[:, :, -2])/(2/(nx)))
    return loss_ic, loss_f

# Define Training Fuction

In [7]:
def train_wave2d(model,
                 dataset,
                 train_loader,
                 optimizer, scheduler,
                 config,
#                  c=1.0,
                 rank=0, log=False,
                 project='PINO-2d-default',
                 group='default',
                 tags=['default'],
                 use_tqdm=True):
    if rank == 0 and wandb and log:
        run = wandb.init(project=project,
                         entity='shawngr2',
                         group=group,
                         config=config,
                         tags=tags, reinit=True,
                         settings=wandb.Settings(start_method="fork"))

    data_weight = config['train']['xy_loss']
    f_weight = config['train']['f_loss']
    ic_weight = config['train']['ic_loss']
    ckpt_freq = config['train']['ckpt_freq']
    c = config['data']['c']
    model.train()
    myloss = LpLoss(size_average=True)
    S, T = dataset.S, dataset.T
    pbar = range(config['train']['epochs'])
    if use_tqdm:
        pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.1)

    for e in pbar:
        model.train()
        train_pino = 0.0
        train_ic = 0.0
        data_l2 = 0.0
        train_loss = 0.0
        

        for x, y in train_loader:
            x, y = x.to(rank), y.to(rank)
            out = model(x).reshape(y.shape)
            data_loss = myloss(out, y)
            loss_ic, loss_f = PINO_loss_wave2D(out, x[..., 0, -1], c=c)
            total_loss = loss_ic * ic_weight + loss_f * f_weight + data_loss * data_weight

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            data_l2 += data_loss.item()
            train_pino += loss_f.item()
            train_ic += loss_ic.item()
            train_loss += total_loss.item()
        scheduler.step()
        data_l2 /= len(train_loader)
        train_pino /= len(train_loader)
        train_ic /= len(train_loader)
        train_loss /= len(train_loader)
        if use_tqdm:
            pbar.set_description(
                (
                    f'Epoch {e}, train loss: {train_loss:.5f} '
                    f'train f error: {train_pino:.5f}; '
                    f'data l2 error: {data_l2:.5f}; '
                    f'train ic error: {train_ic:.5f}'
                )
            )
        if wandb and log:
            wandb.log(
                {
                    'Train f error': train_pino,
                    'Train L2 error': data_l2,
                    'Train ic error': loss_ic,
                    'Train loss': train_loss,
                }
            )

        if e % ckpt_freq == 0:
            save_checkpoint(config['train']['save_dir'],
                            config['train']['save_name'].replace('.pt', f'_{e}.pt'),
                            model, optimizer)
    save_checkpoint(config['train']['save_dir'],
                    config['train']['save_name'],
                    model, optimizer)
    print('Done!')

# Define Eval Function

In [8]:
def eval_wave2D(model,
                dataloader,
                config,
                device,
#                 c=1.0,
                use_tqdm=True):
    c = config['data']['c']
    model.eval()
    myloss = LpLoss(size_average=True)
    if use_tqdm:
        pbar = tqdm(dataloader, dynamic_ncols=True, smoothing=0.05)
    else:
        pbar = dataloader

    test_err = []
    f_err = []
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            out = model(x).reshape(y.shape)
            data_loss = myloss(out, y)

            loss_ic, f_loss = PINO_loss_wave2D(out, x[..., 0, -1], c=c)
            test_err.append(data_loss.item())
            f_err.append(f_loss.item())

    mean_f_err = np.mean(f_err)
    std_f_err = np.std(f_err, ddof=1) / np.sqrt(len(f_err))

    mean_err = np.mean(test_err)
    std_err = np.std(test_err, ddof=1) / np.sqrt(len(test_err))

    print(f'==Averaged relative L2 error mean: {mean_err}, std error: {std_err}==\n'
          f'==Averaged equation error mean: {mean_f_err}, std error: {std_f_err}==')



# Define Parameters

In [9]:
dim = 2
N = 128
Nx = 128
Ny = 128
l = 0.1
L = 1.0
sigma = 1.0
Nu = None
Nsamples = 50
dt = 1.0e-4
save_int = int(1e-2/dt)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Generate Random Fields

In [10]:
# grf = GaussianRF(dim, N, length=1.0, alpha=2.5, tau=5.0, device=device)
# U0 = grf.sample(Nsamples)

In [11]:
grf = GRF_Mattern(dim, N, length=L, nu=Nu, l=l, sigma=sigma, boundary="periodic", device=device)
U0 = grf.sample(Nsamples)

In [12]:
u0 = U0[0].cpu()
u0

tensor([[-0.1266, -0.1740, -0.2194,  ...,  0.0123, -0.0318, -0.0786],
        [-0.1404, -0.1885, -0.2344,  ...,  0.0008, -0.0441, -0.0917],
        [-0.1550, -0.2032, -0.2490,  ..., -0.0130, -0.0582, -0.1061],
        ...,
        [-0.0926, -0.1349, -0.1755,  ...,  0.0304, -0.0086, -0.0501],
        [-0.1023, -0.1468, -0.1895,  ...,  0.0273, -0.0138, -0.0575],
        [-0.1138, -0.1600, -0.2043,  ...,  0.0212, -0.0216, -0.0671]])

In [13]:
u0.shape

torch.Size([128, 128])

# Plot Random Fields

In [14]:
key = 0
x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
u0 = U0[key].cpu().numpy()
u0

array([[-0.12658249, -0.17403445, -0.21937762, ...,  0.01230361,
        -0.03181181, -0.07863276],
       [-0.14040141, -0.18852614, -0.2343867 , ...,  0.00082741,
        -0.04406614, -0.09168597],
       [-0.15497713, -0.20320867, -0.24903   , ..., -0.01302158,
        -0.05818198, -0.10605748],
       ...,
       [-0.09260599, -0.13486414, -0.17547265, ...,  0.03043726,
        -0.00859047, -0.0500548 ],
       [-0.10234658, -0.14683113, -0.18952121, ...,  0.02731996,
        -0.01381805, -0.05751839],
       [-0.11379065, -0.16001973, -0.20429882, ...,  0.02120497,
        -0.02164574, -0.0671479 ]], dtype=float32)

In [15]:
fig = plt.figure()
# fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
# surf = ax.plot_surface(X, Y, u0, cmap='jet', linewidth=0, antialiased=True, vmin=-2, vmax=2)
c = plt.pcolormesh(X, Y, u0, cmap='jet', shading='gouraud')

# fig.colorbar(surf, shrink=0.5, aspect=5)
fig.colorbar(c)

plt.title('GRF 2D')
plt.axis('square')
plt.show()

<IPython.core.display.Javascript object>

In [16]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(X, Y, u0, cmap='jet', linewidth=0, antialiased=True,)
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.title('GRF 2D')
plt.show()



<IPython.core.display.Javascript object>

# Evolve the Wave Equation

In [17]:
wave_eq = WaveEq2D(Nx=Nx, Ny=Ny, dt=dt, device=device)
U = vmap(wave_eq.wave_driver, in_dims=(0, None))(U0, save_int)

In [18]:
a = U0.cpu().float()
u = U.cpu().float()
display(u.shape,a.shape)

torch.Size([50, 101, 128, 128])

torch.Size([50, 128, 128])

In [19]:
u

tensor([[[[-1.2658e-01, -1.7403e-01, -2.1938e-01,  ...,  1.2304e-02,
           -3.1812e-02, -7.8633e-02],
          [-1.4040e-01, -1.8853e-01, -2.3439e-01,  ...,  8.2741e-04,
           -4.4066e-02, -9.1686e-02],
          [-1.5498e-01, -2.0321e-01, -2.4903e-01,  ..., -1.3022e-02,
           -5.8182e-02, -1.0606e-01],
          ...,
          [-9.2606e-02, -1.3486e-01, -1.7547e-01,  ...,  3.0437e-02,
           -8.5905e-03, -5.0055e-02],
          [-1.0235e-01, -1.4683e-01, -1.8952e-01,  ...,  2.7320e-02,
           -1.3818e-02, -5.7518e-02],
          [-1.1379e-01, -1.6002e-01, -2.0430e-01,  ...,  2.1205e-02,
           -2.1646e-02, -6.7148e-02]],

         [[-1.2702e-01, -1.7270e-01, -2.1635e-01,  ...,  6.7941e-03,
           -3.5728e-02, -8.0837e-02],
          [-1.4054e-01, -1.8684e-01, -2.3094e-01,  ..., -4.5703e-03,
           -4.7813e-02, -9.3659e-02],
          [-1.5486e-01, -2.0123e-01, -2.4526e-01,  ..., -1.8283e-02,
           -6.1760e-02, -1.0782e-01],
          ...,
     

# Plot Data

In [20]:
# %matplotlib notebook
fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()

fig.show()
fig.canvas.draw()


x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
# u = U[key].cpu().numpy()
# u0

pcm = ax.pcolormesh(X, Y, u[key, 0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(pcm, ax=ax)
# ax.

plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Wave Equation')
plt.axis('square')
plt.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(len(u)):
    ax.clear()
    pcm = ax.pcolormesh(X, Y, u[key, i], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
#     plt.colorbar(pcm, ax=ax)
    plt.xlabel('$x$')
    plt.ylabel('$y$')
    plt.title(f'Wave Equation')
    plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)

<IPython.core.display.Javascript object>

In [21]:
# %matplotlib notebook
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plt.ion()

fig.show()
fig.canvas.draw()


x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
# u = U[key].cpu().numpy()
# u0
surf = ax.plot_surface(X, Y, u[key, 0].cpu().numpy(), cmap='jet', vmin=-0.5, vmax=0.5, linewidth=0, antialiased=False)

# pcm = ax.pcolormesh(X, Y, u[0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)

# plt.colorbar(surf, ax=ax)
fig.colorbar(surf, shrink=0.5, aspect=5)

# ax.
# ax.set_zlim(-1, 1)
# display(ax.get_zlim())
zlim = ax.get_zlim()
ax.set_zlim(zlim)
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Wave Equation')
# plt.axis('square')
plt.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(len(u)):
    ax.clear()
#     pcm = ax.pcolormesh(X, Y, u[i], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
    surf = ax.plot_surface(X, Y, u[key, i].cpu().numpy(), cmap='jet', vmin=-0.5, vmax=0.5, linewidth=0, antialiased=False)
    ax.set_zlim(zlim)
#     plt.colorbar(pcm, ax=ax)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
#     ax.set_zlabel('$u$')
    plt.title(f'Wave Equation')
#     plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)



<IPython.core.display.Javascript object>

# Load Configuration File

In [22]:
config_file = 'configs/custom/wave2D-0000.yaml'
config = load_config(config_file)
display(config)

{'data': {'name': 'Wave2D-0000',
  'total_num': 50,
  'n_train': 45,
  'n_test': 5,
  'nx': 128,
  'nt': 100,
  'sub': 1,
  'sub_t': 1,
  'c': 1.0},
 'model': {'layers': [64, 64, 64, 64, 64],
  'modes1': [8, 8, 8, 8],
  'modes2': [8, 8, 8, 8],
  'modes3': [8, 8, 8, 8],
  'fc_dim': 128,
  'activation': 'gelu',
  'padding': 5},
 'train': {'batchsize': 1,
  'epochs': 100,
  'milestones': [25, 50, 75, 100],
  'base_lr': 0.001,
  'scheduler_gamma': 0.5,
  'ic_loss': 5.0,
  'f_loss': 1.0,
  'xy_loss': 5.0,
  'save_dir': 'Wave2D',
  'save_name': 'Wave2D-0000.pt',
  'ckpt': 'checkpoints/Wave2D/Wave2D-0000.pt',
  'ckpt_freq': 25},
 'log': {'project': 'PINO-Wave', 'group': 'Wave2D-0000'},
 'test': {'batchsize': 1, 'ckpt': 'checkpoints/WaveD/Wave2D-0000.pt'}}

# Define the DataLoaders

In [23]:
dataset = DataLoader2D(u, config['data']['nx'], config['data']['nt'], config['data']['sub'], config['data']['sub_t'])
train_loader = dataset.make_loader(config['data']['n_train'], config['train']['batchsize'], start=0, train=True)
test_loader = dataset.make_loader(config['data']['n_test'], config['test']['batchsize'], start=config['data']['n_train'], train=False)

In [28]:
x,y = next(iter(train_loader))

In [33]:
x_in = F.pad(x, (0, 0, 0, 5, 0, 0, 0, 0), "constant", 0)

In [31]:
x_in.shape

torch.Size([1, 128, 128, 106, 4])

# Define the Model

In [24]:
log = False
# config = config_train
model = FNN3d(modes1=config['model']['modes1'],
              modes2=config['model']['modes2'],
              modes3=config['model']['modes3'],
              fc_dim=config['model']['fc_dim'],
              layers=config['model']['layers'], 
              activation=config['model']['activation'],
              pad_z=config['model']['padding']
             ).to(device)

optimizer = Adam(model.parameters(), betas=(0.9, 0.999), lr=config['train']['base_lr'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=config['train']['milestones'],
                                                 gamma=config['train']['scheduler_gamma'])

# Load from checkpoint

In [25]:
load_checkpoint(model, ckpt_path=config['train']['ckpt'], optimizer=None)

Weights loaded from checkpoints/Wave2D/Wave2D-0000.pt


# Train the Model

In [41]:
train_wave2d(model,
             dataset,
             train_loader,
             optimizer,
             scheduler,
             config,
#              c=config['data']['c'],
             rank=0,
             log=log,
             project=config['log']['project'],
             group=config['log']['group'])

Epoch 0, train loss: 0.09790 train f error: 0.00984; data l2 error: 0.01005; train ic error: 0.00756:   1%|          | 1/100 [00:30<50:58, 30.89s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_0.pt


Epoch 25, train loss: 0.06849 train f error: 0.00814; data l2 error: 0.00760; train ic error: 0.00447:  26%|██▌       | 26/100 [13:17<37:58, 30.79s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_25.pt


Epoch 50, train loss: 0.05066 train f error: 0.00875; data l2 error: 0.00644; train ic error: 0.00194:  51%|█████     | 51/100 [26:19<25:46, 31.56s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_50.pt


Epoch 75, train loss: 0.04970 train f error: 0.00876; data l2 error: 0.00638; train ic error: 0.00181:  76%|███████▌  | 76/100 [38:23<11:28, 28.67s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_75.pt


Epoch 99, train loss: 0.05044 train f error: 0.00862; data l2 error: 0.00641; train ic error: 0.00195: 100%|██████████| 100/100 [49:15<00:00, 29.55s/it]


Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000.pt
Done!


In [33]:
train_wave2d(model,
             dataset,
             train_loader,
             optimizer,
             scheduler,
             config,
             c=config['data']['c'],
             rank=0,
             log=log,
             project=config['log']['project'],
             group=config['log']['group'])

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 0, train loss: 1.64529 train f error: 0.93369; data l2 error: 0.06710; train ic error: 0.06445:   1%|          | 1/150 [00:22<54:43, 22.04s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_0.pt


Epoch 25, train loss: 0.18098 train f error: 0.01783; data l2 error: 0.01708; train ic error: 0.01461:  17%|█▋        | 26/150 [10:42<58:29, 28.30s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_25.pt


Epoch 50, train loss: 0.09589 train f error: 0.00708; data l2 error: 0.01176; train ic error: 0.00771:  34%|███▍      | 51/150 [23:36<50:49, 30.81s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_50.pt


Epoch 75, train loss: 0.05376 train f error: 0.00288; data l2 error: 0.00984; train ic error: 0.00410:  51%|█████     | 76/150 [36:25<38:15, 31.02s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_75.pt


Epoch 100, train loss: 0.03815 train f error: 0.00209; data l2 error: 0.00938; train ic error: 0.00267:  67%|██████▋   | 101/150 [49:21<25:42, 31.48s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_100.pt


Epoch 125, train loss: 0.02934 train f error: 0.00171; data l2 error: 0.00913; train ic error: 0.00185:  84%|████████▍ | 126/150 [1:01:59<12:01, 30.08s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_125.pt


Epoch 149, train loss: 0.02965 train f error: 0.00165; data l2 error: 0.00910; train ic error: 0.00189: 100%|██████████| 150/150 [1:13:38<00:00, 29.45s/it]


Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000.pt
Done!


### Old Model for Comparison

In [36]:
train_wave2d(model,
             dataset,
             train_loader,
             optimizer,
             scheduler,
             config,
             c=config['data']['c'],
             rank=0,
             log=log,
             project=config['log']['project'],
             group=config['log']['group'])

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 0, train loss: 12.16554 train f error: 1.29981; data l2 error: 1.00512:   0%|          | 1/500 [00:21<2:58:32, 21.47s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_0.pt


Epoch 100, train loss: 0.71877 train f error: 0.04601; data l2 error: 0.06258:  20%|██        | 101/500 [33:36<2:13:45, 20.11s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_100.pt


Epoch 200, train loss: 0.54671 train f error: 0.01872; data l2 error: 0.05360:  40%|████      | 201/500 [1:06:49<1:39:54, 20.05s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_200.pt


Epoch 300, train loss: 0.52136 train f error: 0.01023; data l2 error: 0.05306:  60%|██████    | 301/500 [1:40:10<1:09:43, 21.02s/it]

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_300.pt


Epoch 400, train loss: 0.50422 train f error: 0.00724; data l2 error: 0.05272:  80%|████████  | 401/500 [2:13:21<32:58, 19.99s/it]  

Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000_400.pt


Epoch 499, train loss: 0.50743 train f error: 0.00920; data l2 error: 0.05292: 100%|██████████| 500/500 [2:46:11<00:00, 19.94s/it]


Checkpoint is saved at checkpoints/Wave2D/Wave2D-0000.pt
Done!


# Evaluate Model

In [26]:
eval_wave2D(model,
            test_loader,
            config,
            device,
#             c=config['data']['c'],
            use_tqdm=True)

100%|██████████| 5/5 [00:00<00:00,  6.08it/s]

==Averaged relative L2 error mean: 0.0077342063188552855, std error: 0.0003990224445474794==
==Averaged equation error mean: 0.013093253411352635, std error: 0.0009625584048592725==





In [34]:
eval_wave2D(model,
            test_loader,
            config,
            device,
            c=config['data']['c'],
            use_tqdm=True)

100%|██████████| 5/5 [00:00<00:00,  5.29it/s]

==Averaged relative L2 error mean: 0.010337595082819461, std error: 0.00046738182780304237==
==Averaged equation error mean: 0.006359837856143713, std error: 0.0008212992923767788==





# Generate Test Predictions

In [55]:
Nx = config['data']['nx']
Ny = config['data']['nx']
Nt = config['data']['nt'] + 1
Ntest = config['data']['n_test']
model.eval()
test_x = np.zeros((Ntest,Nx,Ny,Nt,4))
preds_y = np.zeros((Ntest,Nx,Ny,Nt))
test_y = np.zeros((Ntest,Nx,Ny,Nt))
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data_x, data_y = data
        data_x, data_y = data_x.to(device), data_y.to(device)
        pred_y = model(data_x).reshape(data_y.shape)
        test_x[i] = data_x.cpu().numpy()
        test_y[i] = data_y.cpu().numpy()
        preds_y[i] = pred_y.cpu().numpy()
#     data_loss = myloss(out, y)

In [36]:
X

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
        [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
        ...,
        [0.9766, 0.9766, 0.9766,  ..., 0.9766, 0.9766, 0.9766],
        [0.9844, 0.9844, 0.9844,  ..., 0.9844, 0.9844, 0.9844],
        [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922]])

In [37]:
Y

tensor([[0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922],
        [0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922],
        [0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922],
        ...,
        [0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922],
        [0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922],
        [0.0000, 0.0078, 0.0156,  ..., 0.9766, 0.9844, 0.9922]])

In [38]:
T

NameError: name 'T' is not defined

In [39]:
key_t

NameError: name 'key_t' is not defined

# Plot Results

In [68]:
key = 0
key_t = (Nt - 1) 
pred = preds_y[key]
true = test_y[key]


a = test_x[key]
# Nt, Nx, _ = a.shape
u0 = a[..., 0, -1]
pred_t = pred[..., key_t]
true_t = true[..., key_t]
# T = a[:,:,2]
# X = a[:,:,1]
# x = X[0]
x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
t = a[0, 0, key_t, 2]
grid_x, grid_y, grid_t = get_grid3d(Nx, Nt)
# X = X.reshape(Nx, Nx)

In [69]:
fig = plt.figure(figsize=(24,5))
plt.subplot(1,4,1)

plt.pcolormesh(X, Y, u0, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Intial Condition $u(x,y)$')
plt.tight_layout()
plt.axis('square')

plt.subplot(1,4,2)
# plt.pcolor(XX,TT, S_test, cmap='jet')
plt.pcolormesh(X, Y, true_t, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Exact $u(x,y,t={t:.2f})$')
plt.tight_layout()
plt.axis('square')

plt.subplot(1,4,3)
# plt.pcolor(XX,TT, S_pred, cmap='jet')
plt.pcolormesh(X, Y, pred_t, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Predict $u(x,y,t={t:.2f})$')
plt.axis('square')

plt.tight_layout()

plt.subplot(1,4,4)
# plt.pcolor(XX,TT, S_pred - S_test, cmap='jet')
plt.pcolormesh(X, Y, pred_t - true_t, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Absolute error')
plt.tight_layout()
plt.axis('square')

# plt.show()

<IPython.core.display.Javascript object>

(0.0, 0.9921875, 0.0, 0.9921875)

In [33]:
# %matplotlib notebook
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))

plt.ion()

fig.show()
fig.canvas.draw()


x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
# u = U[key].cpu().numpy()
# u0
# pcm1 = ax1.plot_pcmace(X, Y, true[..., 0], cmap='jet', label='true', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
pcm1 = ax1.pcolormesh(X, Y, true[..., 0], cmap='jet', label='true', shading='gouraud')

# plt.colorbar()
# pcm2 = ax2.plot_pcmace(X, Y, pred[..., 0], cmap='jet', label='pred', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
pcm2 = ax2.pcolormesh(X, Y, pred[..., 0], cmap='jet', label='pred', shading='gouraud')

# pcm = ax.pcolormesh(X, Y, u[0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax2)
# cax = divider.append_axes("right", size="5%", pad=0.05)


# fig.colorbar(pcm, shrink=0.5, aspect=5)

# ax.
# ax.set_clim(-1, 1)
# display(ax.get_clim())
clim = pcm1.get_clim()
pcm1.set_clim(clim)
ax1.set_xlabel('$x$')
ax1.set_ylabel('$y$')
pcm2.set_clim(clim)
ax2.set_xlabel('$x$')
ax2.set_ylabel('$y$')
ax1.set_title(f'Wave Equation Truth')
ax2.set_title(f'Wave Equation Prediction')

# plt.axis('square')

plt.colorbar(pcm1, ax=ax1)
plt.colorbar(pcm2, ax=ax2)
ax1.axis('square')
ax2.axis('square')
plt.tight_layout()
# ax2.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(pred.shape[-1]):
    ax1.clear()
    ax2.clear()
    pcm1 = ax1.pcolormesh(X, Y, true[..., i], cmap='jet', label='true', shading='gouraud')

    pcm2 = ax2.pcolormesh(X, Y, pred[..., i], cmap='jet', label='pred', shading='gouraud')
#     plt.axis('square')

    pcm1.set_clim(clim)
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    pcm2.set_clim(clim)
    ax2.set_xlabel('$x$')
    ax2.set_ylabel('$y$')
    ax1.set_title(f'Wave Equation Truth')
    ax2.set_title(f'Wave Equation Prediction')
#     plt.title(f'Wave Equation')
    ax1.axis('square')
    ax2.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)



<IPython.core.display.Javascript object>

In [34]:
# %matplotlib notebook
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plt.ion()

fig.show()
fig.canvas.draw()


# x = torch.linspace(0, 1, Nx + 1)[:-1]
# y = torch.linspace(0, 1, Nx + 1)[:-1]
# X, Y = torch.meshgrid(x, y, indexing='ij')
# u = U[key].cpu().numpy()
# u0
surf1 = ax.plot_surface(X, Y, true[..., 0], color='b', label='true', alpha=0.2, linewidth=0, antialiased=False)
surf2 = ax.plot_surface(X, Y, pred[..., 0], color='r', label='pred', alpha=0.2, linewidth=0, antialiased=False)

# pcm = ax.pcolormesh(X, Y, u[0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)

# plt.colorbar(surf, ax=ax)
# fig.colorbar(surf, shrink=0.5, aspect=5)

# ax.
# ax.set_zlim(-1, 1)
# display(ax.get_zlim())
zlim = ax.get_zlim()
ax.set_zlim(zlim)
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Wave Equation')
# plt.axis('square')
plt.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(len(u)):
    ax.clear()
#     pcm = ax.pcolormesh(X, Y, u[i], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
    surf1 = ax.plot_surface(X, Y, true[..., i], color='b', label='true', alpha=0.8, linewidth=0, antialiased=False)
    surf2 = ax.plot_surface(X, Y, pred[..., i], color='r', label='pred', alpha=0.8, linewidth=0, antialiased=False)
    ax.set_zlim(zlim)
#     plt.colorbar(pcm, ax=ax)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
#     ax.set_zlabel('$u$')
    plt.title(f'Wave Equation')
#     plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)



<IPython.core.display.Javascript object>

In [35]:
# %matplotlib notebook
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5), subplot_kw={"projection": "3d"})

plt.ion()

fig.show()
fig.canvas.draw()


x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
# u = U[key].cpu().numpy()
# u0
# surf1 = ax1.plot_surface(X, Y, true[..., 0], cmap='jet', label='true', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
surf1 = ax1.plot_surface(X, Y, true[..., 0], cmap='jet', label='true', linewidth=0.0, antialiased=False)

# plt.colorbar()
# surf2 = ax2.plot_surface(X, Y, pred[..., 0], cmap='jet', label='pred', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
surf2 = ax2.plot_surface(X, Y, pred[..., 0], cmap='jet', label='pred',linewidth=0.0, antialiased=False)

# pcm = ax.pcolormesh(X, Y, u[0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax2)
# cax = divider.append_axes("right", size="5%", pad=0.05)


# fig.colorbar(surf, shrink=0.5, aspect=5)

# ax.
# ax.set_zlim(-1, 1)
# display(ax.get_zlim())
zlim = ax1.get_zlim()
ax1.set_zlim(zlim)
ax1.set_xlabel('$x$')
ax1.set_ylabel('$y$')
ax2.set_zlim(zlim)
ax2.set_xlabel('$x$')
ax2.set_ylabel('$y$')
ax1.set_title(f'Wave Equation Truth')
ax2.set_title(f'Wave Equation Prediction')
# plt.axis('square')

plt.colorbar(surf1, ax=ax1, shrink=0.5, aspect=5)
plt.colorbar(surf2, ax=ax2, shrink=0.5, aspect=5)
plt.tight_layout()
# ax2.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(pred.shape[-1]):
    ax1.clear()
    ax2.clear()
#     pcm = ax.pcolormesh(X, Y, u[i], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
#     surf1 = ax1.plot_surface(X, Y, true[..., i], cmap='jet', label='true', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
    surf1 = ax1.plot_surface(X, Y, true[..., i], cmap='jet', label='true', linewidth=0.0, antialiased=False)
#     surf2 = ax2.plot_surface(X, Y, pred[..., i], cmap='jet', label='pred', alpha=1.0, rstride=1, cstride=1, linewidth=0.0, antialiased=False)
    surf2 = ax2.plot_surface(X, Y, pred[..., i], cmap='jet', label='pred',linewidth=0.0, antialiased=False)

    ax1.set_zlim(zlim)
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax2.set_zlim(zlim)
    ax2.set_xlabel('$x$')
    ax2.set_ylabel('$y$')
    ax1.set_title(f'Wave Equation Truth')
    ax2.set_title(f'Wave Equation Prediction')
#     plt.title(f'Wave Equation')
#     plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)



<IPython.core.display.Javascript object>

In [45]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plt.ion()

fig.show()
fig.canvas.draw()


x = torch.linspace(0, 1, Nx + 1)[:-1]
y = torch.linspace(0, 1, Nx + 1)[:-1]
X, Y = torch.meshgrid(x, y, indexing='ij')
surf = ax.plot_surface(X, Y, pred[..., 0] - true[..., 0], cmap='jet', linewidth=0, antialiased=False)
# pcm = ax.pcolormesh(X, Y, u[0], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)

# plt.colorbar(surf, ax=ax)
fig.colorbar(surf, shrink=0.5, aspect=5)

# ax.
# ax.set_zlim(-1, 1)
# display(ax.get_zlim())
zlim = ax.get_zlim()
ax.set_zlim(zlim)
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Wave Equation Error')
# plt.axis('square')
plt.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(len(u)):
    ax.clear()
#     pcm = ax.pcolormesh(X, Y, u[i], vmin=-0.5, vmax=0.5, cmap='jet', shading='gouraud')
    surf = ax.plot_surface(X, Y, pred[..., i] - true[..., i], cmap='jet', linewidth=0, antialiased=False)
    ax.set_zlim(zlim)
#     plt.colorbar(pcm, ax=ax)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
#     ax.set_zlabel('$u$')
    plt.title(f'Wave Equation Error')
#     plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)



<IPython.core.display.Javascript object>

In [46]:
# %matplotlib notebook
fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()

fig.show()
fig.canvas.draw()


# x = torch.linspace(0, 1, Nx + 1)[:-1]
# y = torch.linspace(0, 1, Nx + 1)[:-1]
# X, Y = torch.meshgrid(x, y, indexing='ij')
# # u = U[key].cpu().numpy()
# u0

pcm = ax.pcolormesh(X, Y, pred[..., 0] - true[..., 0], cmap='jet', shading='gouraud')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(pcm, ax=ax)
# ax.
clim = pcm.get_clim()
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title(f'Wave Equation')
plt.axis('square')
plt.tight_layout()

# movie_dir = "Wave2D_movie"
# movie_filename = "Wave2D_movie"
# movie_files = []
# os.makedirs(movie_dir, exist_ok=True)
for i in range(len(u)):
    ax.clear()
    pcm = ax.pcolormesh(X, Y, pred[..., i] - true[..., i], cmap='jet', shading='gouraud')
#     plt.colorbar(pcm, ax=ax)
    pcm.set_clim(clim)
    plt.xlabel('$x$')
    plt.ylabel('$y$')
    plt.title(f'Wave Equation')
    plt.axis('square')
    plt.tight_layout()
    fig.canvas.draw()
#     movie_path = os.path.join(movie_dir,f'{movie_filename}-{i:03}.png')
#     movie_files.append(movie_path)
#     plt.savefig(movie_path)

<IPython.core.display.Javascript object>

In [38]:
preds_y.shape

(5, 128, 128, 101)

# Save and Load Data

In [5]:
def save_data(data_path, test_x, test_y, preds_y, key_t):
    data_dir, data_filename = os.path.split(data_path)
    os.makedirs(data_dir, exist_ok=True)
    np.savez(data_path, test_x=test_x, test_y=test_y, preds_y=preds_y, key_t=key_t)

def load_data(data_path):
    data = np.load(data_path)
    test_x = data['test_x']
    test_y = data['test_y']
    preds_y = data['preds_y']
    key_t = int(data['key_t'])
    return test_x, test_y, preds_y, key_t

In [6]:
data_dir = 'data/Wave2D'
data_filename = 'data.npz'
data_path = os.path.join(data_dir, data_filename)
# os.makedirs(data_dir, exist_ok=True)


In [72]:
save_data(data_path, test_x, test_y, preds_y, key_t)

In [7]:
test_x, test_y, preds_y, key_t = load_data(data_path)

In [15]:
def plot_predictions(key, key_t, test_x, test_y, preds_y, print_index=False, save_path=None, font_size=None):
    
    if font_size is not None:
        plt.rcParams.update({'font.size': font_size})
        
    pred = preds_y[key]
    true = test_y[key]
    
    Nx, Ny, Nt = pred.shape


    a = test_x[key]
    # Nt, Nx, _ = a.shape
    u0 = a[..., 0, -1]
    pred_t = pred[..., key_t]
    true_t = true[..., key_t]
    # T = a[:,:,2]
    # X = a[:,:,1]
    # x = X[0]
    x = torch.linspace(0, 1, Nx + 1)[:-1]
    y = torch.linspace(0, 1, Ny + 1)[:-1]
    X, Y = torch.meshgrid(x, y, indexing='ij')
    t = a[0, 0, key_t, 2]

    # Plot
    fig = plt.figure(figsize=(24,5))
    plt.subplot(1,4,1)

    plt.pcolormesh(X, Y, u0, cmap='jet', shading='gouraud')
    plt.colorbar()
    plt.xlabel('$x$')
    plt.ylabel('$y$')
    plt.title('Intial Condition $u(x,y)$')
    plt.tight_layout()
    plt.axis('square')

    plt.subplot(1,4,2)
    # plt.pcolor(XX,TT, S_test, cmap='jet')
    plt.pcolormesh(X, Y, true_t, cmap='jet', shading='gouraud')
    plt.colorbar()
    plt.xlabel('$x$')
    plt.ylabel('$y$')
#     plt.title(f'Exact $u(x,y,t={t:.2f})$')
    plt.title(f'Exact $u(x,y,t={int(t)})$')
    plt.tight_layout()
    plt.axis('square')

    plt.subplot(1,4,3)
    # plt.pcolor(XX,TT, S_pred, cmap='jet')
    plt.pcolormesh(X, Y, pred_t, cmap='jet', shading='gouraud')
    plt.colorbar()
    plt.xlabel('$x$')
    plt.ylabel('$y$')
#     plt.title(f'Predict $u(x,y,t={t:.2f})$')
    plt.title(f'Predict $u(x,y,t={int(t)}$')

    plt.axis('square')

    plt.tight_layout()

    plt.subplot(1,4,4)
    # plt.pcolor(XX,TT, S_pred - S_test, cmap='jet')
    plt.pcolormesh(X, Y, pred_t - true_t, cmap='jet', shading='gouraud')
    plt.colorbar()
    plt.xlabel('$x$')
    plt.ylabel('$y$')
    plt.title('Absolute Error')
    plt.tight_layout()
    plt.axis('square')

    if save_path is not None:
        plt.savefig(f'{save_path}.png', bbox_inches='tight')
    plt.show()

    


In [16]:
figures_dir = 'Wave2D/figures/'
os.makedirs(figures_dir, exist_ok=True)
font_size = 12
for key in range(len(preds_y)):
    save_path = os.path.join(figures_dir, f'Wave{key}')
#     plot_predictions(key, key_t, test_x, test_y, preds_y, print_index=True, save_path=None)
    plot_predictions(key, key_t, test_x, test_y, preds_y, print_index=True, save_path=save_path, font_size=font_size)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>