In [None]:
import torch
import e2cnn
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from e2cnn.nn.modules.r2_conv.r2convolution import compute_basis_params
from e2cnn.nn.modules.r2_conv.basisexpansion_singleblock import block_basisexpansion
import warnings

from torch.utils import data
from torchsummary import summary

from tqdm import tqdm
import time

import torch.optim as optim

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Scale and Rotation Equivariant Layer
class ScaleRotEquivLayer(torch.nn.Module):
    def __init__(self, 
                 in_frames, # number of input frames
                 out_frames, # number of output frames
                 kernel_size, 
                 N, # Order of rotation group
                 scale_factors, # scaling factors applied to equivariant basis
                 first_layer = False, # whether it is the first layer
                 last_layer = False # whether it is the last layer
                ):
        super(ScaleRotEquivLayer, self).__init__()
        
        r2_act = e2cnn.gspaces.Rot2dOnR2(N = N)
        self.last_layer = last_layer
        self.first_layer = first_layer
        self.kernel_size = kernel_size
        
        if self.first_layer:
            self.feat_type_in = e2cnn.nn.FieldType(r2_act, in_frames*[r2_act.trivial_repr])
        else:
            self.feat_type_in = e2cnn.nn.FieldType(r2_act, in_frames*[r2_act.regular_repr])
            
        if self.last_layer:
            self.feat_type_hid = e2cnn.nn.FieldType(r2_act, out_frames*[r2_act.trivial_repr])
        else:
            self.feat_type_hid = e2cnn.nn.FieldType(r2_act, out_frames*[r2_act.regular_repr])
            
        
        if not last_layer:
            self.norm = e2cnn.nn.InnerBatchNorm(self.feat_type_hid)
            self.relu = e2cnn.nn.ReLU(self.feat_type_hid)
        
        # obtain the equivariant basis kernels
        grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params(kernel_size = kernel_size)
        i_repr = self.feat_type_in._unique_representations.pop()
        o_repr = self.feat_type_hid._unique_representations.pop()
        basis = self.feat_type_in.gspace.build_kernel_basis(i_repr, o_repr, sigma, rings, maximum_frequency = 5)
        block_expansion = block_basisexpansion(basis, grid, basis_filter, recompute=False)
        basis_kernels = block_expansion.sampled_basis.to(device)
        
        
        basis_kernels = basis_kernels.transpose(0,1).transpose(1,2) 
        basis_kernels = basis_kernels.reshape(basis_kernels.shape[0], 
                                              basis_kernels.shape[1],
                                              basis_kernels.shape[2], 
                                              kernel_size, kernel_size)
        # basis_kernels size: out_chs x inp_chs x number of basis x kz x kz
        
        # apply scaling transformations
        self.multiscale_basis_kernels = torch.cat([self.resize_conv_kernel(basis_kernels, f) for f in scale_factors], dim=2)
        
        # initialize weights and biases
        stdv = np.sqrt(1/(in_frames*kernel_size*kernel_size))
        self.weights = nn.Parameter(torch.ones(out_frames, in_frames, self.multiscale_basis_kernels.shape[2]).float().to(device))
        self.weights.data.uniform_(-stdv, stdv)

        self.bias = nn.Parameter(torch.zeros(out_frames*self.multiscale_basis_kernels.shape[0]).to(device))
        self.bias.data.uniform_(-stdv, stdv)
        


    def forward(self, x):
        
        # multiply equivariant basis by trainable weights  
        conv_filters = torch.einsum('pqbkl,oib->opiqkl', 
                                    self.multiscale_basis_kernels.to(self.weights.device), 
                                    self.weights) 
        
        conv_filters = conv_filters.reshape(conv_filters.shape[0]*conv_filters.shape[1],
                                            conv_filters.shape[2]*conv_filters.shape[3], 
                                            self.kernel_size, self.kernel_size)
        
        if not self.last_layer:
            out = F.conv2d(x, conv_filters, self.bias, padding = (self.kernel_size - 1)//2)
            return self.relu(e2cnn.nn.GeometricTensor(out, self.feat_type_hid)).tensor
        else:
            return F.conv2d(x, conv_filters, self.bias, padding = (self.kernel_size - 1)//2)
        
        
    def resize_conv_kernel(self,
                       kernel, # a PyTorch tensor of shape (#basis, kernel_size, kernel_size)
                       scale_factor, # scaling factors for two spatial dims
                       mode='trilinear' #interpolation mode to use for resizing
                      ):
    
        # get the original kernel size
        old_size = kernel.size(-1)
        
        # resize the kernel using bilinear interpolation
        resized_kernel = F.interpolate(kernel, scale_factor = (1, scale_factor, scale_factor), mode=mode)
        new_size = resized_kernel.size(-1)

        if scale_factor < 1:
            # adjust the kernel size to match the new size
            new_kernel = torch.zeros(kernel.shape).to(device)
            padding = (old_size - resized_kernel.shape[-1]) // 2 
            new_kernel[..., padding:padding+new_size, padding:padding+new_size] = resized_kernel
        else:
            padding = (resized_kernel.shape[-1] - old_size) // 2 
            new_kernel = resized_kernel[..., padding:padding+old_size, padding:padding+old_size]

        return new_kernel

In [None]:
# Scale and Rotation Equivariant ResNet Block
class ScaleRotEquivResBlock(torch.nn.Module):
    def __init__(self, 
                 in_frames,
                 out_frames,
                 kernel_size, 
                 N,
                 scale_factors, 
                ): 
        super(ScaleRotEquivResBlock, self).__init__()
        
        self.layer1 = ScaleRotEquivLayer(in_frames = in_frames, 
                                         out_frames = out_frames,
                                         kernel_size = kernel_size,
                                         N = N, 
                                         scale_factors = scale_factors
                                        )

        
        self.layer2 = ScaleRotEquivLayer(in_frames = out_frames, 
                                         out_frames = out_frames,
                                         kernel_size = kernel_size,
                                         N = N, 
                                         scale_factors = scale_factors
                                        )

        self.upscale = ScaleRotEquivLayer(in_frames = in_frames, 
                                         out_frames = out_frames,
                                         kernel_size = kernel_size,
                                         N = N, 
                                         scale_factors = scale_factors
                                        )

        
        self.in_frames = in_frames
        self.out_frames = out_frames
        
    def forward(self, x):
        out = self.layer1(x)
        
        # residual connection
        if self.in_frames != self.out_frames:
            out = self.layer2(out) + self.upscale(x)
        else:
            out = self.layer2(out) + x
            
        return out

In [None]:
# Scale and Rotation Equivariant ResNet
class ScaleRotEquivResNet(torch.nn.Module):
    def __init__(self, 
                 in_frames, 
                 out_frames,
                 kernel_size, 
                 N, 
                 scale_factors):
        
        super(ScaleRotEquivResNet, self).__init__()
        
        self.input_layer = ScaleRotEquivLayer(in_frames = in_frames, 
                                              out_frames = 16,
                                              kernel_size = kernel_size,
                                              N = N, 
                                              scale_factors = scale_factors,
                                              first_layer = True
                                             )
        
        self.last_layer = ScaleRotEquivLayer(in_frames = 1024, 
                                              out_frames = out_frames,
                                              kernel_size = kernel_size,
                                              N = N, 
                                              scale_factors = scale_factors,
                                              last_layer = True
                                             )
        
        layers = [self.input_layer]
        layers += [ScaleRotEquivResBlock(16, 32, kernel_size, N, scale_factors), 
                   ScaleRotEquivResBlock(32, 32, kernel_size, N, scale_factors)]
        layers += [ScaleRotEquivResBlock(32, 64, kernel_size, N, scale_factors),
                   ScaleRotEquivResBlock(64, 64, kernel_size, N, scale_factors)]
        layers += [ScaleRotEquivResBlock(64, 128, kernel_size, N, scale_factors), 
                   ScaleRotEquivResBlock(128, 128, kernel_size, N, scale_factors)]
        layers += [ScaleRotEquivResBlock(128, 256, kernel_size, N, scale_factors), 
                   ScaleRotEquivResBlock(256, 256, kernel_size, N, scale_factors)]
        layers += [ScaleRotEquivResBlock(256, 512, kernel_size, N, scale_factors), 
                   ScaleRotEquivResBlock(512, 512, kernel_size, N, scale_factors)]
        layers += [ScaleRotEquivResBlock(512, 1024, kernel_size, N, scale_factors), 
                   ScaleRotEquivResBlock(1024, 1024, kernel_size, N, scale_factors)]
        layers += [self.last_layer]
        self.model = torch.nn.Sequential(*layers)
    
    def forward(self, x):
        #BxCxHxW
        out = self.model(x)
        return out

<h2>DATA LOADER</h2>

In [None]:
class Dataset(data.Dataset):
    def __init__(self, indices, direc):
        self.list_IDs = indices
        self.direc = direc
      
    def __len__(self):
        return len(self.list_IDs)
    
    def __getitem__(self, index):
        ID = self.list_IDs[index]
        x = torch.load(self.direc + 'h_' + str(ID) + '.pt')
        y = torch.load(self.direc + 'T_' + str(ID) + '.pt')
            
        return x.float(), y.float()

In [None]:
batch_size = 4

print(batch_size)

train_direc = '../simulated_data_reg/'
valid_direc = '../simulated_data_reg/'
test_direc = '../simulated_data_reg/'

train_indices = list(range(0, 1))
valid_indices = list(range(0, 1))
test_indices = list(range(0, 1))

''' Load Data '''

train_set = Dataset(train_indices, train_direc)
valid_set = Dataset(valid_indices, valid_direc)
test_set = Dataset(test_indices, test_direc)

train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 0)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = True, num_workers = 0)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 0)

