In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '2'
os.environ['export OPENBLAS_NUM_THREADS']='2'

In [2]:
import torch
import torch.nn as nn
import torch.fft
import numpy as np
from scipy.io import loadmat, savemat
import math
import os
import h5py
import matplotlib.pyplot as plt
from functools import partial
from models.models import MWT2d
from models.utils import train, test, LpLoss, get_filter, UnitGaussianNormalizer

In [3]:
print(torch.__version__)

1.13.1+cu117


In [4]:
torch.manual_seed(0)
np.random.seed(0)

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
def get_initializer(name):
    
    if name == 'xavier_normal':
        init_ = partial(nn.init.xavier_normal_)
    elif name == 'kaiming_uniform':
        init_ = partial(nn.init.kaiming_uniform_)
    elif name == 'kaiming_normal':
        init_ = partial(nn.init.kaiming_normal_)
    return init_

In [8]:
data_path = 'Data/temps_train.npy'
ntrain = 3800
ntest = 616

r = 1
h = int(((64 - 1)/r) + 1)
s = h

dataloader = np.load(data_path)
print(dataloader.shape)
u_data = dataloader.astype(np.float32)

x_train = torch.from_numpy(u_data[:ntrain, ::r,::r, 0])
y_train = torch.from_numpy(u_data[:ntrain, ::r,::r, 1])

x_test = torch.from_numpy(u_data[-ntest:, ::r,::r, 0])
y_test = torch.from_numpy(u_data[-ntest:, ::r,::r, 1])

(4416, 64, 64, 2)


In [9]:
x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_test = x_normalizer.encode(x_test)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

grids = []
grids.append(np.linspace(0, 1, s))
grids.append(np.linspace(0, 1, s))
grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T
grid = grid.reshape(1,s,s,2)
grid = torch.tensor(grid, dtype=torch.float)
x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3)
x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3)

In [10]:
# # data_path = 'Data/melt64/melting.npy'
# data_path = 'Data/melt64/melting.npy'
# dataloader = np.load(data_path)

# ntrain = 170
# ntest = 30

# r = 1
# h = int(((64 - 1)/r) + 1)
# s = h

In [11]:
batch_size = 20
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=True)

In [12]:
ich = 3
initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform

torch.manual_seed(0)
np.random.seed(0)

model = MWT2d(ich, 
            alpha = 12,
            c = 4,
            k = 4, 
            base = 'legendre', # 'chebyshev'
            nCZ = 4,
            L = 0,
            initializer = initializer,
            ).to(device)

learning_rate = 0.001
epochs = 5000
step_size = 100
gamma = 0.5

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)
y_normalizer.cuda()

In [14]:
train_loss = []
test_loss = []
for epoch in range(1, epochs+1):
    train_l2 = train(model, train_loader, optimizer, epoch, device,
        lossFn = myloss, lr_schedule = scheduler,
        post_proc = y_normalizer.decode)
    train_loss.append(train_l2)
    test_l2 = test(model, test_loader, device, lossFn=myloss, post_proc=y_normalizer.decode)
    print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')
    test_loss.append(test_l2)
        
    if epoch%100 == 0:
        PATH = 'NS_models/temps_new/temps_new{}.pt'.format(epoch)
        torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': myloss}, PATH)
        
        np.save('visual/train_loss_temps_much_data.npy', train_loss)
        np.save('visual/test_loss_temps_much_data.npy', test_loss)

epoch: 1, train l2 = 0.14023416360742166, test l2 = 0.08355708768615475
epoch: 2, train l2 = 0.07588366008118579, test l2 = 0.08258027438219491
epoch: 3, train l2 = 0.06464024584544333, test l2 = 0.07354230617547965
epoch: 4, train l2 = 0.06070535225303549, test l2 = 0.05038590119643645
epoch: 5, train l2 = 0.056148563905766134, test l2 = 0.061369071436392794
epoch: 6, train l2 = 0.05577198911654322, test l2 = 0.046795936277160395
epoch: 7, train l2 = 0.053693953579977936, test l2 = 0.0494527192665385
epoch: 8, train l2 = 0.049702709254465606, test l2 = 0.044418110766194084
epoch: 9, train l2 = 0.051230339737314924, test l2 = 0.04260125698207261
epoch: 10, train l2 = 0.05354311256032241, test l2 = 0.04556868273716468
epoch: 11, train l2 = 0.05150753734927428, test l2 = 0.04985572088074375
epoch: 12, train l2 = 0.04855660590686296, test l2 = 0.04056264827777813
epoch: 13, train l2 = 0.04494076333547893, test l2 = 0.04137165631566729
epoch: 14, train l2 = 0.0461248777414623, test l2 = 0.

