<h2>IMPORTS</h2>

In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import shutil
import time
from argparse import ArgumentParser
import numpy as np
import sys
import json

import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from PIL import Image

import models
from utils.train_utils import train_xent, test_acc
from utils import loaders
from utils.model_utils import get_num_parameters
from utils.misc import dump_list_element_1line
from torchvision import datasets, transforms

from torch.utils import data
from torchsummary import summary

from tqdm import tqdm

<h2>ARGUMENTS</h2>

In [None]:
batch_size = 8
epochs = 1

optimizer = 'adam'
momentum = 0.9
nesterov = False
decay = 0.0001
lr = 0.01
lr_steps = [20, 40]
lr_gamma = 0.1

# model = model_names
extra_scaling = 1.0
save_model_path = 'saved_models/trial_1.pt'
tag = ''
data_dir = 'simulated_data'

use_cuda = True
# data_dir = 'datasets/MNIST_scale/seed_0/scale_0.3_1.0/'

In [None]:
use_cuda = use_cuda and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Device: {}'.format(device))

<h2>DATA LOADER</h2>

In [None]:
class Dataset(data.Dataset):
    def __init__(self, indices, direc):
        self.list_IDs = indices
        self.direc = direc
      
    def __len__(self):
        return len(self.list_IDs)
    
    def __getitem__(self, index):
        ID = self.list_IDs[index]
        x = torch.load(self.direc + 'h_' + str(ID) + '.pt')
        y = torch.load(self.direc + 'T_' + str(ID) + '.pt')
            
        return x.float(), y.float()

In [None]:
batch_size = 8

print(batch_size)

train_direc = 'simulated_data/'
valid_direc = 'simulated_data/'
test_direc = 'simulated_data/'

train_indices = list(range(0, 1))
valid_indices = list(range(0, 1))
test_indices = list(range(0, 1))

''' Load Data '''

train_set = Dataset(train_indices, train_direc)
valid_set = Dataset(valid_indices, valid_direc)
test_set = Dataset(test_indices, test_direc)

train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 0)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = True, num_workers = 0)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 0)

<h2>LOAD MODEL</h2>

In [None]:
model = models.__dict__['mnist_ses_scalar_200']
model = model()
print('\nModel:')
print(model)
print()

In [None]:
if use_cuda:
    cudnn.enabled = True
    cudnn.benchmark = True
    print('CUDNN is enabled. CUDNN benchmark is enabled')
    model.cuda()

In [None]:
print('num_params:', get_num_parameters(model))
print(flush=True)

In [None]:
summary(model, (1, 200, 200))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {}'.format(device))

In [None]:
parameters = filter(lambda x: x.requires_grad, model.parameters())
parameters

In [None]:
if optimizer == 'adam':
    optimizer = optim.Adam(parameters, lr=lr)
optimizer

In [None]:
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, lr_steps, lr_gamma)
lr_scheduler

In [None]:
loss_fun = torch.nn.MSELoss()

In [None]:
print('\nTraining\n' + '-' * 30)

In [None]:
n_epochs = 60
train_mse = []

In [None]:
''' Train epoch function '''

def train_epoch(train_loader, model, optimizer, loss_function):
    train_mse = []
    for xx, yy in train_loader:
        xx = xx.to(device)
        yy = yy.to(device)
        
        xx = xx.unsqueeze(1)
        yy = yy.unsqueeze(1)
        
        # print(xx.shape, yy.shape)
        loss = 0
        ims = []
        for y in yy.transpose(0,1):
            im = model(xx)
            im = im.squeeze(1)
            # print('im: ', im.shape)
            # print('y: ', y.shape)
            im = im.unsqueeze(1)
            xx = torch.cat([xx[:, 2:], im], 1)
            loss += loss_function(im, y)
        train_mse.append(loss.item()/yy.shape[1]) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_mse = round(np.sqrt(np.mean(train_mse)),5)
    return train_mse

In [None]:
''' Eval epoch function '''

def eval_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            for y in yy.transpose(0, 1):
                # y = y.unsqueeze(1)
                im = model(xx)
                im = im.squeeze(1)
                im = im.unsqueeze(1)
                xx = torch.cat([xx[:, 2:], im], 1)
                loss += loss_function(im, y)
                ims.append(im.unsqueeze(1).cpu().data.numpy())
                
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = round(np.sqrt(np.mean(valid_mse)), 5)
    return valid_mse, preds, trues

In [None]:
''' Test epoch function '''

def test_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        loss_curve = []
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            
            for y in yy.transpose(0, 1):
                # y = y.unsqueeze(1)
                im = model(xx)
                im = im.squeeze(1)
                im = im.unsqueeze(1)
                xx = torch.cat([xx[:, 2:], im], 1)
                mse = loss_function(im, y)
                loss += mse
                loss_curve.append(mse.item())
                ims.append(im.unsqueeze(1).cpu().data.numpy())
           
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
            
        loss_curve = np.array(loss_curve).reshape(-1,yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = np.mean(valid_mse)
        loss_curve = np.sqrt(np.mean(loss_curve, axis = 0))
    return valid_mse, preds, trues, loss_curve

In [None]:
train_mse = []
valid_mse = []
test_mse = []
times = []

min_mse = 100

n_epochs = 60

In [None]:
for i in tqdm(range(n_epochs)):

    print('EPOCH: ', i+1)

    start = time.time()
    optimizer.step()

    model.train()
    print('Model trained')

    train_mse.append(train_epoch(train_loader, model, optimizer, loss_fun))
    model.eval()
    mse, _, _ = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)
    
    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1] 
        best_model = model

    end = time.time()
    
    times.append(end-start)
    
    # Early Stopping but train at least for 50 epochs
    # if (len(train_mse) > 50 and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
    #         break
            
    print('TRAIN MSE: ', train_mse[-1])
    print('VALID MSE: ', valid_mse[-1])
    print('TIME: ', end - start)
    print('----------------------------------')

test_mse, preds, trues, loss_curve = test_epoch(test_loader, best_model, loss_fun)

In [None]:
import matplotlib.pyplot as plt

''' Plot Loss Curves '''

plt.plot(train_mse, label='Train')
# plt.plot(valid_mse, label='Valid')
plt.xlabel('Epoch #')
plt.ylabel('MSE')
plt.title('MSE')
plt.legend()
plt.grid()
plt.show()

In [None]:
import matplotlib.pyplot as plt

for xx, yy in train_loader:
    
    xx = xx.unsqueeze(1).to(device="cuda")
    
    pred = model(xx)
    
    plt.imshow(xx[0][0].cpu(), cmap = 'hot')
    plt.title('Height')
    plt.colorbar()
    plt.show()
    
    plt.imshow(yy[0].cpu(), cmap = 'hot')
    plt.title('Orig Temperature')
    plt.colorbar()
    plt.show()
    
    plt.imshow(pred[0][0].cpu().detach().numpy(), cmap = 'hot')
    plt.title('Pred Temperature')
    plt.colorbar()
    plt.show()
    
    break

In [None]:
test_mse