In [None]:
''' Imports '''

import warnings
warnings.filterwarnings('ignore')

import os
import json
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils import data
import itertools
import re
import random
import time
from torch.autograd import Variable
import math
from scipy.ndimage import gaussian_filter
from torch.autograd import Variable
from tqdm import tqdm

In [None]:
''' Train epoch function '''

def train_epoch(train_loader, model, optimizer, loss_function):
    train_mse = []

    for xx, yy in train_loader:
        
        xx = xx.to(device)
        yy = yy.to(device)

        xx = xx.unsqueeze(1)
        yy = yy.unsqueeze(1)

        loss = 0
        ims = []

        for y in yy.transpose(0, 1):
            y = y.unsqueeze(1)
            im = model(xx)
            # print(im.shape)
            xx = torch.cat([xx[:, 2:], im], 1)
            loss += loss_function(im, y)
        
        train_mse.append(loss.item()/yy.shape[1]) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_mse = round(np.sqrt(np.mean(train_mse)),5)
    return train_mse

In [None]:
''' Eval epoch function '''

def eval_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            for y in yy.transpose(0, 1):
                y = y.unsqueeze(1)
                im = model(xx)
                xx = torch.cat([xx[:, 2:], im], 1)
                loss += loss_function(im, y)
                ims.append(im.unsqueeze(1).cpu().data.numpy())
                
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = round(np.sqrt(np.mean(valid_mse)), 5)
    return valid_mse, preds, trues

In [None]:
''' Test epoch function '''

def test_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        loss_curve = []
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            
            for y in yy.transpose(0, 1):
                y = y.unsqueeze(1)
                im = model(xx)
                xx = torch.cat([xx[:, 2:], im], 1)
                mse = loss_function(im, y)
                loss += mse
                loss_curve.append(mse.item())
                ims.append(im.unsqueeze(1).cpu().data.numpy())
           
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
            
        loss_curve = np.array(loss_curve).reshape(-1,yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = np.mean(valid_mse)
        loss_curve = np.sqrt(np.mean(loss_curve, axis = 0))
    return valid_mse, preds, trues, loss_curve

In [None]:
''' Data Loader '''

class Dataset(data.Dataset):
    def __init__(self, indices, direc):
        self.list_IDs = indices
        self.direc = direc
        
    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        ID = self.list_IDs[index]
        
        x = torch.load(self.direc + 'h_' + str(ID) + ".pt")
        y = torch.load(self.direc + 'T_' + str(ID) + ".pt")
        
        return x.float(), y.float()

In [None]:
''' Dataset Parameters '''

batch_size = 8
train_direc = "../simulated_data_reg/"
valid_direc = "../simulated_data_reg/"
test_direc = "../simulated_data_reg/"
train_indices = list(range(0, 300))
valid_indices = list(range(300, 370))
test_indices = list(range(370, 465))
# train_indices = list(range(0, 200))
# valid_indices = list(range(200, 250))
# test_indices = list(range(250, 270))

''' Load Data '''

train_set = Dataset(train_indices, train_direc)
valid_set = Dataset(valid_indices, valid_direc)
test_set = Dataset(test_indices, test_direc)

train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 8)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = True, num_workers = 8)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 8)

In [None]:
''' Model Hyperparameters '''

n_epochs = 60
learning_rate = 0.001
lr_decay = 0.9

min_mse = 100
train_mse = []
valid_mse = []
test_mse = []
times = []

In [None]:
''' ResNet '''

class Resblock(nn.Module):
    def __init__(self, input_channels, hidden_dim, kernel_size):
        super(Resblock, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(input_channels, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.5)
        ) 
        self.layer2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.5)
        ) 
        
        if input_channels != hidden_dim:
            self.upscale = nn.Sequential(
                nn.Conv2d(input_channels, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2),
                nn.LeakyReLU(0.5)
                )        
        self.input_channels = input_channels
        self.hidden_dim = hidden_dim
        
        
    def forward(self, xx):
        out = self.layer1(xx)  
        if self.input_channels != self.hidden_dim:
            out = self.layer2(out) + self.upscale(xx)
        else:
            out = self.layer2(out) + xx
        return out
    

