<a href="https://colab.research.google.com/github/unerriar/igm-selffocus-nn/blob/main/IGM_selffocus_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Imports

Run this firstly

In [None]:
import os, os.path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

import torch.cuda
import torch.nn as nn
from itertools import chain
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader

from google.colab import drive
from google.colab import output

In [None]:
! git clone https://github.com/unerriar/igm-selffocus-nn.git

Cloning into 'igm-selffocus-nn'...
remote: Enumerating objects: 139, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 139 (delta 8), reused 0 (delta 0), pack-reused 110[K
Receiving objects: 100% (139/139), 1.93 MiB | 8.66 MiB/s, done.
Resolving deltas: 100% (50/50), done.


# Kaggle upload

Run this section to automatically upload four mode datasets of the article.

In [None]:
os.environ['KAGGLE_CONFIG_DIR'] = '/content/igm-selffocus-nn'

In [None]:
! kaggle datasets download -d 'unerriar/igmselffocus'
! unzip 'igmselffocus.zip'

#Helper classes

FieldMediumDataset: *stores data and returns it in batches.*

Workfow: *implements main train-validation routine and allows to easylly upload best saved model weights*

In [None]:
class FieldMediumDataset(Dataset):
    def __init__(self, dir, field_res, elems):
        super(FieldMediumDataset, self).__init__()

        self.field_res = field_res
        self.elements = elems
        self.dir = dir
        
        self.in_fields  = np.stack([self._extract_field('in_'+str(i)+'.bin') for i in range(*elems)])
        self.out_fields = np.stack([self._extract_field('out_'+str(i)+'.bin') for i in range(*elems)])
        self.medium = np.stack([self._extract_medium('medium_'+str(i)+'.bin') for i in range(*elems)])   

    def _extract_field(self, f):
        file_ = open(self.dir+'/'+f, 'rb')
        data  = np.fromfile(file_, dtype=np.complex64)
        file_.close()
        data_p = data[::2].reshape((self.field_res, self.field_res))
        data_m = data[1::2].reshape((self.field_res, self.field_res))
        return np.stack((np.real(data_p), np.imag(data_p), np.real(data_m), np.imag(data_m)))

    def _extract_medium(self, f):
        file_ = open(self.dir+'/'+f, 'rb')
        data  = np.fromfile(file_, dtype=np.float64).astype('float32')
        return data

    def __getitem__(self, i):
        return (self.in_fields[i], self.out_fields[i], self.medium[i])

    def __len__(self):
        return self.elements[1] - self.elements[0]

class Workflow():
    def __init__(self, model, model_name='Untitled', save_dir=None,
                 loss_fn=nn.MSELoss, optimizer=torch.optim.SGD, optim_params={'lr':.1}):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print('using ' + self.device)
        
        self.model = model.to(self.device)
        self.loss_fn = loss_fn()
        self.save_dir = save_dir
        self.optimizer = optimizer(model.parameters(), **optim_params)
        self.model_name = model_name
        self.optim_params = optim_params
        
        self.best_loss = 9999 #pseudoinfinity
    
    def train(self, n_epoch, trainloader, validloader,
              history_tracing=50, refresh_time=5, in_fields=True):
        """
        Full train-validation routine
        
        Parameters

        n_epochs : int
            number of epochs to learn
        trainloader : torch.utils.data.DataLoader
            train dataset dataloader
        validloader : torch.utils.data.DataLoader
            validation dataset dataloader
        history_tracing : int
            number of past epochs to track detailed data about.
            Does not affect model behaviour
        refresh_time : int
            number of epochs between plots and info being refreshed.
            Does not affect the model behaviour
        in_fields : 'zeros' | True | False
            True:    the model will get both input and output fields (requires 8 in_channels)
            False:   the model will get only output fields (requires 4 in_channels)
            'zeros': the model will get zero tensor instead of input fields (requires 8 in_channels)
        """
        train_loss_history = np.empty(0)
        valid_loss_history  = np.empty(0)
        
        for e in range(n_epoch):
            train_loss = self._train(trainloader, in_fields=in_fields)
            valid_loss = self._valid(validloader, in_fields=in_fields)

            train_loss_history = np.append(train_loss_history, train_loss)
            valid_loss_history = np.append(valid_loss_history, valid_loss)
            
            #loss dynamics monitoring
            hist_len = min(e+1, history_tracing)
            valid_loss_mean  = np.mean(valid_loss_history[-hist_len:])
            valid_loss_std   = np.std(valid_loss_history[-hist_len:])
            valid_loss_delta = valid_loss - valid_loss_history[-hist_len]

            #saving best model
            if valid_loss < self.best_loss:
                self.best_loss = valid_loss
                torch.save(self.model.state_dict(), self.save_dir+'/'+self.model_name+'.pth')

            #learning visualization
            if (e+1) % refresh_time == 0:
                output.clear()
                print(f'Epoch {e+1}/{n_epoch}:\nTrain loss: {train_loss:.4}, validation loss: {valid_loss:.4}')
                print(f'Validation loss during last {hist_len} epochs:')
                print(f'Mean: {valid_loss_mean:.4}, std dev: {valid_loss_std:.4}, delta: {valid_loss_delta:.6}:')
                print(f'Best validation loss: {self.best_loss:.4}')

                fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(10,7))
                ax1.plot(range(len(train_loss_history)), train_loss_history, label='train')
                ax1.plot(range(len(valid_loss_history )), valid_loss_history,  label='valid')
                ax1.set_xlabel('epoch')
                ax1.set_ylabel('loss')
                ax1.legend()
                
                ax2.plot(range(len(train_loss_history)), train_loss_history, label='train')
                ax2.plot(range(len(valid_loss_history )), valid_loss_history,  label='valid')
                ax2.set_yscale('log')
                ax2.set_xlabel('epoch')
                ax2.set_ylabel('loss (log scale)')
                ax2.legend()

                plt.show()
    
    def test(self, dataloader, in_fields=True):
        """
        Single epoch test
        """
        batch_num = len(dataloader)
        
        self.model.eval()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = self.model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = self.model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = self.model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()

        print(f'Test loss: {ep_loss / batch_num}')
            
        return ep_loss / batch_num

    def load_best(self):
        """
        Upload best saved model.
        """
        self.model.load_state_dict(torch.load(self.save_dir+'/'+self.model_name+'.pth'))

    def _train(self, dataloader, in_fields=True):
        """
        Single epoch train phase
        """
        batch_num = len(dataloader)
        
        self.model.train()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = self.model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = self.model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = self.model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()

            #backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return ep_loss / batch_num

    def _valid(self, dataloader, in_fields=True):
        """
        Single epoch valid phase
        """
        batch_num = len(dataloader)
        
        self.model.eval()

        ep_loss = 0
        for batch, (in_field, out_field, medium) in enumerate(dataloader):
            in_field, out_field, medium = in_field.to(self.device), out_field.to(self.device), medium.to(self.device)

            #error computation
            if in_fields == 'zeros':
                predict = self.model(torch.cat((0*in_field, out_field), axis=1))
            elif in_fields:
                predict = self.model(torch.cat((in_field, out_field), axis=1))
            else:
                predict = self.model(out_field)
            loss = self.loss_fn(predict, medium)
            ep_loss += loss.item()
    
        return ep_loss / batch_num

