In [1]:
import os
# os.environ['OMP_NUM_THREADS'] = '1'
# os.environ['export OPENBLAS_NUM_THREADS']='1'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import math
import h5py
import cv2
import glob
from functools import partial

# from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer
from models.utils_3d import train, test, LpLoss, get_filter, UnitGaussianNormalizer

In [3]:
# # set fraction of GPU you want to use (out of 24GB). 0 denotes the physical GPU number 0 (the only one).
# torch.cuda.set_per_process_memory_fraction(0.6, 0)
# torch.cuda.empty_cache()
# total_memory = torch.cuda.get_device_properties(0).total_memory
# print('totoal_memory = ', total_memory/10**9)
# #anything greater or equal to 0.1 will return an error.
# tmp_tensor = torch.empty(int(total_memory * 0.0999), dtype=torch.int8, device='cuda')
# del tmp_tensor
# torch.cuda.empty_cache()

In [4]:
torch.manual_seed(0)
np.random.seed(0)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def get_initializer(name):
    
    if name == 'xavier_normal':
        init_ = partial(nn.init.xavier_normal_)
    elif name == 'kaiming_uniform':
        init_ = partial(nn.init.kaiming_uniform_)
    elif name == 'kaiming_normal':
        init_ = partial(nn.init.kaiming_normal_)
    return init_

In [7]:
class sparseKernel(nn.Module):
    def __init__(self,
                 k, alpha, c=1, 
                 nl = 1,
                 initializer = None,
                 **kwargs):
        super(sparseKernel,self).__init__()
        
        self.k = k
        self.conv = self.convBlock(alpha*k**2, alpha*k**2)
        self.Lo = nn.Conv1d(alpha*k**2, c*k**2, 1)
        
    def forward(self, x):
        B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, Nx, Ny, T)
        x = x.reshape(B, -1, Nx, Ny, T)
        x = self.conv(x)
        x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T)
        return x
        
        
    def convBlock(self, ich, och):
        net = nn.Sequential(
            nn.Conv3d(och, och, 3, 1, 1),
            nn.ReLU(inplace=True),
        )
        return net 



# fft conv taken from: https://github.com/zongyi-li/fourier_neural_operator
def compl_mul3d(a, b):
    # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
    # print('inputs shape', a.shape)
    # print('weights shape', b.shape)
    return torch.einsum("bixyz,ioxyz->boxyz", a, b)



class sparseKernelFT(nn.Module):
    def __init__(self,
                 k, alpha, c=1, 
                 nl = 1,
                 initializer = None,
                 **kwargs):
        super(sparseKernelFT, self).__init__()        
        
        self.modes = alpha

        self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))        
        self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))        
        self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))        
        nn.init.xavier_normal_(self.weights1)
        nn.init.xavier_normal_(self.weights2)
        nn.init.xavier_normal_(self.weights3)
        nn.init.xavier_normal_(self.weights4)
        
        self.Lo = nn.Conv1d(c*k**2, c*k**2, 1)
#         self.Wo = nn.Conv1d(c*k**2, c*k**2, 1)
        self.k = k
        
        
    # fft conv taken from: https://github.com/zongyi-li/fourier_neural_operator
    def compl_mul3d(self, a, b):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        # print('inputs shape', a.shape)
        # print('weights shape', b.shape)
        return torch.einsum("bixyz,ioxyz->boxyz", a, b)

    def forward(self, x):
        
        B, c, ich, Nx, Ny, T = x.shape # (B, c, ich, N, N, T)
        
        # print('x', x.shape)
        x = x.reshape(B, -1, Nx, Ny, T)
        # print('x reshaped', x.shape)

        x_fft = torch.fft.rfft(x)
        # print('x_fft', x_fft.shape)
        
        # Multiply relevant Fourier modes
        l1 = min(self.modes, Nx//2+1)
        l2 = min(self.modes, Ny//2+1)
        
        out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, dtype=torch.cfloat, device=x.device)
        # print('out_ft', out_ft.shape)

        
        out_ft[:, :, :l1, :l2, :self.modes] = self.compl_mul3d(
            x_fft[:, :, :l1, :l2, :self.modes], self.weights1[:, :, :l1, :l2, :])
        out_ft[:, :, -l1:, :l2, :self.modes] = self.compl_mul3d(
                x_fft[:, :, -l1:, :l2, :self.modes], self.weights2[:, :, :l1, :l2, :])
        out_ft[:, :, :l1, -l2:, :self.modes] = self.compl_mul3d(
                x_fft[:, :, :l1, -l2:, :self.modes], self.weights3[:, :, :l1, :l2, :])
        out_ft[:, :, -l1:, -l2:, :self.modes] = self.compl_mul3d(
                x_fft[:, :, -l1:, -l2:, :self.modes], self.weights4[:, :, :l1, :l2, :])
        
        #Return to physical space
        x = torch.fft.irfft(out_ft)
        
        x = F.relu(x)
        x = self.Lo(x.view(B, c*ich, -1)).view(B, c, ich, Nx, Ny, T)
        return x


class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1])

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x