<h2>LOAD MODEL</h2>

In [None]:
model = ScaleRotEquivResNet(in_frames = 1, 
                            out_frames = 1,
                            kernel_size = 5,
                            N = 4, 
                            scale_factors = [0.5, 0.75, 1, 1.5, 2]).to(device)

In [None]:
x = torch.randn(1, 1, 200, 200).to(device)
print(model(x).shape)

In [None]:
summary(model, (1, 200, 200))

In [None]:
optimizer = 'adam'
lr = 0.01
lr_gamma = 0.1

data_dir = '../simulated_data_scale'

parameters = filter(lambda x: x.requires_grad, model.parameters())
parameters

In [None]:
if optimizer == 'adam':
    optimizer = optim.Adam(parameters, lr=lr)
    
loss_fun = torch.nn.MSELoss()

In [None]:
''' Train epoch function '''

def train_epoch(train_loader, model, optimizer, loss_function):
    train_mse = []
    for xx, yy in train_loader:
        xx = xx.to(device)
        yy = yy.to(device)
        
        xx = xx.unsqueeze(1)
        yy = yy.unsqueeze(1)
        
        # print(xx.shape, yy.shape)
        loss = 0
        ims = []
        for y in yy.transpose(0,1):
            im = model(xx)
            im = im.squeeze(1)
            # print('im: ', im.shape)
            # print('y: ', y.shape)
            im = im.unsqueeze(1)
            xx = torch.cat([xx[:, 2:], im], 1)
            loss += loss_function(im, y)
        train_mse.append(loss.item()/yy.shape[1]) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_mse = round(np.sqrt(np.mean(train_mse)),5)
    return train_mse