#Network

General artificial neural network architecture.

In [None]:
class Network(nn.Module):
    def __init__(self, in_channels=4, hidden_neurons=(4, 8, 16, 40, 60), cact=nn.Tanh, lact=nn.ReLU):
        super(Network, self).__init__()
                
        self.convolutions = nn.Sequential(
            nn.Conv2d(in_channels, hidden_neurons[0], 4, 4), #64
            cact(),
            nn.Conv2d(hidden_neurons[0], hidden_neurons[1], 4, 4), #16
            cact(),
            nn.Conv2d(hidden_neurons[1], hidden_neurons[2], 4, 4), #4
            cact(),
        )
        self.flatten = nn.Flatten()
        self.linear  = nn.Sequential(
            nn.Linear(hidden_neurons[2] * 4**2, hidden_neurons[3]),
            lact(),
            nn.Linear(hidden_neurons[3], hidden_neurons[4]),
            lact(),
            nn.Linear(hidden_neurons[4], 4)
        )

    def forward(self, x):
        x = self.convolutions(x)
        x = self.flatten(x)
        return self.linear(x)

#Dataloader construction

In [None]:
field_resolution = 256 #spatial resolution of the input and output fields
batch_size = 256
total_n = 2048 #total number of dataset samples
valid_n = 512  #number of samples in the validation subset
test_n  = 256  #number of samples in the test subset
train_n = total_n - valid_n - test_n #number of samples in the train subset
field_mode = 'Gaussian' # Gaussian | Supergaussian | Laguerre | Poincare
dir = field_mode + 'Dataset'