class MWT_CZ(nn.Module):
    def __init__(self,
                 k = 3, alpha = 5, 
                 L = 0, c = 1,
                 base = 'legendre',
                 initializer = None,
                 **kwargs):
        super(MWT_CZ, self).__init__()
        
        self.k = k
        self.L = L
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
        H0r = H0@PHI0
        G0r = G0@PHI0
        H1r = H1@PHI1
        G1r = G1@PHI1
        
        H0r[np.abs(H0r)<1e-8]=0
        H1r[np.abs(H1r)<1e-8]=0
        G0r[np.abs(G0r)<1e-8]=0
        G1r[np.abs(G1r)<1e-8]=0
        
        self.A = sparseKernelFT(k, alpha, c)
        self.B = sparseKernelFT(k, alpha, c)
        self.C = sparseKernelFT(k, alpha, c)
        
        self.T0 = nn.Conv1d(c*k**2, c*k**2, 1)

        if initializer is not None:
            self.reset_parameters(initializer)

        self.register_buffer('ec_s', torch.Tensor(
            np.concatenate((np.kron(H0, H0).T, 
                            np.kron(H0, H1).T,
                            np.kron(H1, H0).T,
                            np.kron(H1, H1).T,
                           ), axis=0)))
        self.register_buffer('ec_d', torch.Tensor(
            np.concatenate((np.kron(G0, G0).T,
                            np.kron(G0, G1).T,
                            np.kron(G1, G0).T,
                            np.kron(G1, G1).T,
                           ), axis=0)))
        
        self.register_buffer('rc_ee', torch.Tensor(
            np.concatenate((np.kron(H0r, H0r), 
                            np.kron(G0r, G0r),
                           ), axis=0)))
        self.register_buffer('rc_eo', torch.Tensor(
            np.concatenate((np.kron(H0r, H1r), 
                            np.kron(G0r, G1r),
                           ), axis=0)))
        self.register_buffer('rc_oe', torch.Tensor(
            np.concatenate((np.kron(H1r, H0r), 
                            np.kron(G1r, G0r),
                           ), axis=0)))
        self.register_buffer('rc_oo', torch.Tensor(
            np.concatenate((np.kron(H1r, H1r), 
                            np.kron(G1r, G1r),
                           ), axis=0)))
        
        
    def forward(self, x):
        
        B, c, ich, Nx, Ny, T = x.shape # (B, c, k^2, Nx, Ny, T)
        ns = math.floor(np.log2(Nx))

        Ud = torch.jit.annotate(List[Tensor], [])
        Us = torch.jit.annotate(List[Tensor], [])

#         decompose
        for i in range(ns-self.L):
            d, x = self.wavelet_transform(x)
            Ud += [self.A(d) + self.B(x)]
            Us += [self.C(d)]
        x = self.T0(x.reshape(B, c*ich, -1)).view(
            B, c, ich, 2**self.L, 2**self.L, T) # coarsest scale transform

