# Library


In [90]:
import os
import matplotlib.pyplot as plt
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.utils import shuffle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from resources.data_utils import peak_to_ring_num, scale_dataset, get_RFP_type 
from resources.plot_utils import plot_R2, plot_profiles_peak_1channel

In [92]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [104]:
%run MLP_VAE_core.ipynb
%run VAE_core.ipynb

# Dataset

In [None]:
# Read in training data
data_dir = "/data/"
output_dir = data_dir + 'model/'
os.makedirs(output_dir, exist_ok=True)

# Load 1D profiles
data_array = np.load(os.path.join(data_dir, 'all_outputs.npy'))
data_array = data_array.reshape([-1, 3, 201])
RFP_data = data_array[:, 1, :].squeeze()

# Normalize RFP
normalized_RFP = RFP_data / RFP_data.max(axis=1, keepdims=True)
normalized_RFP = normalized_RFP.reshape([-1, 1, 201])

# Parameters 
params_array = np.load(os.path.join(data_dir + 'all_params.npy' ))
scaling_ranges = {
    'DC': [0.5e-3, 12.5e-2],
    'aC': [0.1, 1],
    'aA': [100, 100000],
    'aT': [10, 8000],
    'aL': [5, 500],
    'dA': [0.001, 0.1],
    'dT': [3, 300],
    'dL': [0.144, 14.4],
    'alpha': [1, 5],
    'beta':  [2, 2000],
    'Kphi':  [1, 10],
    'N0':  [200000, 5000000]
}
scaling_options = ['exp','linear','exp','exp','exp',
                   'exp','exp','exp','linear','exp',
                   'linear','linear']
all_params = ['DC', 'DN', 'DA', 'DB', 'aC','aA', 'aB', 'aT', 'aL', 'bN','dA', 'dB', 'dT', 'dL', 'k1', 
              'k2', 'KN', 'KP', 'KT', 'KA', 'KB', 'alpha','beta', 'Cmax', 'a', 'b', 'm', 'n', 'Kphi', 'l', 
              'N0', 'G1','G2','G3','G4','G5','G6','G7','G8','G9','G10','G11','G12', 'G13','G14',
             'G15','G16','G17','G18', 'G19', 'alpha_p','beta_p', 'seeding_v']
sceening_params = ['DC',  'aC', 'aA', 'aT', 'aL', 'dA','dT', 'dL', 'alpha','beta','Kphi', 'N0']
selected_param_idx = [all_params.index(param) for param in sceening_params]
params_array = params_array[:, selected_param_idx]

# Pattern types
pattern_types_array = np.load(os.path.join(data_dir, 'all_types.npy'))
RFP_types_array = pattern_types_array[:, 1]

# Check size
print('---------------------------------------------')
print(f"RFP profiles: {normalized_RFP.shape}")
print(f"Parameters: {params_array.shape}")
print(f"Pattern types:  {RFP_types_array.shape}")

# Dataloader

In [None]:
# Set up hyperparameters
batch_size = 32
latent_dim = 16
latent_channel = 16 
seq_length = normalized_RFP.shape[2] 
input_dim = params_array.shape[1]

# Set everything to float32
normalized_RFP = normalized_RFP.astype(np.float32)
params_array = params_array.astype(np.float32)

# Shuffle
normalized_RFP = shuffle(normalized_RFP, random_state=25)
params_array = shuffle(params_array, random_state=25)
RFP_types_array = shuffle(RFP_types_array, random_state=25)


# Split datasets
ratio = 0.1
train_data, test_data, train_indices, test_indices = train_test_split(normalized_RFP, range(len(normalized_RFP)), test_size=ratio, random_state=42, shuffle=False)
train_params, test_params, train_indices, test_indices = train_test_split(params_array, range(len(params_array)), test_size=ratio, random_state=42, shuffle=False)
train_types, test_types, train_indices, test_indices = train_test_split(RFP_types_array, range(len(RFP_types_array)), test_size=ratio, random_state=42, shuffle=False)

train_data, valid_data, train_indices, valid_indices = train_test_split(train_data, range(len(train_data)), test_size=ratio, random_state=42, shuffle=False)
train_params, valid_params, train_indices, valid_indices = train_test_split(train_params, range(len(train_params)), test_size=ratio, random_state=42, shuffle=False)
train_types, valid_types, train_indices, valid_indices = train_test_split(train_types, range(len(train_types)), test_size=ratio, random_state=42, shuffle=False)

# Inputs
norm_train_params = scale_dataset(train_params, scaling_ranges, scaling_options)
norm_valid_params = scale_dataset(valid_params, scaling_ranges, scaling_options)
norm_test_params = scale_dataset(test_params, scaling_ranges, scaling_options)

