# Library

In [13]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.nn import MSELoss
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, ExponentialLR
from torch.utils.data import DataLoader
from resources.plot_utils import plot_R2

# Dataset

In [None]:
# Read in training data
data_directory = "./all_data/"
output_dir = data_directory + 'model/'

# Load RFP profiles
data_array = np.load(os.path.join(data_directory, 'all_outputs.npy'))
data_array = data_array.reshape([-1, 3, 201])
RFP_data = data_array[:, 1, :].squeeze()
# Normalize 
normalized_RFP = RFP_data / RFP_data.max(axis=1, keepdims=True)
normalized_RFP = normalized_RFP.reshape([-1, 1, 201])

# Parameters 
filename = data_directory + 'all_params.npy' # original scale
params_array = np.load(filename)
scaling_ranges = {
    'DC': [0.5e-3, 12.5e-2],
    'aC': [0.1, 1],
    'aA': [100, 100000],
    'aT': [10, 8000],
    'aL': [5, 500],
    'dA': [0.001, 0.1],
    'dT': [3, 300],
    'dL': [0.144, 14.4],
    'alpha': [1, 5],
    'beta':  [2, 2000],
    'Kphi':  [1, 10],
    'N0':  [200000, 5000000]
}
scaling_options = ['exp','linear','exp','exp','exp',
                   'exp','exp','exp','linear','exp',
                   'linear','linear']
all_params = ['DC', 'DN', 'DA', 'DB', 'aC','aA', 'aB', 'aT', 'aL', 'bN','dA', 'dB', 'dT', 'dL', 'k1', 
              'k2', 'KN', 'KP', 'KT', 'KA', 'KB', 'alpha','beta', 'Cmax', 'a', 'b', 'm', 'n', 'Kphi', 'l', 
              'N0', 'G1','G2','G3','G4','G5','G6','G7','G8','G9','G10','G11','G12', 'G13','G14',
             'G15','G16','G17','G18', 'G19', 'alpha_p','beta_p', 'seeding_v']
sceening_params = ['DC',  'aC', 'aA', 'aT', 'aL', 'dA','dT', 'dL', 'alpha','beta','Kphi', 'N0']
selected_param_idx = [all_params.index(param) for param in sceening_params]
params_array = params_array[:, selected_param_idx]

# Pattern types
pattern_types_array = np.load(os.path.join(data_directory, 'all_types.npy'))
pattern_types_array = pattern_types_array[:, 1]

print('---------------------------------------------')
print(f"RFP profiles: {normalized_RFP.shape}")
print(f"Parameters: {params_array.shape}")
print(f"Pattern types:  {pattern_types_array.shape}")

# Train 

In [None]:
# Get and check device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Set up hyperparameters
batch_size = 16
seq_length = data_array.shape[2] # dimension of the 1D profiles


In [4]:
# Normalize params
scaling_ranges = {
    'DC': [0.5e-3, 12.5e-2],
    'aC': [0.1, 1],
    'aA': [100, 100000],
    'aT': [10, 8000],
    'aL': [5, 500],
    'dA': [0.001, 0.1],
    'dT': [3, 300],
    'dL': [0.144, 14.4],
    'alpha': [1, 5],
    'beta':  [2, 2000],
    'Kphi':  [1, 10],
    'N0':  [200000, 5000000]
}
scaling_options = ['exp','linear','exp','exp','exp', 'exp','exp','exp','linear','exp','linear','linear']

def scale_feature(value, min_val, max_val, option):
    if option == "linear":
        return (value - min_val) / (max_val - min_val)
    elif option == "exp":
        return (np.log(value) - np.log(min_val)) / (np.log(max_val) - np.log(min_val))
    else:
        raise ValueError(f"Unknown scaling option: {option}")

def scale_dataset(dataset, ranges, opt):
    scaled_dataset = np.zeros_like(dataset)
    for i, (key, range_vals) in enumerate(ranges.items()):
        min_val, max_val = range_vals
        scaling_option = opt[i]
        scaled_dataset[:, i] = [scale_feature(val, min_val, max_val, scaling_option) for val in dataset[:, i]]
    return scaled_dataset