epoch: 113, train l2 = 0.037246450891620236, test l2 = 0.03848061942821973
epoch: 114, train l2 = 0.03815368242953953, test l2 = 0.037260713660484786
epoch: 115, train l2 = 0.03734507581905315, test l2 = 0.04068565146102534
epoch: 116, train l2 = 0.03704279485501741, test l2 = 0.038250515406781975
epoch: 117, train l2 = 0.036991114506596014, test l2 = 0.03708292031055921
epoch: 118, train l2 = 0.03705451263408912, test l2 = 0.036839434175522295
epoch: 119, train l2 = 0.03722797536536267, test l2 = 0.03809897712879367
epoch: 120, train l2 = 0.03759795093222668, test l2 = 0.037386746665874084
epoch: 121, train l2 = 0.03720612743967458, test l2 = 0.03881070250040525
epoch: 122, train l2 = 0.03857629087410475, test l2 = 0.03857764563003144
epoch: 123, train l2 = 0.03796795824640676, test l2 = 0.03856435451995243
epoch: 124, train l2 = 0.03720957622716301, test l2 = 0.03750046400667785
epoch: 125, train l2 = 0.0382795049799116, test l2 = 0.04463109519187506
epoch: 126, train l2 = 0.03773752

epoch: 223, train l2 = 0.035972677265342916, test l2 = 0.03701653563744062
epoch: 224, train l2 = 0.036101119204571375, test l2 = 0.03686819380366957
epoch: 225, train l2 = 0.036088741531497555, test l2 = 0.03848907728860905
epoch: 226, train l2 = 0.0360955617537624, test l2 = 0.036863112217420106
epoch: 227, train l2 = 0.03580434794488706, test l2 = 0.037162836979735984
epoch: 228, train l2 = 0.03611980177854237, test l2 = 0.03740678505076991
epoch: 229, train l2 = 0.03605048534117247, test l2 = 0.03670611503449353
epoch: 230, train l2 = 0.03597095040114302, test l2 = 0.03777651227526851
epoch: 231, train l2 = 0.036600459602318315, test l2 = 0.04171906556788977
epoch: 232, train l2 = 0.03680959782318065, test l2 = 0.03798094382146736
epoch: 233, train l2 = 0.036143892893665715, test l2 = 0.03723554371239303
epoch: 234, train l2 = 0.035933170977391696, test l2 = 0.03716547105025936
epoch: 235, train l2 = 0.03582658292431581, test l2 = 0.036599120052603934
epoch: 236, train l2 = 0.03611

epoch: 333, train l2 = 0.03532417889488371, test l2 = 0.036339994955372504
epoch: 334, train l2 = 0.035412205183192304, test l2 = 0.037627109749750656
epoch: 335, train l2 = 0.03539281498444708, test l2 = 0.036454328275346136
epoch: 336, train l2 = 0.03578735895062748, test l2 = 0.03665118064586218
epoch: 337, train l2 = 0.03549812782751886, test l2 = 0.03635101184829489
epoch: 338, train l2 = 0.03518662568769957, test l2 = 0.036515504121780396
epoch: 339, train l2 = 0.03531960056800591, test l2 = 0.037301430164219496
epoch: 340, train l2 = 0.03552633355322637, test l2 = 0.037539666155716044
epoch: 341, train l2 = 0.03532435137190317, test l2 = 0.03667046400633725
epoch: 342, train l2 = 0.03528787040396741, test l2 = 0.03647717852871139
epoch: 343, train l2 = 0.0357334099396279, test l2 = 0.03651296705394596
epoch: 344, train l2 = 0.035812527262850814, test l2 = 0.0368668031576392
epoch: 345, train l2 = 0.035530934631824496, test l2 = 0.037555757764872015
epoch: 346, train l2 = 0.03535