In [None]:
''' Eval epoch function '''

def eval_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            for y in yy.transpose(0, 1):
                # y = y.unsqueeze(1)
                im = model(xx)
                im = im.squeeze(1)
                im = im.unsqueeze(1)
                xx = torch.cat([xx[:, 2:], im], 1)
                loss += loss_function(im, y)
                ims.append(im.unsqueeze(1).cpu().data.numpy())
                
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = round(np.sqrt(np.mean(valid_mse)), 5)
    return valid_mse, preds, trues

In [None]:
''' Test epoch function '''

def test_epoch(valid_loader, model, loss_function):
    valid_mse = []
    preds = []
    trues = []
    with torch.no_grad():
        loss_curve = []
        for xx, yy in valid_loader:
            xx = xx.to(device)
            yy = yy.to(device)

            xx = xx.unsqueeze(1)
            yy = yy.unsqueeze(1)

            loss = 0
            ims = []
            
            for y in yy.transpose(0, 1):
                # y = y.unsqueeze(1)
                im = model(xx)
                im = im.squeeze(1)
                im = im.unsqueeze(1)
                xx = torch.cat([xx[:, 2:], im], 1)
                mse = loss_function(im, y)
                loss += mse
                loss_curve.append(mse.item())
                ims.append(im.unsqueeze(1).cpu().data.numpy())
           
            ims = np.concatenate(ims, axis = 1)
            preds.append(ims)
            trues.append(yy.cpu().data.numpy())
            valid_mse.append(loss.item()/yy.shape[1])
            
        loss_curve = np.array(loss_curve).reshape(-1,yy.shape[1])
        preds = np.concatenate(preds, axis = 0)  
        trues = np.concatenate(trues, axis = 0)  
        valid_mse = np.mean(valid_mse)
        loss_curve = np.sqrt(np.mean(loss_curve, axis = 0))
    return valid_mse, preds, trues, loss_curve

In [None]:
train_mse = []
valid_mse = []
test_mse = []
times = []

min_mse = 100

n_epochs = 60

In [None]:
for i in tqdm(range(n_epochs)):

    print('EPOCH: ', i+1)

    start = time.time()
    optimizer.step()

    model.train()
    print('Model trained')

    train_mse.append(train_epoch(train_loader, model, optimizer, loss_fun))
    model.eval()
    mse, _, _ = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)
    
    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1] 
        best_model = model

    end = time.time()
    
    times.append(end-start)
    
    # Early Stopping but train at least for 50 epochs
    # if (len(train_mse) > 50 and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
    #         break
            
    print('TRAIN MSE: ', train_mse[-1])
    print('VALID MSE: ', valid_mse[-1])
    print('TIME: ', end - start)
    print('----------------------------------')

test_mse, preds, trues, loss_curve = test_epoch(test_loader, best_model, loss_fun)

In [None]:
''' Plot Loss Curves '''

import matplotlib.pyplot as plt

plt.plot(train_mse, label='Train')
plt.plot(valid_mse, label='Valid')
plt.xlabel('Epoch #')
plt.ylabel('MSE')
plt.title('MSE')
plt.legend()
plt.grid()
plt.show()