# import packages

In [None]:
from torch import nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

import xarray as xr
import time

# define Unet 

In [8]:
class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(out_channels),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Upsample(scale_factor=2, mode='nearest'),
                    torch.nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(out_channels),
#                     torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) ## padding?
                    )
            return  block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding = 1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    torch.nn.SELU(),
                    torch.nn.BatchNorm2d(out_channels),
                    )
            return  block
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=32)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(32, 64)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(64, 128)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding = 1),
                            torch.nn.SELU(),
                            torch.nn.BatchNorm2d(128),
                            torch.nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding = 1),
                            torch.nn.SELU(),
                            torch.nn.BatchNorm2d(128),
                            torch.nn.Upsample(scale_factor=2, mode='nearest'),
                            torch.nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding = 1),
                            torch.nn.SELU(),
                            torch.nn.BatchNorm2d(128),
                            )
        # Decode
        self.conv_decode3 = self.expansive_block(256, 128, 64)
        self.conv_decode2 = self.expansive_block(128, 64, 32)
        self.final_layer = self.final_block(64, 32, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) ### false should work?
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)
        return  final_layer

# initialize model and optimizer

In [9]:
model = UNet(in_channel = 1, out_channel = 1)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

# load training and testing data

In [None]:
Vort_0_train = np.load('vort_train.npy')
Vort_0_test = np.load('vort_test.npy')

Eta_train = np.load('Eta_train.npy')
Eta_test = np.load('Eta_test.npy')


In [None]:
## The format for input is batch x channel x width x height.
## Even though we only have 1 channel, it is still necessary to have an dimension for it.

xTrain = np.expand_dims(Eta_train, axis=1)
xTest = np.expand_dims(Eta_test, axis=1)

yTrain =  np.expand_dims(Vort_0_train, axis=1)
yTest =  np.expand_dims(Vort_0_test, axis=1)


# embed data into pytorch data loader

In [None]:
class MyDataset(Dataset):

    def __init__(self, X_path, y_path):

        self.X = torch.from_numpy(X_path).float()
        self.y = torch.from_numpy(y_path).float()
    
    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
train_dataset = MyDataset(X_path = xTrain, y_path = yTrain)
val_dataset = MyDataset(X_path = xTest, y_path = yTest)

In [None]:
batch_size = 32
log_interval = 500

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True)

# training and testing function

In [None]:
train_loss = []
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.mse_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            train_loss.append(loss.item())

In [None]:
validation_accuracy = []
def validation():
    model.eval()
    validation_loss = 0
    count = 0
    for data, target in val_loader:
        count += 1
        output = model(data)
        validation_loss += F.mse_loss(output, target, reduction='mean').item() # sum up batch loss

    validation_loss /= count
    print('\nValidation set: Average loss: {:.4f}\n'.format(
        validation_loss))
    validation_accuracy.append(validation_loss)


# train the model

In [None]:
torch.set_num_threads(16)
epochs = 60


for epoch in range(1, epochs + 1):
    train(epoch)
    validation()


# save and load model

In [None]:
## save
torch.save(model.state_dict(), './models/vort_cs')

In [None]:
## load
model = UNet(in_channel = 1, out_channel = 1)
model.load_state_dict(torch.load('./models/vort_cs'))