# Import statements

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from numpy import load, savez
import nets

# Folder for saving results

In [None]:
# Path of the folder containing this jupiter notebook
folder = '/Users/admin/Documents/GitHub/Parametric-inverse-problem-topology/'

results_folder = folder + 'results/'
if not(os.path.exists(results_folder)):
        os.mkdir(results_folder)

results_folder = folder + 'results/Complete/'
if not(os.path.exists(results_folder)):
        os.mkdir(results_folder)
        
weights_folder = results_folder + 'weights/'
if not(os.path.exists(weights_folder)):
        os.mkdir(weights_folder)

# Neural networks
Multilayer perceptrons composed by 6 hidden layers with 3000 neurons each. The activation function is the ReLU. Dropout is applied before each hidden layer

In [None]:
dimensions = [60, 3000, 3000, 3000, 3000, 3000, 3000, 1]

device = "cpu"
model = nets.net_loop_complete(dimensions).to(device)

# Dataset

In [None]:
data = load(folder + 'data/dataset_loop_complete.npz')

# Training set
vis_train = data['vis_train_noisy']
max_vis_train = np.max(vis_train, axis = 1)
mmax_vis_train = np.expand_dims(max_vis_train, axis = 1)
mmax_vis_train = np.repeat(mmax_vis_train, 60., axis=1)
vis_train = vis_train / mmax_vis_train
vis_train = torch.from_numpy(vis_train)
vis_train = vis_train.to(device)

xc_train = data['xc_train']
xc_train = (np.expand_dims(xc_train, 1) + 30.) / 60.

yc_train = data['yc_train']
yc_train = (np.expand_dims(yc_train, 1) + 30.) / 60.

fwhm_train = data['fwhm_train']
fwhm_train = np.expand_dims(fwhm_train, 1) / 15.

flux_train = data['flux_train'] / max_vis_train
flux_train = np.expand_dims(flux_train, 1)

ecc_train   = data['ecc_train']
alpha_train = data['alpha_train']/180. * np.pi
c_train     = data['c_train']

xx_train = ecc_train * (1. + c_train * np.sin(alpha_train)) * np.cos(2.*alpha_train)
yy_train = ecc_train * (1. + c_train * np.sin(alpha_train)) * np.sin(2.*alpha_train)
zz_train = ecc_train * np.cos(alpha_train) * c_train

xx_train  = (np.expand_dims(xx_train, 1) + 6.)/13.
yy_train  = (np.expand_dims(yy_train, 1) + 6.)/13.
zz_train  = np.expand_dims(zz_train, 1) + 0.5
ecc_train = np.expand_dims(ecc_train, 1) / 5.

target_train = np.concatenate((xc_train, yc_train, flux_train, fwhm_train, ecc_train, \
                               xx_train, yy_train, zz_train), axis=1 )
target_train = torch.from_numpy(target_train)
target_train = target_train.to(device)

# Validation set
vis_valid = data['vis_valid_noisy']
max_vis_valid = np.max(vis_valid, axis = 1)
mmax_vis_valid = np.expand_dims(max_vis_valid, axis = 1)
mmax_vis_valid = np.repeat(mmax_vis_valid, 60., axis=1)
vis_valid = vis_valid / mmax_vis_valid
vis_valid = torch.from_numpy(vis_valid)
vis_valid = vis_valid.to(device)

xc_valid = data['xc_valid']
xc_valid = (np.expand_dims(xc_valid, 1) + 30.) / 60.

yc_valid = data['yc_valid']
yc_valid = (np.expand_dims(yc_valid, 1) + 30.) / 60.

fwhm_valid = data['fwhm_valid']
fwhm_valid = np.expand_dims(fwhm_valid, 1) / 15.

flux_valid = data['flux_valid'] / max_vis_valid
flux_valid = np.expand_dims(flux_valid, 1)

ecc_valid   = data['ecc_valid']
alpha_valid = data['alpha_valid']/180. * np.pi
c_valid     = data['c_valid']

xx_valid = ecc_valid * (1. + c_valid * np.sin(alpha_valid)) * np.cos(2.*alpha_valid)
yy_valid = ecc_valid * (1. + c_valid * np.sin(alpha_valid)) * np.sin(2.*alpha_valid)
zz_valid = ecc_valid * np.cos(alpha_valid) * c_valid

xx_valid  = (np.expand_dims(xx_valid, 1) + 6.)/13.
yy_valid  = (np.expand_dims(yy_valid, 1) + 6.)/13.
zz_valid  = np.expand_dims(zz_valid, 1) + 0.5
ecc_valid = np.expand_dims(ecc_valid, 1) / 5. 

target_valid = np.concatenate((xc_valid, yc_valid, flux_valid, fwhm_valid, ecc_valid, \
                               xx_valid, yy_valid, zz_valid), axis=1 )
target_valid = torch.from_numpy(target_valid)
target_valid = target_valid.to(device)

# Training

In [None]:
batch_size = 100
num_epochs_train = 1000
learning_rate = 0.0001
#scheduler parameters
step_size = 500
gamma = 0.8
#number of batches in an epoch
n_batch = int(vis_train.shape[0] / batch_size)

train_dataset = nets.VisDatasetTrain(vis_train.float(), target_train.float())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

#initialization of the arrays containing the values of the loss on training and validation set
train_loss = np.zeros(num_epochs_train)
valid_loss = np.zeros(num_epochs_train)

# Train the model
print('Train on {} data'.format(vis_train.shape[0]))
print('Validate on {} data'.format(vis_valid.shape[0]))
# Train, validation...
total_step = len(train_loader)
for epoch in range(num_epochs_train):
    model.train()
    loss_epoch = np.zeros(n_batch)
    running_loss = 0
    for i, (vis, labels) in enumerate(train_loader):

        # clean the gradients
        optimizer.zero_grad()

        # forward pass
        outputs = model(vis)
        
        loss = criterion(outputs, labels)

        # backpropagation
        loss.backward()

        # optimise
        optimizer.step()

        # update the loss for visualisation/print
        loss_epoch[i] = loss.item()
        running_loss += loss.item()

    train_loss[epoch] = np.mean(loss_epoch)
    
    # for each epoch, compute also validation loss
    model.eval()

    outputs = model(vis_valid.float())
    loss = criterion(outputs, target_valid.float())
    valid_loss[epoch] = loss.item()

    print('Epoch [{}/{}], Train loss: {:.7f}, Valid loss: {:.7f}'
          .format(epoch + 1, num_epochs_train, train_loss[epoch], valid_loss[epoch]))

    if (epoch+1) % 50 == 0:
        torch.save(model.state_dict(), weights_folder + '/nn_complete_' + str(epoch + 1) + '.pt')

    scheduler.step()

savez(results_folder + '/history', train_loss = train_loss, valid_loss = valid_loss)