test_dataset  = FieldMediumDataset(dir, field_resolution, [0, test_n])
valid_dataset = FieldMediumDataset(dir, field_resolution, [test_n, test_n+valid_n])
train_dataset = FieldMediumDataset(dir, field_resolution, [test_n+valid_n, total_n])

test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
valid_dataloader  = DataLoader(valid_dataset,  batch_size, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

#Training

This section is optional if you want to experiment with training parameters.

In [None]:
#making model deterministic
torch.manual_seed(1)

in_fields_mode = True  #True  to give the input field profiles to the model
                       #False to use model without input field profiles
                       #'zeros' to use model with field profiles but set their channels to zero.
in_channels = 8 if in_fields_mode else 4

model = Network(
    in_channels = in_channels, 
    hidden_neurons = (4, 8, 16, 60, 40),
    cact = nn.Tanh,
    lact = nn.ReLU
)
optim_params = {'lr':0.0005}

workflow = Workflow(
    model, field_mode,
    save_dir = '/content/igm-selffocus-nn',
    optimizer = torch.optim.Adam,
    optim_params = optim_params,
)

workflow.train(
    n_epoch = 2500, 
    trainloader = train_dataloader,
    validloader  = valid_dataloader,
    history_tracing = 20,
    refresh_time    = 20,
    in_fields = in_fields_mode
)
workflow.test(
    test_dataloader,
    in_fields = in_fields_mode
)

#Uploading pretrained networks

With this section you can upload the model, pretrained by the authors.

In [None]:
#making model deterministic
torch.manual_seed(1)

in_fields_mode = False #True  to give the input field profiles to the model
                       #False to use model without input field profiles
                       #'zeros' to use model with field profiles but set their channels to zero.
in_channels = 8 if in_fields_mode else 4

model = Network(
    in_channels = in_channels, 
    hidden_neurons = (4, 8, 16, 40, 60),
    cact = nn.Tanh,
    lact = nn.ReLU
)
optim_params = {'lr':.0005}

workflow = Workflow(
    model, field_mode,
    save_dir = '/content/igm-selffocus-nn',
    optimizer = torch.optim.Adam,
    optim_params = optim_params,
)

workflow.load_best()
workflow.test(
    test_dataloader,
    in_fields = in_fields_mode
)

#Error statistics

In [None]:
for dataloader, datatype in [(train_dataloader, '_train'),
                             (valid_dataloader, '_valid'),
                             (test_dataloader,  '_test')]:

    for idx, data in enumerate(dataloader):
        if idx == 0:
            pred  = workflow.model(torch.cat((data[0], data[1]), axis=1).to('cuda')).cpu().detach().numpy()
            truth = data[2].detach().numpy()
        else:
            pred  = np.concatenate((pred, workflow.model(torch.cat((data[0], data[1]), axis=1).to('cuda')).cpu().detach().numpy()))
            truth = np.concatenate((truth, data[2].detach().numpy()))

    pred_file = open(field_mode+datatype+'_pred.bin', 'wb')
    truth_file = open(field_mode+datatype+'_truth.bin', 'wb')
    pred.tofile(pred_file)  
    truth.tofile(truth_file)    