epoch: 444, train l2 = 0.035065390781352394, test l2 = 0.03655675111652969
epoch: 445, train l2 = 0.035108587749694525, test l2 = 0.03700052559762806
epoch: 446, train l2 = 0.035092600952637824, test l2 = 0.03677010332996195
epoch: 447, train l2 = 0.03518327800851119, test l2 = 0.03697612717167124
epoch: 448, train l2 = 0.03513821691274643, test l2 = 0.036348885384859975
epoch: 449, train l2 = 0.03506395138408008, test l2 = 0.03663142131907599
epoch: 450, train l2 = 0.03509721979498863, test l2 = 0.036361299406785466
epoch: 451, train l2 = 0.03509780397540645, test l2 = 0.0367640746007492
epoch: 452, train l2 = 0.0350877557694912, test l2 = 0.03637914224104448
epoch: 453, train l2 = 0.03510356501529091, test l2 = 0.03641752347156599
epoch: 454, train l2 = 0.03505546607469258, test l2 = 0.03646830943497745
epoch: 455, train l2 = 0.03502698345403922, test l2 = 0.03635831511059365
epoch: 456, train l2 = 0.03511263346201495, test l2 = 0.037390356714075264
epoch: 457, train l2 = 0.035216968

epoch: 555, train l2 = 0.03491106091361297, test l2 = 0.03645701038760024
epoch: 556, train l2 = 0.03501404061129219, test l2 = 0.0363465001830807
epoch: 557, train l2 = 0.03487793695769812, test l2 = 0.036529295339986875
epoch: 558, train l2 = 0.03493584572484619, test l2 = 0.03632769007961471
epoch: 559, train l2 = 0.0349077707058505, test l2 = 0.036510028912649525
epoch: 560, train l2 = 0.03495706060999318, test l2 = 0.036347153608675124
epoch: 561, train l2 = 0.03489306043637426, test l2 = 0.03646196386256775
epoch: 562, train l2 = 0.03489437099350126, test l2 = 0.03671237255458708
epoch: 563, train l2 = 0.03493736288265178, test l2 = 0.0365771208103601
epoch: 564, train l2 = 0.03494784628874377, test l2 = 0.036317807416637225
epoch: 565, train l2 = 0.03489288865735656, test l2 = 0.036337013271721924
epoch: 566, train l2 = 0.03489892902342897, test l2 = 0.03648730333555828
epoch: 567, train l2 = 0.03493775632820632, test l2 = 0.03631818400962012
epoch: 568, train l2 = 0.03493918573

epoch: 666, train l2 = 0.034815315761064225, test l2 = 0.03639931023701445
epoch: 667, train l2 = 0.0348327851687607, test l2 = 0.0363342859721803
epoch: 668, train l2 = 0.03481201022863388, test l2 = 0.03635597325764693
epoch: 669, train l2 = 0.034830785487827504, test l2 = 0.036503116805832105
epoch: 670, train l2 = 0.03482092171907425, test l2 = 0.03629899179780638
epoch: 671, train l2 = 0.034804636938007255, test l2 = 0.036326169145184675
epoch: 672, train l2 = 0.03478221734103404, test l2 = 0.03637721645948175
epoch: 673, train l2 = 0.034839037241120085, test l2 = 0.0363445949148048
epoch: 674, train l2 = 0.03482565428081311, test l2 = 0.036472681377615244
epoch: 675, train l2 = 0.03480521288357283, test l2 = 0.03633751923387701
epoch: 676, train l2 = 0.034827433040267544, test l2 = 0.03641131225150901
epoch: 677, train l2 = 0.03481339337794404, test l2 = 0.03629323820789139
epoch: 678, train l2 = 0.03480309680104256, test l2 = 0.03630319590885918
epoch: 679, train l2 = 0.03482803

epoch: 777, train l2 = 0.0347612842290025, test l2 = 0.03636500790908739
epoch: 778, train l2 = 0.03479076801946289, test l2 = 0.03635198318145492
epoch: 779, train l2 = 0.034768380248232894, test l2 = 0.03637535289510504
epoch: 780, train l2 = 0.03477962565265204, test l2 = 0.03644695045886102
epoch: 781, train l2 = 0.034774527565429085, test l2 = 0.03636857344732656
epoch: 782, train l2 = 0.03477458783670476, test l2 = 0.03638715684026867
epoch: 783, train l2 = 0.0348023626992577, test l2 = 0.03630870358123408
epoch: 784, train l2 = 0.0347666471255453, test l2 = 0.03643516567233321
epoch: 785, train l2 = 0.03479730979392403, test l2 = 0.03629684148283748
epoch: 786, train l2 = 0.03476703297150763, test l2 = 0.03632441172739128
epoch: 787, train l2 = 0.03482977983198668, test l2 = 0.03630169113348057
epoch: 788, train l2 = 0.034744368117106586, test l2 = 0.03630895119208794
epoch: 789, train l2 = 0.03476414967524378, test l2 = 0.03629290031922328
epoch: 790, train l2 = 0.0347749641616