# Dataset
train_dataset = CustomDataset(norm_train_params, train_data, train_types)
valid_dataset = CustomDataset(norm_valid_params, valid_data, valid_types)
test_dataset = CustomDataset(norm_test_params, test_data, test_types)

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print(' -------------------- Train set -------------------- ')
print('Data: ', train_data.shape)
print('Parameters: ', train_params.shape)
print('Pattern type: ', train_types.shape)

print(' -------------------- Validation set -------------------- ')
print('Data: ', valid_data.shape)
print('Parameters: ', valid_params.shape)
print('Pattern type: ', valid_types.shape)

print(' -------------------- Test set -------------------- ')
print('Data: ', test_data.shape)
print('Parameters: ', test_params.shape)
print('Pattern type: ', test_types.shape)

print(' -------------------- Check params ranges -------------------- ')
# check parameter ranges
upper_lims = np.max(norm_train_params, axis=0)
lower_lims = np.min(norm_train_params, axis=0)
for i in range(len(upper_lims)):
    print(sceening_params[i], ' -- Upper limits -- ',upper_lims[i], 'Lower limits -- ',lower_lims[i])

# Train MLP

## Load trained VAE

In [None]:
vae = VAE(seq_length, latent_dim, latent_channel)
filename = output_dir + 'VAE.pt'
print(filename)
vae.load_state_dict(torch.load(filename))
vae = vae.to(device)
vae.eval() 

## Train

In [None]:
# Initiate model
model = CombinedModel(input_dim, latent_dim, vae.decoder, 42)

# # Load trained model if exist
# filename = output_dir + 'MLP_VAE.pt'
# model.load_state_dict(torch.load(filename))

# Freeze VAE parameters
for param in model.decoder.parameters():
    param.requires_grad = False
print(model)
model = model.to(device)
print("The model has", count_parameters(model), "trainable parameters")

# Training setup
lr = 1e-3           
min_lr = 1e-7     
epochs = 1000
gamma = 0.98
weight_decay = 1e-5
alpha = 0 # 2e-5 

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.mlp.parameters(), lr= lr, weight_decay=weight_decay)  

# Early stopping 
best_valid_loss = np.inf  
epochs_no_improve = 0  
patience = 30 

# Warm up
warmup_epochs = 10
def warmup_scheduler(epoch):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        return 1.0
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_scheduler)
scheduler2 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


# Training loop
train_loss_history = []
valid_loss_history = []
test_loss_history = []

for epoch in range(epochs):

    train_loss = train_combined(model, train_loader, optimizer, criterion, alpha, device)
    valid_loss = validate_combined(model, valid_loader, criterion, alpha, device)
    test_loss = test_combined(model, test_loader, criterion, alpha, device)
    
    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)
    test_loss_history.append(test_loss)

    # Clamp minimum learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = max(param_group['lr'], min_lr)

    # Print loss
    if (epoch + 1) % 5 == 0: # every 5 epochs
        print('Epoch: {} Train: {:.7f}, Valid: {:.7f}, Test: {:.7f}, Lr:{:.8f}'.format(epoch + 1, train_loss_history[epoch], valid_loss_history[epoch], test_loss_history[epoch], param_group['lr']))
    
    # Update learning rate
    if epoch < warmup_epochs:
        scheduler1.step()
    else:
        scheduler2.step()

    # Check for early stopping
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_no_improve = 0  # Reset the counter
    else:
        epochs_no_improve += 1  # Increment the counter

    if epochs_no_improve == patience:
        print('Early stopping!')
        break  # Exit the loop

In [None]:
# Plotting the loss history
plt.figure(figsize=(6, 3))
plt.semilogy(train_loss_history, label='Training')
plt.semilogy(valid_loss_history, label='Validation')
plt.semilogy(test_loss_history, label='Testing')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Save

In [None]:
model_path = os.path.join(output_dir,"MLP_VAE.pt")
print(model_path)
torch.save(model.state_dict(), model_path)

# Training results


In [None]:
train_pred = []
train_ori = []
test_pred = []
test_ori = []

model.eval()
with torch.no_grad():
    for params, data, _ in train_loader:
        data = data.to(device)
        params = params.to(device)
        reconstruction, mean, logvar = model(params)
        train_pred.append(reconstruction.cpu().numpy())
        train_ori.append(data.cpu().numpy())

    for params, data, _ in test_loader:
        data = data.to(device)
        params = params.to(device)
        reconstruction, mean, logvar = model(params)
        test_pred.append(reconstruction.cpu().numpy())
        test_ori.append(data.cpu().numpy())