#        reconstruct            
        for i in range(ns-1-self.L,-1,-1):
            x = x + Us[i]
            x = torch.cat((x, Ud[i]), 2)
            x = self.evenOdd(x)

        return x

    
    def wavelet_transform(self, x):
        xa = torch.cat([x[:, :, :, ::2 , ::2 , :], 
                        x[:, :, :, ::2 , 1::2, :], 
                        x[:, :, :, 1::2, ::2 , :], 
                        x[:, :, :, 1::2, 1::2, :]
                       ], 2)
        waveFil = partial(torch.einsum, 'bcixyt,io->bcoxyt') 
        d = waveFil(xa, self.ec_d)
        s = waveFil(xa, self.ec_s)
        return d, s
        
        
    def evenOdd(self, x):
        
        B, c, ich, Nx, Ny, T = x.shape # (B, c, 2*k^2, Nx, Ny)
        assert ich == 2*self.k**2
        evOd = partial(torch.einsum, 'bcixyt,io->bcoxyt')
        x_ee = evOd(x, self.rc_ee)
        x_eo = evOd(x, self.rc_eo)
        x_oe = evOd(x, self.rc_oe)
        x_oo = evOd(x, self.rc_oo)
        
        x = torch.zeros(B, c, self.k**2, Nx*2, Ny*2, T,
            device = x.device)
        x[:, :, :, ::2 , ::2 , :] = x_ee
        x[:, :, :, ::2 , 1::2, :] = x_eo
        x[:, :, :, 1::2, ::2 , :] = x_oe
        x[:, :, :, 1::2, 1::2, :] = x_oo
        return x
    
    def reset_parameters(self, initializer):
        initializer(self.T0.weight)
    
    
class MWT(nn.Module):
    def __init__(self,
                 ich = 1, k = 3, alpha = 2, c = 1,
                 nCZ = 3,
                 L = 0,
                 base = 'legendre',
                 initializer = None,
                 **kwargs):
        super(MWT,self).__init__()
        
        self.k = k
        self.c = c
        self.L = L
        self.nCZ = nCZ
        self.Lk = nn.Linear(ich, c*k**2)
        
        self.MWT_CZ = nn.ModuleList(
            [MWT_CZ(k, alpha, L, c, base, 
            initializer) for _ in range(nCZ)]
        )
        self.BN = nn.ModuleList(
            [nn.BatchNorm3d(c*k**2) for _ in range(nCZ)]
        )
        self.Lc0 = nn.Linear(c*k**2, 128)
        self.Lc1 = nn.Linear(128, 1)
        
        if initializer is not None:
            self.reset_parameters(initializer)
        
    def forward(self, x):
        
        B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d)
        ns = math.floor(np.log2(Nx))
        x = model.Lk(x)
        x = x.view(B, Nx, Ny, T, self.c, self.k**2)
        x = x.permute(0, 4, 5, 1, 2, 3)
    
        for i in range(self.nCZ):
            x = self.MWT_CZ[i](x)
            x = self.BN[i](x.view(B, -1, Nx, Ny, T)).view(
                B, self.c, self.k**2, Nx, Ny, T)
            if i < self.nCZ-1:
                x = F.relu(x)

        x = x.view(B, -1, Nx, Ny, T) # collapse c and k**2
        x = x.permute(0, 2, 3, 4, 1)
        x = self.Lc0(x)
        x = F.relu(x)
        x = self.Lc1(x)
        return x.squeeze()
    
    def reset_parameters(self, initializer):
        initializer(self.Lc0.weight)
        initializer(self.Lc1.weight)

In [8]:
def makeVid(tensor, set, dirPath):
    
    newDir = 'train{}'.format(set)
    path = os.path.join(dirPath, newDir)
    if os.path.isdir(path):
        pass
    else:
        os.mkdir(path)
    
    for i in range(tensor.shape[-1]):
        f = tensor[set][:, :, i]
        fig = plt.figure()
        ax = fig.add_subplot(111)
        cax = ax.matshow(f, interpolation='nearest')
        cbar = fig.colorbar(cax)
        
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        cbar.set_label('Vorticity (RPM)', rotation=270)
        ax.set_title('Two-dimensional Vorticity')
        plt.savefig(path+'/{}.png'.format(i))
    
    img_array = []
    filelist = glob.glob(path+'/*.png')
    
    for filename in sorted(filelist):
        img = cv2.imread(filename)
        print(filename)
        height, width, layers = img.shape
        size = (width,height)
        img_array.append(img)
    
    fps = 1
    vids = os.path.join(dirPath, 'videos')
    if os.path.isdir(vids):
        pass
    else:
        os.mkdir(vids)

    vidpath = vids+'/Vid{}.mp4'.format(set)
    out = cv2.VideoWriter(vidpath ,cv2.VideoWriter_fourcc(*'DIVX'), fps, size)

    for j in range(len(img_array)):
        out.write(img_array[j])
    out.release()