epoch: 887, train l2 = 0.03474246457219124, test l2 = 0.03635670273721992
epoch: 888, train l2 = 0.034738892450144415, test l2 = 0.036334487911942714
epoch: 889, train l2 = 0.03474849538583505, test l2 = 0.036333874545314095
epoch: 890, train l2 = 0.034736804099459395, test l2 = 0.03641122008685942
epoch: 891, train l2 = 0.034745664486759587, test l2 = 0.03632113637475224
epoch: 892, train l2 = 0.03474219329263035, test l2 = 0.03634300871522396
epoch: 893, train l2 = 0.03474805838967625, test l2 = 0.03639621800416476
epoch: 894, train l2 = 0.034748526207710566, test l2 = 0.036353960923560254
epoch: 895, train l2 = 0.03475319821583597, test l2 = 0.0363557885413046
epoch: 896, train l2 = 0.03474489077925682, test l2 = 0.036354754749056584
epoch: 897, train l2 = 0.03476281802905233, test l2 = 0.036375127055428245
epoch: 898, train l2 = 0.03475480178469106, test l2 = 0.03633929890665141
epoch: 899, train l2 = 0.034744555048252405, test l2 = 0.03635507725275956
epoch: 900, train l2 = 0.0347

epoch: 998, train l2 = 0.03473123795892063, test l2 = 0.03633733611408766
epoch: 999, train l2 = 0.03473077481514529, test l2 = 0.036341532435897106
epoch: 1000, train l2 = 0.03473559011754237, test l2 = 0.03635366367442267
epoch: 1001, train l2 = 0.034723680599739674, test l2 = 0.036345361066716056
epoch: 1002, train l2 = 0.03472582533171303, test l2 = 0.03634691393220579
epoch: 1003, train l2 = 0.03472689131372853, test l2 = 0.03634064064010397
epoch: 1004, train l2 = 0.034723546724570425, test l2 = 0.0363352932519727
epoch: 1005, train l2 = 0.034722841774162495, test l2 = 0.036327550937603044
epoch: 1006, train l2 = 0.03472783267498016, test l2 = 0.03633249231747219
epoch: 1007, train l2 = 0.03472678437044746, test l2 = 0.03633637798877506
epoch: 1008, train l2 = 0.03472660560356943, test l2 = 0.036347527831018744
epoch: 1009, train l2 = 0.034723055323487836, test l2 = 0.036335427894607766
epoch: 1010, train l2 = 0.034724916745173307, test l2 = 0.03635562046781763
epoch: 1011, train

epoch: 1107, train l2 = 0.03472134792491009, test l2 = 0.03634854957654879
epoch: 1108, train l2 = 0.03471963098958919, test l2 = 0.036344646246402296
epoch: 1109, train l2 = 0.03472081214189529, test l2 = 0.036349314422189416
epoch: 1110, train l2 = 0.03471972696875271, test l2 = 0.0363489988368827
epoch: 1111, train l2 = 0.03472195424531636, test l2 = 0.03635476413485292
epoch: 1112, train l2 = 0.03472154760831281, test l2 = 0.036346286148219914
epoch: 1113, train l2 = 0.03472057342529297, test l2 = 0.036349272476388264
epoch: 1114, train l2 = 0.03472045312586584, test l2 = 0.0363447806471354
epoch: 1115, train l2 = 0.03472060001994434, test l2 = 0.036357073814837965
epoch: 1116, train l2 = 0.0347202468388959, test l2 = 0.0363466932692311
epoch: 1117, train l2 = 0.03471906803940472, test l2 = 0.03634717915352289
epoch: 1118, train l2 = 0.034719182833244926, test l2 = 0.03634409500019891
epoch: 1119, train l2 = 0.03472132527514508, test l2 = 0.036348542658152516
epoch: 1120, train l2 

epoch: 1216, train l2 = 0.03471800344555002, test l2 = 0.03634742463563944
epoch: 1217, train l2 = 0.03471851365346658, test l2 = 0.03634650356970824
epoch: 1218, train l2 = 0.034717845752051, test l2 = 0.036348995063212014
epoch: 1219, train l2 = 0.03471947924086922, test l2 = 0.03634884537427456
epoch: 1220, train l2 = 0.03471804267481754, test l2 = 0.03634877735144132
epoch: 1221, train l2 = 0.03471792641438936, test l2 = 0.036348582184933996
epoch: 1222, train l2 = 0.03471825238121183, test l2 = 0.036347978687905645
epoch: 1223, train l2 = 0.034717872040836435, test l2 = 0.03634901547973806
epoch: 1224, train l2 = 0.03471792894758676, test l2 = 0.036348908655829244
epoch: 1225, train l2 = 0.03471836965335043, test l2 = 0.036348940199845796
epoch: 1226, train l2 = 0.034717822561138555, test l2 = 0.036348501244535696
epoch: 1227, train l2 = 0.03471889786030117, test l2 = 0.03634903986345638
epoch: 1228, train l2 = 0.03471952729319271, test l2 = 0.03634780253489296
epoch: 1229, train 