In [5]:
%run MLP_VAE_core.ipynb
%run VAE_core.ipynb

In [None]:
split_ratio_list =[0.1, 0.25, 0.5, 0.9, 0.95]
latent_dim_list = [8, 16, 32, 64]
num_models = len(latent_dim_list)
latent_channel = 16

model_list = []
vae_r2_traing_list = []
vae_r2_test_list = []
mlp_r2_traing_list = []
mlp_r2_test_list = []
# train_set_size = []

# Get a test set first, do not update it in the for loop
ratio = 0.1
data_array = torch.tensor(data_array).float()
train_data_all, test_data, train_indices_all, test_indices = train_test_split(normalized_RFP, range(normalized_RFP.shape[0]), test_size=ratio, random_state=25)
train_params_all, test_params, train_indices_all, test_indices = train_test_split(params_array, range(params_array.shape[0]), test_size=ratio, random_state=25)
train_labels_all, test_labels, train_indices_all, test_indices = train_test_split(pattern_types_array, range(pattern_types_array.shape[0]), test_size=ratio, random_state=25)

train_data_all = train_data_all[:, 0, :].reshape([-1, 1, 201])
test_data = test_data[:, 0, :].reshape([-1, 1, 201])

print('Test set size: ', len(test_data))
print('Complete train set size: ', len(train_data_all))

test_data = torch.tensor(test_data).float()
train_data_all = torch.tensor(train_data_all).float()