# Concatenate
train_pred = np.concatenate(train_pred)#.flatten()
train_ori = np.concatenate(train_ori)#.flatten()
test_pred = np.concatenate(test_pred)#.flatten()
test_ori = np.concatenate(test_ori)#.flatten()

filename = output_dir + 'VAE_train_R2.png'
plot_R2(train_ori, train_pred, filename)
filename = output_dir + 'VAE_test_R2.png'
plot_R2(test_ori, test_pred, filename)

# Accuracy by class

In [None]:
def R2_acc_by_type(ori_types, pred_types, ori_data, pred_data, target_class):

    # Compute accuracy
    filtered_pred_types = [(p, l) for p, l in zip(pred_types, ori_types) if l == target_class]
    correct_predictions = sum(p == l for p, l in filtered_pred_types)
    total_predictions = len(filtered_pred_types)    

    if total_predictions > 0:
        accuracy = correct_predictions / total_predictions 
    else:
        accuracy = 0


    # Compute R2
    indices = [i for i, t in enumerate(ori_types) if t == target_class]

    ori_data_selected = ori_data[indices]
    pred_data_selected = pred_data[indices]
    
    if len(ori_data_selected) != 0:
        ori_data_selected = ori_data_selected.flatten()
        pred_data_selected = pred_data_selected.flatten()

        r2 = r2_score(ori_data_selected, pred_data_selected)
    else:
        r2 = 0
        
    return r2, accuracy, correct_predictions, total_predictions


In [None]:
print(' ------------------------- Train ------------------------- ')
_, ori_types = get_RFP_type(train_ori)
_, pred_types = get_RFP_type(train_pred)
ori_types = peak_to_ring_num(np.array(ori_types))
pred_types = peak_to_ring_num(np.array(pred_types))


# Get R2 and acc for each pattern type
R2_1, acc_1, correct_1, total_1 = R2_acc_by_type(ori_types, pred_types, train_ori, train_pred, 1)
R2_2, acc_2, correct_2, total_2 = R2_acc_by_type(ori_types, pred_types, train_ori, train_pred, 2)
R2_3, acc_3, correct_3, total_3 = R2_acc_by_type(ori_types, pred_types, train_ori, train_pred, 3)
R2_4, acc_4, correct_4, total_4 = R2_acc_by_type(ori_types, pred_types, train_ori, train_pred, 4)

print(' 1 ring --- R2: ', R2_1, ' , acc: ', acc_1 , ', correct#: ', correct_1, ', total#: ', total_1)
print(' 2 ring --- R2: ', R2_2, ' , acc: ', acc_2 , ', correct#: ', correct_2, ', total#: ', total_2)
print(' 3 ring --- R2: ', R2_3, ' , acc: ', acc_3 , ', correct#: ', correct_3, ', total#: ', total_3)
print(' 4 ring --- R2: ', R2_4, ' , acc: ', acc_4 , ', correct#: ', correct_4, ', total#: ', total_4)

print(' ------------------------- Test ------------------------- ')
_, ori_types = get_RFP_type(test_ori)
_, pred_types = get_RFP_type(test_pred)
ori_types = peak_to_ring_num(np.array(ori_types))
pred_types = peak_to_ring_num(np.array(pred_types))

# Get R2 and acc for each pattern type
R2_1, acc_1, correct_1, total_1 = R2_acc_by_type(ori_types, pred_types, test_ori, test_pred, 1)
R2_2, acc_2, correct_2, total_2 = R2_acc_by_type(ori_types, pred_types, test_ori, test_pred, 2)
R2_3, acc_3, correct_3, total_3 = R2_acc_by_type(ori_types, pred_types, test_ori, test_pred, 3)
R2_4, acc_4, correct_4, total_4 = R2_acc_by_type(ori_types, pred_types, test_ori, test_pred, 4)

print(' 1 ring --- R2: ', R2_1, ' , acc: ', acc_1 , ', correct#: ', correct_1, ', total#: ', total_1)
print(' 2 ring --- R2: ', R2_2, ' , acc: ', acc_2 , ', correct#: ', correct_2, ', total#: ', total_2)
print(' 3 ring --- R2: ', R2_3, ' , acc: ', acc_3 , ', correct#: ', correct_3, ', total#: ', total_3)
print(' 4 ring --- R2: ', R2_4, ' , acc: ', acc_4 , ', correct#: ', correct_4, ', total#: ', total_4)

In [None]:
# Plot examples
for i in range(0, 5):
    plot_profiles_peak_1channel(test_pred[i].squeeze())
    plot_profiles_peak_1channel(test_ori[i].squeeze())