epoch: 1325, train l2 = 0.034717127782733816, test l2 = 0.03634797704297227
epoch: 1326, train l2 = 0.03471675011672472, test l2 = 0.03634759299940877
epoch: 1327, train l2 = 0.034717075181634804, test l2 = 0.03634837690692443
epoch: 1328, train l2 = 0.03471701062039325, test l2 = 0.03634845417041283
epoch: 1329, train l2 = 0.034717470964318826, test l2 = 0.03634873463155387
epoch: 1330, train l2 = 0.034717427311759246, test l2 = 0.036347448438793034
epoch: 1331, train l2 = 0.03471722686761304, test l2 = 0.03634704214024853
epoch: 1332, train l2 = 0.03471704472836695, test l2 = 0.036347389898516914
epoch: 1333, train l2 = 0.03471711930475737, test l2 = 0.036348152857322194
epoch: 1334, train l2 = 0.03471732930917489, test l2 = 0.036347261109909455
epoch: 1335, train l2 = 0.03471684026874994, test l2 = 0.03634818749768393
epoch: 1336, train l2 = 0.034716944223956055, test l2 = 0.03634837671340286
epoch: 1337, train l2 = 0.03471704538715513, test l2 = 0.0363483533940532
epoch: 1338, trai

epoch: 1434, train l2 = 0.03471638402656505, test l2 = 0.03634838440588543
epoch: 1435, train l2 = 0.0347162709816506, test l2 = 0.03634821410690035
epoch: 1436, train l2 = 0.03471622700753965, test l2 = 0.03634847511912321
epoch: 1437, train l2 = 0.034716386661717766, test l2 = 0.03634877367453142
epoch: 1438, train l2 = 0.034716177974876604, test l2 = 0.03634858334606344
epoch: 1439, train l2 = 0.03471620813012123, test l2 = 0.03634876332112721
epoch: 1440, train l2 = 0.034716175543634514, test l2 = 0.03634837768101073
epoch: 1441, train l2 = 0.034716315292998366, test l2 = 0.036348771303892136
epoch: 1442, train l2 = 0.03471661660232042, test l2 = 0.03634810302551691
epoch: 1443, train l2 = 0.03471636350217618, test l2 = 0.036348423110200215
epoch: 1444, train l2 = 0.03471639326528499, test l2 = 0.03634865098185353
epoch: 1445, train l2 = 0.03471639593965129, test l2 = 0.03634851116251636
epoch: 1446, train l2 = 0.034716376403444694, test l2 = 0.036348529450305096
epoch: 1447, train

epoch: 1543, train l2 = 0.03471591969069682, test l2 = 0.036348483198648925
epoch: 1544, train l2 = 0.03471599502783072, test l2 = 0.03634860724597782
epoch: 1545, train l2 = 0.0347159464343598, test l2 = 0.036348494616421784
epoch: 1546, train l2 = 0.03471591908680765, test l2 = 0.03634853457862681
epoch: 1547, train l2 = 0.03471602772411547, test l2 = 0.0363483763747401
epoch: 1548, train l2 = 0.034716006438983114, test l2 = 0.036348391904846414
epoch: 1549, train l2 = 0.034715896891920194, test l2 = 0.03634867333359532
epoch: 1550, train l2 = 0.03471588240642297, test l2 = 0.036348433850647566
epoch: 1551, train l2 = 0.034715912098947324, test l2 = 0.03634855228585082
epoch: 1552, train l2 = 0.034716078795884786, test l2 = 0.03634817656371501
epoch: 1553, train l2 = 0.03471595735926377, test l2 = 0.036348692879274294
epoch: 1554, train l2 = 0.0347159948239201, test l2 = 0.03634866365751663
epoch: 1555, train l2 = 0.034715874634290995, test l2 = 0.036348433511984815
epoch: 1556, trai

