# Library

In [7]:
import os
import numpy as np
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.model_selection import train_test_split
from resources.plot_utils import plot_R2

In [2]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
%run VAE_core.ipynb

# Dataset

In [None]:
# Read in training data
data_dir = "/data/"
output_dir = data_dir + 'model/' 
os.makedirs(output_dir, exist_ok=True)

# RFP profiles
filename = os.path.join(data_dir, 'all_outputs.npy')
data_array = np.load(filename)
data_array = data_array.reshape([-1, 3, 201])
RFP_data = data_array[:, 1, :].squeeze()
print(f"RFP profiles: {RFP_data.shape}")

# Normalize 
normalized_RFP = RFP_data / RFP_data.max(axis=1, keepdims=True)
print(f"Normalized RFP profiles: {normalized_RFP.shape}")

# Plot -- Create 100 panels (10x10), each showing a random data series
num_panels = 100
rows = 10
cols = 10
fig, axs = plt.subplots(rows, cols, figsize=(20, 20))
fig.suptitle('Random RFP Data', fontsize=16)
for i in range(rows):
    for j in range(cols):
        random_index = np.random.randint(normalized_RFP.shape[0])
        axs[i, j].plot(normalized_RFP[random_index])
        axs[i, j].set_ylim([0, 1])
        axs[i, j].axis('off')
plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show()

# Train

In [None]:
# Parameters
data = normalized_RFP # Use normalized RFP profile to train VAE
seq_length = data.shape[1]
batch_size = 32
latent_dim = 16
latent_channel=16

alpha = 2e-5
lr= 1e-3            
min_lr = 5e-6      
epochs = 1000
gamma = 0.99
weight_decay=1e-5

# Split data
data = torch.tensor(data).float().unsqueeze(1)
data = torch.tensor(data, dtype=torch.float32)
train_data, test_data, train_indices, test_indices = train_test_split(data, range(len(data)), test_size=0.1, random_state=25, shuffle=False)
train_data, valid_data, train_indices, valid_indices = train_test_split(train_data, range(len(train_data)),test_size=0.1, random_state=25, shuffle=True)

print(' --------------------------------------------------- ')
print('Train data size: ', train_data.shape)
print('Validation data size: ', valid_data.shape)
print('Test data size: ', test_data.shape)

# Prepare DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

# Model creation, loss function, and optimizer
model = VAE(seq_length, latent_dim, latent_channel)
print(model)
#Load previous model if it exists
#model.load_state_dict(torch.load('VAE.pt'))

model = model.to(device)
print(f'The model has {count_parameters(model):,} parameters')

# Training setup
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# Early stopping 
best_test_loss = np.inf  
epochs_no_improve = 0  # Counter for epochs since the test loss last improved
patience = 30 # Patience for early stopping

# Warm up
warmup_epochs = 10
def warmup_scheduler(epoch):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        return 1.0

scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_scheduler)
scheduler2 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

# Training loop
train_loss_history = []
valid_loss_history = []
test_loss_history = []

for epoch in range(epochs):
    train_loss = train_VAE(model, train_loader, optimizer, criterion, alpha, device)
    valid_loss = validate_VAE(model, valid_loader, criterion, alpha, device)
    test_loss = test_VAE(model, test_loader, criterion, alpha, device)
    
    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)
    test_loss_history.append(test_loss)

    # Clamp minimum learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = max(param_group['lr'], min_lr)

    # Print loss
    if (epoch + 1) % 5 == 0: # every 5 epochs
        print('Epoch: {} Train: {:.7f}, Valid: {:.7f}, Test: {:.7f}, Lr:{:.8f}'.format(epoch + 1, train_loss_history[epoch], valid_loss_history[epoch], test_loss_history[epoch], param_group['lr']))
    
    # Update learning rate
    if epoch < warmup_epochs:
        scheduler1.step()
    else:
        scheduler2.step()

    # Check for early stopping
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        epochs_no_improve = 0  # Reset the counter
    else:
        epochs_no_improve += 1  # Increment the counter

    if epochs_no_improve == patience:
        print('Early stopping!')
        break  # Exit the loop


In [None]:
# Plotting the loss history
plt.figure(figsize=(6, 3))
plt.semilogy(train_loss_history, label='Training')
plt.semilogy(valid_loss_history, label='Validation')
plt.semilogy(test_loss_history, label='Testing')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot R2
model.eval()
train_data_short = train_data[0: len(test_data)]
test_data_short = test_data

with torch.no_grad():
    train_data_short = train_data_short.to(device)
    test_data_short = test_data_short.to(device)

    train_pred, _, _= model(train_data_short)
    test_pred, _, _ = model(test_data_short)
    
train_data_short = train_data_short.squeeze(1).cpu().numpy()
test_data_short = test_data_short.squeeze(1).cpu().numpy()
train_pred = train_pred.squeeze(1).cpu().numpy()
test_pred = test_pred.squeeze(1).cpu().numpy()

filename = output_dir + 'VAE_train_R2.png'
plot_R2(train_data_short, train_pred, filename)
filename = output_dir + 'VAE_test_R2.png'
plot_R2(test_data_short, test_pred, filename)

In [None]:
# Plot examples
fig, axs = plt.subplots(2, 5, figsize=(10, 3))

for i in range(5):
    axs[0, i].plot(train_data_short[i].squeeze(), label='Original', color='blue')
    axs[0, i].plot(train_pred[i], label='Reconstructed', color='orange')
    axs[0, i].set_title(f'Train {i + 1}')
    # axs[0, i].legend()

for i in range(5):
    axs[1, i].plot(test_data_short[i].squeeze(), label='Original', color='blue')
    axs[1, i].plot(test_pred[i], label='Reconstructed', color='orange')
    axs[1, i].set_title(f'Test {i + 1}')
    # axs[1, i].legend()

plt.tight_layout()
plt.show()

# Save

In [None]:
# Save model
filename = output_dir + 'VAE.pt'
print(filename)
torch.save(model.state_dict(), filename)