In [9]:
data_path = 'Data/ns_V1e-3_N5000_T50.mat'

ntrain = 1000
ntest = 200

batch_size = 20

In [10]:
sub = 1
S = 64 // sub
T_in = 10
T = 40

dataloader = h5py.File(data_path)
u_data = dataloader['u']
t_data = dataloader['t']


train_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,:ntrain]
            ).permute(3, 1, 2, 0)
train_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,:ntrain]
            ).permute(3, 1, 2, 0)

test_a = torch.from_numpy(u_data[:T_in, ::sub,::sub,-ntest:]
            ).permute(3, 1, 2, 0)
test_u = torch.from_numpy(u_data[T_in:T_in+T, ::sub,::sub,-ntest:]
            ).permute(3, 1, 2, 0)

print(train_a.shape)
print(train_u.shape)
print(test_a.shape)
print(test_u.shape)
assert(S == train_u.shape[-2])
assert(T == train_u.shape[-1])

torch.Size([1000, 64, 64, 10])
torch.Size([1000, 64, 64, 40])
torch.Size([200, 64, 64, 10])
torch.Size([200, 64, 64, 40])


In [11]:
# tns = train_a
# for set in range(tns.shape[0]):
#     makeVid(tns, set, 'animations/NS/ns_V1e-3_N5000_T50')

In [12]:
a_normalizer = UnitGaussianNormalizer(train_a)
x_train0 = a_normalizer.encode(train_a)
x_test0 = a_normalizer.encode(test_a)


y_normalizer = UnitGaussianNormalizer(train_u)
y_train = y_normalizer.encode(train_u)
print(y_train.shape)


print('x_train shape before = ', x_train0.shape)
print('x_test shape before = ', x_test0.shape)

x_train0 = x_train0.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])
x_test0 = x_test0.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])

print('x_train shape after = ', x_train0.shape)
print('x_test shape after = ', x_test0.shape)

torch.Size([1000, 64, 64, 40])
x_train shape before =  torch.Size([1000, 64, 64, 10])
x_test shape before =  torch.Size([200, 64, 64, 10])
x_train shape after =  torch.Size([1000, 64, 64, 40, 10])
x_test shape after =  torch.Size([200, 64, 64, 40, 10])


In [13]:
# pad locations (x,y,t)
gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1])
print('gridx shape = ', gridx.shape)
gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1])
print('gridy shape = ', gridy.shape)
gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float)
gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1])
print('gridt shape = ',gridt.shape)

print('x_train shape before = ', x_train0.shape)
x_train = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]),
                       gridt.repeat([ntrain,1,1,1,1]), x_train0), dim=-1)
print('x_train shape after = ', x_train.shape)

print('x_test shape before = ', x_test0.shape)
x_test = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]),
                       gridt.repeat([ntest,1,1,1,1]), x_test0), dim=-1)
print('x_test shape after = ', x_test.shape)

gridx shape =  torch.Size([1, 64, 64, 40, 1])
gridy shape =  torch.Size([1, 64, 64, 40, 1])
gridt shape =  torch.Size([1, 64, 64, 40, 1])
x_train shape before =  torch.Size([1000, 64, 64, 40, 10])
x_train shape after =  torch.Size([1000, 64, 64, 40, 13])
x_test shape before =  torch.Size([200, 64, 64, 40, 10])
x_test shape after =  torch.Size([200, 64, 64, 40, 13])


# Each dimension explained:

$gridx$ $shape =  torch.Size([1, 64, 64, 40, 1])$ first $1$ represents the number of subsamples, but not sure why there is a $1$ at the end. $64\times 64$ is the $x\times y$ grid and $40$ is the number of future timesteps used for prediction. Same for gridy and gridt

$gridy$ $shape =  torch.Size([1, 64, 64, 40, 1])$

$gridt$ $shape =  torch.Size([1, 64, 64, 40, 1])$

$x\_train$ $shape$ $before =  torch.Size([1000, 64, 64, 40, 13])$ repeated $ntrain = 1000$ times at the beginning, then 


$x\_train$ $shape$ $after =  torch.Size([1000, 64, 64, 40, 13])$

$x\_test$ $shape$ $before =  torch.Size([200, 64, 64, 40, 13])$

$x\_test$ $shape$ $after =  torch.Size([200, 64, 64, 40, 13])$