epoch: 1652, train l2 = 0.034715733418339174, test l2 = 0.0363482627775762
epoch: 1653, train l2 = 0.034715760083575, test l2 = 0.03634828812890239
epoch: 1654, train l2 = 0.034715777753215084, test l2 = 0.03634834637889614
epoch: 1655, train l2 = 0.034715800551991714, test l2 = 0.03634839417872491
epoch: 1656, train l2 = 0.03471572620304007, test l2 = 0.03634843085106317
epoch: 1657, train l2 = 0.034715778035552876, test l2 = 0.03634845610562857
epoch: 1658, train l2 = 0.03471576197366966, test l2 = 0.036348441204467376
epoch: 1659, train l2 = 0.03471583444036935, test l2 = 0.036348305449083254
epoch: 1660, train l2 = 0.034715749684133025, test l2 = 0.03634840046817606
epoch: 1661, train l2 = 0.03471573837493595, test l2 = 0.036348375649034204
epoch: 1662, train l2 = 0.03471576498527276, test l2 = 0.03634845063864411
epoch: 1663, train l2 = 0.03471576938503667, test l2 = 0.03634836089301419
epoch: 1664, train l2 = 0.03471583605596894, test l2 = 0.03634864266042585
epoch: 1665, train l

epoch: 1761, train l2 = 0.03471569814964345, test l2 = 0.03634894261886547
epoch: 1762, train l2 = 0.03471568054274509, test l2 = 0.0363489892091844
epoch: 1763, train l2 = 0.03471569323226025, test l2 = 0.0363489899348903
epoch: 1764, train l2 = 0.03471569582819939, test l2 = 0.036348992934474696
epoch: 1765, train l2 = 0.03471568695808712, test l2 = 0.03634898533875292
epoch: 1766, train l2 = 0.034715684056282045, test l2 = 0.03634898296811364
epoch: 1767, train l2 = 0.034715682056389356, test l2 = 0.03634895035972843
epoch: 1768, train l2 = 0.03471568668359205, test l2 = 0.03634900575527897
epoch: 1769, train l2 = 0.03471569227544885, test l2 = 0.036348992353909974
epoch: 1770, train l2 = 0.03471569060495025, test l2 = 0.03634894136097524
epoch: 1771, train l2 = 0.03471569114609768, test l2 = 0.03634890333398596
epoch: 1772, train l2 = 0.0347157037650284, test l2 = 0.03634895771354824
epoch: 1773, train l2 = 0.034715682472053325, test l2 = 0.036348944747602786
epoch: 1774, train l2 

epoch: 1870, train l2 = 0.03471565288148428, test l2 = 0.03634899806279641
epoch: 1871, train l2 = 0.03471565860666727, test l2 = 0.036349013738043896
epoch: 1872, train l2 = 0.03471565750084425, test l2 = 0.036349013689663506
epoch: 1873, train l2 = 0.034715659681119414, test l2 = 0.03634900783563589
epoch: 1874, train l2 = 0.0347156554225244, test l2 = 0.03634899583729831
epoch: 1875, train l2 = 0.03471566247312646, test l2 = 0.03634898688692551
epoch: 1876, train l2 = 0.0347156555636933, test l2 = 0.03634898790291378
epoch: 1877, train l2 = 0.03471566070067255, test l2 = 0.036348993176376666
epoch: 1878, train l2 = 0.03471565529704094, test l2 = 0.03634898780615299
epoch: 1879, train l2 = 0.03471565921055643, test l2 = 0.03634898166184301
epoch: 1880, train l2 = 0.03471566103790936, test l2 = 0.036348974211262416
epoch: 1881, train l2 = 0.034715663610320344, test l2 = 0.036348981516701837
epoch: 1882, train l2 = 0.034715659265455445, test l2 = 0.036348971550340774
epoch: 1883, train

epoch: 1979, train l2 = 0.03471565148548076, test l2 = 0.03634895161761866
epoch: 1980, train l2 = 0.034715652136426224, test l2 = 0.03634895331093243
epoch: 1981, train l2 = 0.03471565354811518, test l2 = 0.03634895466558345
epoch: 1982, train l2 = 0.03471565418337521, test l2 = 0.036348954326920695
epoch: 1983, train l2 = 0.03471565165802052, test l2 = 0.03634895752002666
epoch: 1984, train l2 = 0.03471565302265318, test l2 = 0.03634895423015991
epoch: 1985, train l2 = 0.034715652748158105, test l2 = 0.036348951230575514
epoch: 1986, train l2 = 0.03471565194035831, test l2 = 0.03634895316579125
epoch: 1987, train l2 = 0.034715652654045505, test l2 = 0.03634895437530109
epoch: 1988, train l2 = 0.03471565164233509, test l2 = 0.036348959164960046
epoch: 1989, train l2 = 0.0347156507404227, test l2 = 0.0363489624548268
epoch: 1990, train l2 = 0.03471565125019927, test l2 = 0.03634895423015991
epoch: 1991, train l2 = 0.03471565158743607, test l2 = 0.03634895108543433
epoch: 1992, train l2

