<a href="https://colab.research.google.com/github/unerriar/igm-selffocus-nn/blob/main/IGM_selffocus_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Imports and datasets uploadig

In [None]:
import os, os.path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

import torch.cuda
import torch.nn as nn
from itertools import chain
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader

from google.colab import drive
from google.colab import output

In [None]:
! git clone https://github.com/unerriar/igm-selffocus-nn.git

Cloning into 'igm-selffocus-nn'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 36 (delta 9), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (36/36), done.


In [43]:
os.environ['KAGGLE_CONFIG_DIR'] = '/content/igm-selffocus-nn'

In [None]:
! kaggle datasets download -d unerriar/igmselffocus
! unzip 'igmselffocus.zip'

In [None]:
! unzip 'igmselffocus.zip'

#Helper classes

FieldMediumDataset: *stores data and returns it in batches.*

Workfow: *implements main train-test routine and allows to easylly upload best saved model weights*

In [None]:
class FieldMediumDataset(Dataset):
    def __init__(self, dir, field_res, elems):
        super(FieldMediumDataset, self).__init__()

        self.field_res = field_res
        self.elements = elems
        self.dir = dir
        
        self.in_fields  = np.stack([self._extract_field('in_'+str(i)+'.bin') for i in range(*elems)])
        self.out_fields = np.stack([self._extract_field('out_'+str(i)+'.bin') for i in range(*elems)])
        self.medium = np.stack([self._extract_medium('medium_'+str(i)+'.bin') for i in range(*elems)])   

    def _extract_field(self, f):
        file_ = open(self.dir+'/'+f, 'rb')
        data  = np.fromfile(file_, dtype=np.complex64)
        file_.close()
        data_p = data[::2].reshape((self.field_res, self.field_res))
        data_m = data[1::2].reshape((self.field_res, self.field_res))
        return np.stack((np.real(data_p), np.imag(data_p), np.real(data_m), np.imag(data_m)))

    def _extract_medium(self, f):
        file_ = open(self.dir+'/'+f, 'rb')
        data  = np.fromfile(file_, dtype=np.float64).astype('float32')
        return data

    def __getitem__(self, i):
        return (self.in_fields[i], self.out_fields[i], self.medium[i])

    def __len__(self):
        return self.elements[1] - self.elements[0]

class Workflow():
    def __init__(self, model, model_name='Untitled', save_dir=None,
                 loss_fn=nn.MSELoss, optimizer=torch.optim.SGD, optim_params={'lr':.1}):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print('using ' + self.device)
        
        self.model = model.to(self.device)
        self.loss_fn = loss_fn()
        self.save_dir = save_dir
        self.optimizer = optimizer(model.parameters(), **optim_params)
        self.model_name = model_name
        self.optim_params = optim_params
        
        self.best_loss = 9999 #pseudoinfinity
    
    def train(self, n_epoch, trainloader, testloader,
              history_tracing=50, refresh_time=5, in_fields=True):
        """
        Full train-test routine
        
        Parameters

        n_epochs : int
            number of epochs to learn
        trainloader : torch.utils.data.DataLoader
            train dataset dataloader
        testloader : torch.utils.data.DataLoader
            test dataset dataloader
        history_tracing : int
            number of past epochs to track detailed data about.
            Does not affect model behaviour
        refresh_time : int
            number of epochs between plots and info being refreshed.
            Does not affect the model behaviour
        in_fields : 'zeros' | True | False
            True:    the model will get both input and output fields (requires 8 in_channels)
            False:   the model will get only output fields (requires 4 in_channels)
            'zeros': the model will get zero tensor instead of input fields (requires 8 in_channels)
        """
        train_loss_history = np.empty(0)
        test_loss_history  = np.empty(0)
        
        for e in range(n_epoch):
            train_loss = self._train(trainloader, in_fields=in_fields)
            test_loss  = self._test(testloader, in_fields=in_fields)

            train_loss_history = np.append(train_loss_history, train_loss)
            test_loss_history  = np.append(test_loss_history,  test_loss)
            
            #loss dynamics monitoring
            hist_len = min(e+1, history_tracing)
            test_loss_mean  = np.mean(test_loss_history[-hist_len:])
            test_loss_std   = np.std(test_loss_history[-hist_len:])
            test_loss_delta = test_loss - test_loss_history[-hist_len]

            #saving best model
            if test_loss < self.best_loss:
                self.best_loss = test_loss
                torch.save(model.state_dict(), self.save_dir+'/'+self.model_name+'.pth')

            #learning visualization
            if (e+1) % refresh_time == 0:
                output.clear()
                print(f'Epoch {e+1}/{n_epoch}:\nTrain loss: {train_loss:.4}, test loss: {test_loss:.4}')
                print(f'Test loss during last {hist_len} epochs:')
                print(f'Mean: {test_loss_mean:.4}, std dev: {test_loss_std:.4}, delta: {test_loss_delta:.6}:')
                print(f'Best test loss: {self.best_loss:.4}')

                fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(10,7))
                ax1.plot(range(len(train_loss_history)), train_loss_history, label='train')
                ax1.plot(range(len(test_loss_history )), test_loss_history,  label='test')
                ax1.set_xlabel('epoch')
                ax1.set_ylabel('loss')
                ax1.legend()
                
                ax2.plot(range(len(train_loss_history)), train_loss_history, label='train')
                ax2.plot(range(len(test_loss_history )), test_loss_history,  label='test')
                ax2.set_yscale('log')
                ax2.set_xlabel('epoch')
                ax2.set_ylabel('loss (log scale)')
                ax2.legend()

                plt.show()
    
    def validate(self, dataloader, in_fields=True):
        """
        Single epoch validation
        """
        batch_num = len(dataloader)
        
        self.model.eval()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()

        print(f'Validation loss: {ep_loss / batch_num}')
            
        return ep_loss / batch_num

    def load_best(self):
        """
        Upload best saved model.
        """
        model.load_state_dict(torch.load(self.save_dir+'/'+self.model_name+'.pth'))

    def _train(self, dataloader, in_fields=True):
        """
        Single epoch train phase
        """
        batch_num = len(dataloader)
        
        self.model.train()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()

            #backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return ep_loss / batch_num

    def _test(self, dataloader, in_fields=True):
        """
        Single epoch test phase
        """
        batch_num = len(dataloader)
        
        self.model.eval()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()
    
        return ep_loss / batch_num