for i in range(0, num_models):
    
    latent_dim = latent_dim_list[i]
    ratio = 0.1

    print('----------------------  latent dim: ', latent_dim, '----------------------') 
    
    # Split train and validation datasets
    train_data, valid_data, train_indices, valid_indices = train_test_split(train_data_all, range(train_data_all.shape[0]), test_size=ratio, random_state=15)
    train_params, valid_params, train_indices, valid_indices = train_test_split(train_params_all, range(train_params_all.shape[0]), test_size=ratio, random_state=15)
    train_labels, valid_labels, train_indices, valid_indices = train_test_split(train_labels_all, range(train_labels_all.shape[0]), test_size=ratio, random_state=15)

    # Normalize parameters
    train_params = scale_dataset(train_params, scaling_ranges, scaling_options)
    valid_params = scale_dataset(valid_params, scaling_ranges, scaling_options)
    test_params = scale_dataset(test_params, scaling_ranges, scaling_options)

    # Train set dataloader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    # Valid set loader
    valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
    # Test set loader
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    print('Train set size: ', len(train_data))
    print('Valid set size: ', len(valid_data))
    print('Test set size: ', len(test_data))

    print('------------------- VAE training ---------------------------')
    # Initiate VAE
    vae = VAE(seq_length, latent_dim, latent_channel)

    # # Load VAE if exist
    # model_path = os.path.join(output_dir,'screening_VAE_latent_' + str(latent_dim) + '.pt')
    # print(model_path)
    # vae.load_state_dict(torch.load(model_path))
    # vae.eval() 

    # Send model to device
    vae = vae.to(device)

    # Training setup
    alpha = 2e-5
    lr= 3e-5            
    min_lr = 5e-6      
    epochs = 1000
    gamma = 0.99
    weight_decay=1e-5

    criterion = nn.MSELoss()
    optimizer = optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)

    # Early stopping 
    best_valid_loss = np.inf  
    epochs_no_improve = 0  # Counter for epochs since the test loss last improved
    patience = 30 # Patience for early stopping

    # Warm up
    warmup_epochs = 10
    def warmup_scheduler(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            return 1.0

    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_scheduler)
    scheduler2 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    # Train 
    train_loss_history = []
    valid_loss_history = []
    test_loss_history = []

    for epoch in range(epochs):
        train_loss = train_VAE(vae, train_loader, optimizer, criterion, alpha, device)
        valid_loss = validate_VAE(vae, valid_loader, criterion, alpha, device)
        test_loss = test_VAE(vae, test_loader, criterion, alpha, device)

        train_loss_history.append(train_loss)
        valid_loss_history.append(valid_loss)
        test_loss_history.append(test_loss)

        # Clamp minimum learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(param_group['lr'], min_lr)

        # Print loss
        if (epoch) % 5 == 0: # every 5 epochs
            print('Epoch: {} Train: {:.7f}, Valid: {:.7f}, Test: {:.7f}, Lr:{:.8f}'.format(epoch + 1, train_loss_history[epoch], valid_loss_history[epoch], test_loss_history[epoch], param_group['lr']))

        # Update learning rate
        if epoch < warmup_epochs:
            scheduler1.step()
        else:
            scheduler2.step()
        scheduler2.step()

        # Check for early stopping
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            epochs_no_improve = 0  # Reset the counter
        else:
            epochs_no_improve += 1  # Increment the counter

        if epochs_no_improve == patience:
            print('Early stopping!')
            break  # Exit the loop
   
    # Save VAE
    model_path = os.path.join(output_dir,'screening_VAE_latent_' + str(latent_dim) + '.pt')
    print('model path: ', model_path)
    torch.save(vae.state_dict(), model_path)
    
    
    print('------------------- Evaluate VAE performance ---------------------------')
    # Calculate trained VAE accuracy
    train_data = train_data.cpu()

    train_data_ori = torch.tensor(train_data[0:1000], dtype=torch.float32).to(device)
    test_data_ori = torch.tensor(test_data[0:1000], dtype=torch.float32).to(device)
    
    with torch.no_grad():
        train_pred, _, _ = vae(train_data_ori)
        test_pred, _, _ = vae(test_data_ori)

    # Squeeze the output to match the original data dimension
    train_pred = train_pred.squeeze(1).cpu().numpy()
    test_pred = test_pred.squeeze(1).cpu().numpy()
    train_data_ori = train_data_ori.cpu().numpy().squeeze(1)
    test_data_ori = test_data_ori.cpu().numpy().squeeze(1)

    r2_train = r2_score(train_data_ori.flatten(), train_pred.flatten())
    r2_test = r2_score(test_data_ori.flatten(), test_pred.flatten())
    
    print('******** VAE ********')
    print('R2 train: ', r2_train)
    print('R2 test: ', r2_test)
    vae_r2_traing_list.append(r2_train)
    vae_r2_test_list.append(r2_test)

    print('------------------- MLP training ---------------------------')
    # Create datasets
    train_dataset = CustomDataset_with_type(torch.tensor(train_params, dtype=torch.float32), 
                              torch.tensor(train_data, dtype=torch.float32),
                              torch.tensor(train_labels, dtype=torch.float32))
    valid_dataset = CustomDataset_with_type(torch.tensor(valid_params, dtype=torch.float32), 
                                 torch.tensor(valid_data, dtype=torch.float32),
                                 torch.tensor(valid_labels, dtype=torch.float32))
    test_dataset = CustomDataset_with_type(torch.tensor(test_params, dtype=torch.float32), 
                                 torch.tensor(test_data, dtype=torch.float32),
                                 torch.tensor(test_labels, dtype=torch.float32))
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    print('Traning data points:', len(train_dataset))
    print('Valid data points:', len(valid_dataset))
    print('Test data points:', len(test_dataset))

    # Initiate MLP
    input_dim = len(train_params[0,:])
    model = CombinedModel(input_dim, latent_dim, vae.decoder, 42)
    
    # # Load model if exist
    # model_path = os.path.join(output_dir,'screening_MLP_latent_' + str(latent_dim) + '.pt')
    # print(model_path)
    # model.load_state_dict(torch.load(model_path))
    # model.eval() 
    
    model = model.to(device)

    # Train MLP
    lr = 1e-4            
    min_lr = 1e-7      
    epochs = 1000
    gamma = 0.99
    weight_decay = 1e-5
    alpha = 2e-5 #2e-5 not used
    
    # Training setup
    criterion = MSELoss()
    optimizer = Adam(model.mlp.parameters(), lr= lr, weight_decay=weight_decay)  # Only train MLP parameters

    #  Warmup 
    warmup_epochs = 8

    # Scheduler 
    scheduler1 = LambdaLR(optimizer, lr_lambda=warmup_scheduler)
    scheduler2 = ExponentialLR(optimizer, gamma=0.995)

    # Early stopping setup
    best_valid_loss = np.inf
    epochs_no_improve = 0
    patience = 30
    
    # Training loop
    train_loss_history = []
    valid_loss_history = []
    test_loss_history = []

    for epoch in range(epochs):

        train_loss = train_combined(model, train_loader, optimizer, criterion, alpha, device)
        valid_loss = validate_combined(model, valid_loader, criterion, alpha, device)
        test_loss = test_combined(model, test_loader, criterion, alpha, device)

        train_loss_history.append(train_loss)
        valid_loss_history.append(valid_loss)
        test_loss_history.append(test_loss)

        # Clamp minimum learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(param_group['lr'], min_lr)

        # Print loss
        if epoch % 5 == 0: # every 5 epochs
            print('Epoch: {} Train: {:.7f}, Valid: {:.7f}, Test: {:.7f}, Lr:{:.8f}'.format(epoch + 1, train_loss_history[epoch], valid_loss_history[epoch], test_loss_history[epoch], param_group['lr']))

        # Update learning rate
        if epoch < warmup_epochs:
            scheduler1.step()
        else:
            scheduler2.step()

        # Check for early stopping
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            epochs_no_improve = 0  # Reset the counter
        else:
            epochs_no_improve += 1  # Increment the counter

        if epochs_no_improve == patience:
            print('Early stopping!')
            break  # Exit the loop
        if valid_loss < 0.0001:
            break

    # Save MLP
    model_path = os.path.join(output_dir,'screening_MLP_latent_' + str(latent_dim) + '.pt')
    print('model path: ', model_path)
    torch.save(model.state_dict(), model_path)
    
    
    print('------------------- Evaluate MLP - VAE performance ---------------------------')

    # Calculate trained MLP accuracies
    model.eval()
    train_pred = []
    train_ori = []
    test_pred = []
    test_ori = []

    model.eval()
    with torch.no_grad():
        for params, data, _ in train_loader:
            data = data.to(device)
            params = params.to(device)
            reconstruction, mean, logvar = model(params)
            train_pred.append(reconstruction.cpu().numpy())
            train_ori.append(data.cpu().numpy())

        for params, data, _ in test_loader:
            data = data.to(device)
            params = params.to(device)
            reconstruction, mean, logvar = model(params)
            test_pred.append(reconstruction.cpu().numpy())
            test_ori.append(data.cpu().numpy())

    # Concatenate
    train_pred = np.concatenate(train_pred)
    train_ori = np.concatenate(train_ori)
    test_pred = np.concatenate(test_pred)
    test_ori = np.concatenate(test_ori)
    
    
    filename = output_dir + 'MLP_VAE_train_R2.png'
    plot_R2(train_ori, train_pred, filename)
    filename = output_dir + 'MLP_VAE_test_R2.png'
    plot_R2(test_ori, test_pred, filename)
    print(filename)
    
    # Squeeze the output to match the original data dimension
    train_pred = train_pred.squeeze(1)
    test_pred = test_pred.squeeze(1)
    train_ori = train_ori.squeeze(1)
    test_ori = test_ori.squeeze(1)

    r2_train = r2_score(train_ori.flatten(), train_pred.flatten())
    r2_test = r2_score(test_ori.flatten(), test_pred.flatten())
    
    print('******** MLP ********')
    print('R2 train: ', r2_train)
    print('R2 test: ', r2_test)
    
    mlp_r2_traing_list.append(r2_train)
    mlp_r2_test_list.append(r2_test)
    
  

In [None]:
# Creating subplots
fig, axs = plt.subplots(2, 1, figsize=(5, 10))

# Plotting the data
axs[0].plot(latent_dim_list, vae_r2_test_list)
axs[1].plot(latent_dim_list, mlp_r2_test_list)

# Setting titles for each subplot
axs[0].set_title('VAE ')
axs[1].set_title('MLP - VAE')

plt.tight_layout()
plt.show()