class ResNet(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size):
        super(ResNet, self).__init__()
        layers = [Resblock(input_channels, 64, kernel_size), Resblock(64, 64, kernel_size)]
        layers += [Resblock(64, 128, kernel_size), Resblock(128, 128, kernel_size)]
        layers += [Resblock(128, 300, kernel_size), Resblock(300, 300, kernel_size)]
        layers += [nn.Conv2d(300, output_channels, kernel_size = kernel_size, padding = (kernel_size-1)//2)]
        # layers += [Resblock(256, 512, kernel_size), Resblock(512, 512, kernel_size)]
        # layers += [nn.Conv2d(512, output_channels, kernel_size = kernel_size, padding = (kernel_size-1)//2)]
        self.model = nn.Sequential(*layers)
             
    def forward(self, xx):
        out = self.model(xx)
        return out

In [None]:
''' Model '''

model = nn.DataParallel(ResNet(input_channels = 1, output_channels = 1, kernel_size = 3).to(device))

optimizer = torch.optim.Adam(model.parameters(), learning_rate,betas=(learning_rate, 0.999), weight_decay=4e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=lr_decay)
loss_fun = torch.nn.MSELoss()

In [None]:
print('Parameters: ', sum(p.numel() for p in model.parameters()))

In [None]:
for i in tqdm(range(n_epochs)):

    print('EPOCH: ', i+1)

    start = time.time()
    scheduler.step()

    model.train()
    print('Model trained')

    train_mse.append(train_epoch(train_loader, model, optimizer, loss_fun))
    model.eval()
    mse, _, _ = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)
    
    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1] 
        best_model = model

    end = time.time()
    
    times.append(end-start)
    
    # Early Stopping but train at least for 50 epochs
    # if (len(train_mse) > 50 and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
    #         break
            
    print('TRAIN MSE: ', train_mse[-1])
    print('VALID MSE: ', valid_mse[-1])
    print('TIME: ', end - start)
    print('----------------------------------')

test_mse, preds, trues, loss_curve = test_epoch(test_loader, best_model, loss_fun)

In [None]:
''' Plot Loss Curves '''

plt.plot(train_mse, label='Train')
plt.plot(valid_mse, label='Valid')
plt.xlabel('Epoch #')
plt.ylabel('MSE')
plt.title('MSE')
plt.legend()
plt.grid()
plt.show()

In [None]:
x = torch.load('../simulated_data_reg/h_0.pt')
y = torch.load('../simulated_data_reg/T_0.pt')

In [None]:
''' EDA '''

row = 100

output = model(x.unsqueeze(0).unsqueeze(0).to(device))
output = output.squeeze(0).squeeze(0)

plt.hist(x[row].cpu().detach().numpy(), color = "blue", label="height")
plt.hist(y[row].cpu().detach().numpy(), color = "red", label="original temperature")
plt.hist(output[row].cpu().detach().numpy(), color = "green", label="output temperature")
plt.title('Row ' + str(row))    
plt.legend()
plt.show()

In [None]:
''' Visualize model output '''

height = x.numpy()
plt.imshow(height, cmap='hot', interpolation='nearest')
plt.title('Height profile')
plt.colorbar()
plt.show()

original_temperature = y.numpy()
plt.imshow(original_temperature, cmap='hot', interpolation='nearest')
plt.title('Original temperature profile')
plt.colorbar()
plt.show()

output = model(x.unsqueeze(0).unsqueeze(0).to(device))
output = output.squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='hot', interpolation='nearest')
plt.title('Predicted temperature profile')
plt.colorbar()
plt.show()

In [None]:
model_and_metrics_dict = {}
model_and_metrics_dict['train_mse'] = train_mse
model_and_metrics_dict['valid_mse'] = valid_mse
model_and_metrics_dict['test_mse'] = test_mse
model_and_metrics_dict['epochs'] = n_epochs
model_and_metrics_dict['learning_rate'] = learning_rate
model_and_metrics_dict['model'] = 'cnn'
model_and_metrics_dict['time'] = times

print(model_and_metrics_dict)

In [None]:
torch.save(model.state_dict(), "../trained_models/resnet_backward.pt")

In [None]:
with open('../metrics/resnet_backward.json', 'w') as f:
    json.dump(model_and_metrics_dict, f)