epoch: 2088, train l2 = 0.0347156496424424, test l2 = 0.036348958342493354
epoch: 2089, train l2 = 0.034715649211093, test l2 = 0.03634896429328175
epoch: 2090, train l2 = 0.03471564879542903, test l2 = 0.036348956455658006
epoch: 2091, train l2 = 0.034715647438639086, test l2 = 0.036348960180948305
epoch: 2092, train l2 = 0.034715649195407566, test l2 = 0.036348958342493354
epoch: 2093, train l2 = 0.034715649195407566, test l2 = 0.0363489580038306
epoch: 2094, train l2 = 0.034715647791561326, test l2 = 0.03634895776192863
epoch: 2095, train l2 = 0.034715649344419175, test l2 = 0.03634895423015991
epoch: 2096, train l2 = 0.03471564887385619, test l2 = 0.03634895669755998
epoch: 2097, train l2 = 0.034715649375790046, test l2 = 0.03634896110017578
epoch: 2098, train l2 = 0.03471564769744873, test l2 = 0.03634895819735218
epoch: 2099, train l2 = 0.034715649093452254, test l2 = 0.0363489550042462
epoch: 2100, train l2 = 0.03471564982282488, test l2 = 0.03634895171437945
epoch: 2101, train 

epoch: 2197, train l2 = 0.034715649203250286, test l2 = 0.03634895655241879
epoch: 2198, train l2 = 0.034715649618914254, test l2 = 0.03634895597185407
epoch: 2199, train l2 = 0.03471564932873374, test l2 = 0.03634896114855617
epoch: 2200, train l2 = 0.03471565017574712, test l2 = 0.0363489609550346
epoch: 2201, train l2 = 0.034715649289520166, test l2 = 0.03634895848763454
epoch: 2202, train l2 = 0.034715649195407566, test l2 = 0.0363489587779169
epoch: 2203, train l2 = 0.03471564965028512, test l2 = 0.03634895814897178
epoch: 2204, train l2 = 0.03471564924246386, test l2 = 0.03634895519776778
epoch: 2205, train l2 = 0.034715648669945566, test l2 = 0.03634895964876398
epoch: 2206, train l2 = 0.034715648936597926, test l2 = 0.03634895781030903
epoch: 2207, train l2 = 0.03471565027770243, test l2 = 0.03634896158397972
epoch: 2208, train l2 = 0.03471565061493924, test l2 = 0.0363489572297443
epoch: 2209, train l2 = 0.034715648850328046, test l2 = 0.03634895539128935
epoch: 2210, train l2

epoch: 2306, train l2 = 0.03471564880327174, test l2 = 0.036348963228913095
epoch: 2307, train l2 = 0.034715649375790046, test l2 = 0.03634895819735218
epoch: 2308, train l2 = 0.034715648748372734, test l2 = 0.036348958294112964
epoch: 2309, train l2 = 0.034715649085609535, test l2 = 0.0363489580038306
epoch: 2310, train l2 = 0.034715648575832966, test l2 = 0.03634895839087375
epoch: 2311, train l2 = 0.034715649085609535, test l2 = 0.03634895572995211
epoch: 2312, train l2 = 0.03471564818369715, test l2 = 0.036348961003414997
epoch: 2313, train l2 = 0.034715648324866046, test l2 = 0.03634895824573257
epoch: 2314, train l2 = 0.034715649446374494, test l2 = 0.036348960761513026
epoch: 2315, train l2 = 0.034715648567990254, test l2 = 0.0363489557783325
epoch: 2316, train l2 = 0.0347156496894987, test l2 = 0.036348957665167846
epoch: 2317, train l2 = 0.0347156501208481, test l2 = 0.036348952778748105
epoch: 2318, train l2 = 0.034715649195407566, test l2 = 0.03634895626213643
epoch: 2319, t