#Network

General artificial neural network architecture.

In [None]:
class Network(nn.Module):
    def __init__(self, in_channels=4, hidden_neurons=(16, 32, 16, 64, 32), cact=nn.Tanh, lact=nn.ReLU):
        super(Network, self).__init__()
                
        self.convolutions = nn.Sequential(
            nn.Conv2d(in_channels, hidden_neurons[0], 4, 4), #64
            cact(),
            nn.Conv2d(hidden_neurons[0], hidden_neurons[1], 4, 4), #16
            cact(),
            nn.Conv2d(hidden_neurons[1], hidden_neurons[2], 4, 4), #4
            cact(),
        )
        self.flatten = nn.Flatten()
        self.linear  = nn.Sequential(
            nn.Linear(hidden_neurons[2] * 4**2, hidden_neurons[3]),
            lact(),
            nn.Linear(hidden_neurons[3], hidden_neurons[4]),
            lact(),
            nn.Linear(hidden_neurons[4], 4)
        )

    def forward(self, x):
        x = self.convolutions(x)
        x = self.flatten(x)
        return self.linear(x)

#Dataloader construction

In [70]:
field_resolution = 256 #spatial resolution of the input and output fields
batch_size = 64
total_n = 1024 #total number of dataset samples
test_n  = 256  #number of samples in the test subset
valid_n = 128  #number of samples in the validation subset
train_n = total_n - test_n - valid_n #number of samples in the train subset
dir = 'SingularDataset/SingularDataset'

valid_dataset = FieldMediumDataset(dir, field_resolution,  [0, valid_n])
test_dataset  = FieldMediumDataset(dir, field_resolution,  [valid_n, valid_n+test_n])
train_dataset = FieldMediumDataset(dir, field_resolution,  [valid_n+test_n, total_n])

valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=True)
test_dataloader  = DataLoader(test_dataset,  batch_size, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

#Training

This section is optional if you want to experiment with training parameters.

In [None]:
#making model deterministic
torch.manual_seed(1)

in_fields_mode = False #True  to give the input field profiles to the model
                       #False to use model without input field profiles
                       #'zeros' to use model with field profiles but set their channels to zero.
in_channels = 8 if in_fields_mode else 4

model = Network(
    in_channels = in_channels, 
    hidden_neurons = (16, 32, 16, 64, 32),
    cact = nn.Tanh,
    lact = nn.ReLU
)
optim_params = {'lr':.0003}

workflow = Workflow(
    model, 'Singular',
    save_dir = '/content',
    optimizer = torch.optim.Adam,
    optim_params = optim_params,
)

workflow.train(
    n_epoch = 500, 
    trainloader = train_dataloader,
    testloader = test_dataloader,
    history_tracing = 20,
    refresh_time = 20,
    in_fields = in_fields_mode
)

workflow.load_best()
workflow.validate(
    valid_dataloader,
    in_fields = in_fields_mode
)

#Uploading pretrained networks

With this section you can upload the model, pretrained by the authors.

In [None]:
#making model deterministic
torch.manual_seed(1)

in_fields_mode = False #True  to give the input field profiles to the model
                       #False to use model without input field profiles
                       #'zeros' to use model with field profiles but set their channels to zero.
in_channels = 8 if in_fields_mode else 4

model = Network(
    in_channels = in_channels, 
    hidden_neurons = (16, 32, 16, 64, 32),
    cact = nn.Tanh,
    lact = nn.ReLU
)
optim_params = {'lr':.0002}

workflow = Workflow(
    model, 'Singular',
    save_dir = '/content/igm-selffocus-nn',
    optimizer = torch.optim.Adam,
    optim_params = optim_params,
)

workflow.load_best()
workflow.validate(
    valid_dataloader,
    in_fields = in_fields_mode
)