In [14]:
batch_size = 10
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, test_u), batch_size=batch_size, shuffle=False)

In [15]:
# for x, y in test_loader:
#     print(x.shape)
#     print(y.shape)

In [16]:
ich = 13
initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform

torch.manual_seed(0)
np.random.seed(0)

alpha = 12
c = 4
k = 3
nCZ = 4
L = 0
model = MWT(ich, 
            alpha = alpha,
            c = c,
            k = k, 
            base = 'legendre', # chebyshev
            nCZ = nCZ,
            L = L,
            initializer = initializer,
            ).to(device)
learning_rate = 0.001

epochs = 2000
step_size = 100
gamma = 0.5

In [17]:
y_normalizer.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
myloss = LpLoss(size_average=False)

In [18]:
# PATH = 'NS_models/NS_model1600.pt'

# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

In [19]:
train_loss = []
test_loss = []
for epoch in range(1, epochs+1):
    
    train_l2 = train(model, train_loader, optimizer, epoch, device,
        lossFn = myloss, lr_schedule = scheduler,
        post_proc = y_normalizer.decode)
    train_loss.append(train_l2)
    
    test_l2 = test(model, test_loader, device, 
        lossFn=myloss, 
        post_proc=y_normalizer.decode)
    test_loss.append(test_l2)
        
    if epoch%100 == 0:
        PATH = 'NS_model{}.pt'.format(epoch)
        torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': myloss}, PATH)
        np.save('train_loss_no_experiment.npy', train_loss)
        np.save('test_loss_no_experiment.npy', test_loss)
    print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')

OutOfMemoryError: CUDA out of memory. Tried to allocate 60.00 MiB (GPU 0; 23.69 GiB total capacity; 1.83 GiB already allocated; 99.12 MiB free; 1.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
plt.plot(train_loss, label = 'train loss')
plt.plot(test_loss, label = 'test loss')
plt.legend()
plt.show()

In [None]:
# PATH = 'NS_model2000.pt'
# torch.save({'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': myloss}, PATH)

In [None]:
pred = torch.zeros(test_u.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, test_u), batch_size=1, shuffle=False)
with torch.no_grad():
    for x, y in test_loader:
        test_l2 = 0
        x, y = x.cuda(), y.cuda()
        out = model(x)
        pred[index] = out

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        print(index, test_l2)
        index = index + 1

In [None]:
# index = 0
# for x, y in test_loader:
    # Shapes is 1024 inputs (x) and 1024 outputs (y)
#     print(x[0].shape)
#     print(y[0].shape)
    # plt.matshow(x[0][:, :, 0])
    #plt.matshow(y[0][:, :, 0])

    # Create plots with pre-defined labels.    
#     fig, ax = plt.subplots()
#     fig.set_figwidth(14)
#     fig.set_figheight(8)
#     ax.plot(x[0], label='Initial conditions $u(x,t)$ where $t \in$ [0,1024]')
#     ax.plot(pred[index], label='Predicted evolution $(t \in [1025,2048])$')
#     ax.plot(y[0], 'k--', label='Real evolution $(t \in [1025,2048])$')
#     index += 1
#     legend = ax.legend(loc='lower right', shadow=True)
#     ax.set_xlabel("Initial condition index")
#     ax.set_ylabel("Initial condition value u(x,t)")
#     ax.set_title("Time evolution of initial conditions")
#     ax.grid()
    
    # Put a nicer background color on the legend.
    # legend.get_frame().set_facecolor('C0')
    # plt.savefig('/content/drive/MyDrive/IIB project/mwt/ip-op/fig{}.png'.format(index))
    # files.download("/content/drive/MyDrive/IIB project/mwt/ip-op/fig{}.png".format(index)) 
    # plt.clf()
    # plt.show()

In [None]:
print(x_test[0].shape)
x = x_test.mean(4)
print(x[0].shape)
print(test_u[0].shape)
print(pred[0].shape)

In [None]:
def plot_loss(prediction, test, index, timestep):
    test = test[index][:, :, timestep]
    prediction = torch.t(prediction[index][:, :, timestep])
    loss = torch.sub(test, prediction)
    
    print(test)
    print(prediction)
    print(loss)
    
    plt.matshow(test)
    plt.matshow(prediction)
    plt.matshow(loss)

In [None]:
plot_loss(pred, test_u, 10, 10)