epoch: 2415, train l2 = 0.03471564941500362, test l2 = 0.0363489565040384
epoch: 2416, train l2 = 0.034715648677788286, test l2 = 0.036348957955450206
epoch: 2417, train l2 = 0.034715649258149295, test l2 = 0.03634895926172083
epoch: 2418, train l2 = 0.03471564852877667, test l2 = 0.036348957132983514
epoch: 2419, train l2 = 0.03471564885817076, test l2 = 0.03634895993904634
epoch: 2420, train l2 = 0.03471564846603494, test l2 = 0.036348956842701154
epoch: 2421, train l2 = 0.034715648936597926, test l2 = 0.03634895606861486
epoch: 2422, train l2 = 0.03471564904639596, test l2 = 0.036348961390458144
epoch: 2423, train l2 = 0.034715648481720374, test l2 = 0.03634895495586581
epoch: 2424, train l2 = 0.034715649093452254, test l2 = 0.036348955875093285
epoch: 2425, train l2 = 0.03471564918756485, test l2 = 0.036348954326920695
epoch: 2426, train l2 = 0.034715649281677446, test l2 = 0.03634895824573257
epoch: 2427, train l2 = 0.03471564869347372, test l2 = 0.03634896182588169
epoch: 2428, t

KeyboardInterrupt: 

In [None]:
# PATH = 'NS_models/melt/melt.pt'

# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

# model.eval()

In [None]:
pred = torch.zeros(y_test.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False)
with torch.no_grad():
    for x, y in test_loader:
        test_l2 = 0
        x, y = x.cuda(), y.cuda()
        out = model(x)
        pred[index] = out

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        print(index, test_l2)
        index = index + 1

In [None]:
total_loss = 0.
predictions = []
post_proc=y_normalizer.decode

with torch.no_grad():
    for data, target in test_loader:
        bs = len(data)

        data, target = data.to(device), target.to(device)
        output = model(data)
        output = post_proc(output)

        loss = myloss(output.view(bs, -1), target.view(bs, -1))
        predictions.extend(output.cpu().data.numpy())
        total_loss += loss.sum().item()

predictions = torch.Tensor(predictions)
predictions = torch.reshape(predictions, y_test.shape)

In [None]:
print('x_test.mean.shape', x_test[29, :, :, 0])
print('test_u[0].shape', y_test.shape)
print('predictions', predictions.shape)
print('total loss = ', total_loss)

In [None]:
def plot_loss(initial, prediction, test, index):
    initial = initial[index]
    test = test[index]
    prediction = prediction[index]
    
    loss = abs(torch.sub(test, prediction))
#     test0 = test[index][:, :, 0].T
#     prediction0 = prediction[index][:, :, 0].T
#     loss0 = abs(torch.sub(test0, prediction0))
#     print(test.shape)
#     test39 = test[index][:, :, 39].T
#     prediction39 = prediction[index][:, :, 39].T
#     loss39 = abs(torch.sub(test39, prediction39))
# #     print(test)
# #     print(prediction)
# #     print(loss)
    fig, axs = plt.subplots(2, 2, figsize=(20, 20))

    cp1 = axs[0, 0].matshow(test)
    axs[0, 0].set_title('True velocity field', fontsize=20)
    axs[0, 0].xaxis.set_tick_params(labelsize=18)
    axs[0, 0].yaxis.set_tick_params(labelsize=18)
    cp1 = fig.colorbar(cp1)
    cp1.ax.tick_params(labelsize=18)
    
    cp2 = axs[0, 1].matshow(prediction)
    axs[0, 1].set_title('Predicted velocity field', fontsize=20)
    axs[0, 1].xaxis.set_tick_params(labelsize=18)
    axs[0, 1].yaxis.set_tick_params(labelsize=18)
    cp2 = fig.colorbar(cp2)
    cp2.ax.tick_params(labelsize=18)
    
    cp3 = axs[1, 0].matshow(loss/100)
    axs[1, 0].set_title('Absolute difference between model \n prediction and true value', fontsize=20)
    axs[1, 0].xaxis.set_tick_params(labelsize=18)
    axs[1, 0].yaxis.set_tick_params(labelsize=18)
    cp3 = fig.colorbar(cp3)
    cp3.ax.tick_params(labelsize=18)
    
    
    train = np.load('visual/ice_train_loss_xyu.npy')
    test = np.load('visual/ice_test_loss_xyu.npy')
    axs[1, 1].plot(train[100:1100], label = 'Train loss')
    axs[1, 1].plot(test[100:1100], label = 'Test loss')
    axs[1, 1].set_title('End of predicted field', fontsize=20)
    axs[1, 1].xaxis.set_tick_params(labelsize=18)
    axs[1, 1].yaxis.set_tick_params(labelsize=18)
    axs[1, 1].tick_params(labelsize=18)
    axs[1, 1].legend(prop={'size': 18})
    fig.show()

In [None]:
plot_loss(x_train, predictions, y_test, 0)

In [None]:
cp1 = plt.matshow(y_train[0])
plt.colorbar(cp1)