In [2]:
from dir import *
from VAE_model import *
from VAE_model_2 import *
from VAE_model_single import *
from VAE_MoG_model import *
from training import *
from extras import *
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
from scipy.stats import shapiro
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import itertools
plt.style.use('ggplot')

# 1) Data exploration 

## 1.1) Overall exporation 

In [None]:
data = pd.read_csv(PANGENOME_MATRIX_CSV, index_col=[0], header=[0])

In [None]:
phylogroup_data = pd.read_csv(PHYLOGROUPS_DATA, index_col=[0], header=[0])

In [None]:
phylogroup_data

In [None]:
data

In [None]:
data.dtypes

In [None]:
data.transpose()

In [None]:
data.transpose()[data.transpose()[data.transpose().columns].eq(0).all(1)]

In [None]:
data.columns

In [None]:
percent_GF_present = data.astype(bool).sum(axis=0) / len(data.index) * 100

In [None]:
percent_GF_present

In [None]:
# plt.figure(figsize=(10, 8))
# percent_GF_present.iloc[:100].plot(kind='bar', color='dodgerblue')
# plt.xlabel('Genomes')
# plt.ylabel('Percentage of GFs present in the genome')
# plt.show()

In [None]:
frequency1 = data.sum(axis=1)

In [None]:
plt.figure(figsize=(10,8))
plt.hist(frequency1, color='dodgerblue')
plt.xlabel('Gene count')
plt.ylabel('Frequency')
plt.savefig("figures/gene_count.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
frequency2 = data.sum(0)

In [None]:
frequency2

In [None]:
plt.figure(figsize=(10,8))
plt.hist(frequency2, bins=20, color='dodgerblue')
plt.xlabel('Genome size')
plt.ylabel('Gene Gamily Frequency')
plt.savefig("figures/genome_size.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
threshold_data = []
thresholds = np.linspace(0, 20, num=10)

for i in thresholds:
    row_sums = data.sum(axis=1)
    threshold_data.append(len(data[row_sums >= i]))

In [None]:
threshold_data

In [None]:
thresholds

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(thresholds, threshold_data, color='dodgerblue')
plt.plot(thresholds, threshold_data, color='dodgerblue')
plt.xlabel('Gene Number Thershold')
plt.ylabel('Gene Frequency')
plt.savefig("figures/gene_frequency.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
data.transpose()

## 1.2) PCA

In [None]:
merged_df = pd.merge(data.transpose(), phylogroup_data, how='inner', left_index=True, right_on='AccessionID')

In [None]:
merged_df

In [None]:
# Apply PCA
pca = PCA(n_components=2)
data_pca = pca.fit_transform(merged_df.iloc[:, :-1])
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])

In [None]:
# Visualize the first two principal components
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = merged_df.Phylogroup.tolist(), data=df_pca)
plt.savefig("figures/PCA_graph.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
df_pca

In [None]:
shapiro_test_pc1 = shapiro(df_pca['PC1'])
shapiro_test_pc2 = shapiro(df_pca['PC2'])
print(f"Shapiro-Wilk Test for PC1: {shapiro_test_pc1}")
print(f"Shapiro-Wilk Test for PC2: {shapiro_test_pc2}")

# 2) Data preprocessing  

In [None]:
merged_df

In [None]:
numeric_cols = merged_df.select_dtypes(include='number')
column_sums = numeric_cols.sum(axis=0)

filtered_columns = column_sums[column_sums >= 20].index
filtered_data = merged_df[filtered_columns]

filtered_data = merged_df[filtered_columns].copy()
filtered_data['Phylogroup'] = merged_df['Phylogroup']

In [None]:
filtered_data

In [None]:
data_array_t = np.array(filtered_data.iloc[:, :-1])
phylogroups_array = np.array(filtered_data.iloc[:, -1])

In [None]:
data_array_t

In [None]:
phylogroups_array

In [None]:
data_array_t.shape

In [None]:
data_array_t.shape[1]

In [None]:
# Converting to PyTorch tensor
data_tensor = torch.tensor(data_array_t, dtype=torch.float32)

# Spliting into train and test sets
train_data, temp_data, train_labels, temp_labels = train_test_split(data_tensor, phylogroups_array, test_size=0.3, random_state=12345)
val_data, test_data, val_labels, test_labels = train_test_split(temp_data, temp_labels, test_size=0.3333, random_state=12345)

# train_data, val_data = train_test_split(data_tensor, test_size=0.2, random_state=12345)
# train_data, test_data = train_test_split(data_tensor, test_size=0.25, random_state=12345)

test_phylogroups = test_labels

# train_labels = torch.tensor(train_labels, dtype=torch.long)
# val_labels = torch.tensor(val_labels, dtype=torch.long)
# test_labels = torch.tensor(test_labels, dtype=torch.long)

# TensorDataset
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)
test_dataset = TensorDataset(test_data)

# DataLoaders for main training
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Dataloader for overfitting on one sample (for dubbiging purposes)
input_dim = data_array_t.shape[1]
binary_data = torch.tensor(np.random.randint(0, 2, size=(1, input_dim)), dtype=torch.float32)
single_sample_dataset = TensorDataset(binary_data)
single_sample_loader = DataLoader(single_sample_dataset, batch_size=1, shuffle=True)

# Dataloader fot a small subset for overfitting (again, for debugging)
small_subset_indices = np.random.choice(len(train_dataset), size=256, replace=False)
small_subset = Subset(train_dataset, small_subset_indices)
small_loader = DataLoader(small_subset, batch_size=batch_size, shuffle=True)

In [None]:
len(test_data)

In [None]:
len(test_phylogroups)

In [None]:
len(train_dataset)

In [None]:
len(train_loader)

In [None]:
len(train_dataset)

In [None]:
len(val_dataset)

In [None]:
train_data

In [None]:
TensorDataset(train_data)

In [None]:
TensorDataset(torch.tensor(train_data))

In [None]:
print(data_tensor)

# 3) Overfitting on a single sample and small data subset

## 3.1) Overfitting on a single sample

In [None]:
single_sample = torch.randn(1, data_array_t.shape[1])

In [None]:
single_sample.shape[0]

In [None]:
# NO GRADIENT CLIPPING AND SCHEDULER 
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
n_epochs = 10
input_dim = data_array_t.shape[1]

model = VAE_single(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Overfitting one sample to see if the model is broken 
model.train()
num_epochs = 1000

# Collecting data for visualisation 
train_loss_vals1 = []
train_loss_vals2 = []
kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / n_epochs
    epoch_kl_divergence = 0
    epoch_kl_divergence_beta = 0 
    
    for data in single_sample_loader:
        data = data[0].to(torch.float)
        # print(data)
        
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())
        
        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        # Backpropagation
        loss.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method1) = {loss.item()}\nLoss (method2) = {loss2.item()}")

    train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(single_sample_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(single_sample_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
torch.save(model.state_dict(), "models/saved_single_sample_VAE_1000.pt")
print("Model saved.")

In [None]:
epochs = np.linspace(1, 1000, num=1000)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1, color='dodgerblue')
plt.plot(epochs, train_loss_vals1, label='train loss (no KL annelaing)', color='dodgerblue')
plt.scatter(epochs, train_loss_vals2, color='darkorange')
plt.plot(epochs, train_loss_vals2, label='train loss using KL annelaing', color='darkorange')
plt.xlim(0, 100)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("figures/train_loss_comparisons_no_GS_1000_ss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta, color='dodgerblue')
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing', color='dodgerblue')
plt.scatter(epochs, kl_divergences_beta, color='darkorange')
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling', color='darkorange')
plt.xlim(0, 50)
plt.xlabel('KL divergence')
plt.ylabel('Epoch')
plt.legend()
plt.savefig("figures/kl_divergence_comparison_no_GS_1000_ss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
input_dim = data_array_t.shape[1]

In [None]:
input_dim

In [None]:
# GRADIENT CLIPPING PLUS SCHEDULER USED 
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
n_epochs = 10

# Model
model = VAE_single(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Overfitting
model.train()
num_epochs = 1000 

# Gradient clipping and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Collecting data for visualisation 
train_loss_vals1 = []
train_loss_vals2 = []
kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / n_epochs
    epoch_kl_divergence = 0
    epoch_kl_divergence_beta = 0 

    for data in single_sample_loader:
        data = data[0].to(torch.float)

        # print(data)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())
        
        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    scheduler.step()  

    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss (method1) = {loss.item()}")
        print(f"Epoch {epoch}: Loss (method2) = {loss2.item()}")

    train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(single_sample_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(single_sample_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1, color='dodgerblue')
plt.plot(epochs, train_loss_vals1, label='train loss (no KL annelaing)', color='dodgerblue')
plt.scatter(epochs, train_loss_vals2, color='darkorange')
plt.plot(epochs, train_loss_vals2, label='train loss using KL annelaing', color='darkorange')
plt.xlim(0, 100)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("figures/train_loss_comparisons_GS_1000_ss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta, color='dodgerblue')
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing', color='dodgerblue')
plt.scatter(epochs, kl_divergences_beta, color='darkorange')
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling', color='darkorange')
plt.xlim(0, 50)
plt.xlabel('Epoch')
plt.ylabel('KL divergence value')
plt.legend()
plt.savefig("figures/kl_divergence_comparison_GS_1000_ss.pdf", format="pdf", bbox_inches="tight")
plt.show()

## 3.2) Overfitting on a small train subset

In [None]:
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0

model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
num_epochs = 1000  

train_loss_vals1 = []
# train_loss_vals2 = []
kl_divergences_no_beta = []
# kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / num_epochs
    epoch_kl_divergence = 0
    for data in small_loader:
        data = data[0].to(torch.float)
        # print(data)
    
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())

        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        # epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        # loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        loss.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method1) = {loss.item()}")

    train_loss_vals1.append(loss.item())
    # train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(small_loader.dataset))
    # kl_divergences_beta.append(epoch_kl_divergence_beta / len(small_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
# Save trained model
torch.save(model.state_dict(), "models/saved_small_VAE1_1000.pt")
print("Model saved.")

In [None]:
epochs = np.linspace(1, 1000, num=1000)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1, color='dodgerblue')
plt.plot(epochs, train_loss_vals1, color='dodgerblue')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("figures/train_loss_small_ds1_1000.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta, color='dodgerblue')
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing', color='dodgerblue')
plt.scatter(epochs, kl_divergences_beta, color='darkorange')
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling', color='darkorange')
plt.xlabel('Epoch')
plt.ylabel('KL divergence value')
plt.legend()
plt.savefig("figures/kl_divergence_comparison_GS_ds1_1000.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0

model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
num_epochs = 1000  

# train_loss_vals1 = []
train_loss_vals2 = []
# kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / num_epochs
    epoch_kl_divergence_beta = 0
    for data in small_loader:
        data = data[0].to(torch.float)
        # print(data)
    
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())

        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        # epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        # loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        loss2.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method 2) = {loss2.item()}")

    # train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    # kl_divergences_no_beta.append(epoch_kl_divergence / len(small_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(small_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss2.item()}")

In [None]:
# Save trained model
torch.save(model.state_dict(), "models/saved_small_VAE2_100.pt")
print("Model saved.")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals2, color='dodgerblue')
plt.plot(epochs, train_loss_vals2, color='dodgerblue')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("figures/train_loss_small_ds2_1000.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta, color='dodgerblue')
plt.plot(epochs, kl_divergences_no_beta, label = 'no KL annealing', color='dodgerblue')
plt.scatter(epochs, kl_divergences_beta, color='darkorange')
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling', color='darkorange')
# plt.xlim(0, 1000)
# plt.ylim(0, 20)
plt.xlabel('Epoch')
plt.ylabel('KL divergence value')
plt.legend()
plt.savefig("figures/kl_divergence_comparison_1_2_1000.pdf", format="pdf", bbox_inches="tight")
plt.show()

# 4) Training VAE model on full dataset (train + validation sets)

## 4.1) Training with no KL annealing 

In [None]:
# Model
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64

model1 = VAE(input_dim, hidden_dim, latent_dim).to(device)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

max_norm = 1.0 
beta_start = 0.1
beta_end = 1.0
n_epochs = 100

train_loss_vals, val_loss_vals = train_no_KL_annelaing(model=model1, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, max_norm=max_norm)

In [None]:
for batch in train_loader:
    if batch[0].size(0) > 0:
        print('+') 

In [None]:
data_array_t.shape[1]

In [None]:
# Save trained model
torch.save(model1.state_dict(), "models/saved_no_KL_annealing_VAE_100.pt")
print("Model saved.")

In [None]:
epochs = np.linspace(1, 100, num=100)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals, color='dodgerblue')
plt.plot(epochs, train_loss_vals, label='Train Loss', color='dodgerblue')
plt.scatter(epochs, val_loss_vals, color='darkorange')
plt.plot(epochs, val_loss_vals, label='Validation Loss', color='darkorange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig("figures/model_train_val_loss_1_100.pdf", format="pdf", bbox_inches="tight")
plt.show()

## 4.2) Training using KL annealing 

In [None]:
model2 = VAE(input_dim, hidden_dim, latent_dim).to(device)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

train_loss_vals2, val_loss_vals = train_with_KL_annelaing(model=model2, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, max_norm=max_norm)

In [None]:
# Save trained model
torch.save(model2.state_dict(), "models/saved_KL_annealing_VAE_100.pt")
print("Model saved.")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals2, color='dodgerblue')
plt.plot(epochs, train_loss_vals2, label='Train Loss', color='dodgerblue')
plt.scatter(epochs, val_loss_vals, color='darkorange')
plt.plot(epochs, val_loss_vals, label='Validation Loss', color='darkorange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig("figures/model_train_val_loss_2_100.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals, color='dodgerblue')
plt.plot(epochs, train_loss_vals, label='No KL annealing', color='dodgerblue')
plt.scatter(epochs, train_loss_vals2, color='darkorange')
plt.plot(epochs, train_loss_vals2, label='with KL annelaing', color='darkorange')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("figures/compare_first_second_train_losses_100.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# # Load trained model 

# model = VAE(input_dim, hidden_dim, latent_dim)
# model.load_state_dict(torch.load('saved_KL_annealing_VAE.pt', map_location=device))
# model.eval()  

In [None]:
# recon_x, mu, logvar = model(data)

## 4.3) (Experiment) Training a MoG VAE (later)

In [None]:
# num_components = 3
# model2 = VAEWithMoGPrior(input_dim, hidden_dim, latent_dim, num_components).to(device)
# optimizer = torch.optim.Adam(model2.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
# # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# train(model=model2, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, max_norm=max_norm)

In [None]:
# # save trained model
# torch.save(model2.state_dict(), "models/saved_MoG_VAE.pt")
# print("Model saved.")

# 5) Observing the latent spaces of the model(s) fitted

In [None]:
hidden_dim = 512
latent_dim = 64
input_dim = data_array_t.shape[1]

# Trying to get teh latent space
model2 = VAE(input_dim, hidden_dim, latent_dim)
model2.load_state_dict(torch.load('models/saved_KL_annealing_VAE_100.pt'))  
model2.eval()  

# Get latent variables
latents = get_latent_variables(model2, test_loader, device)

In [None]:
# Apply t-SNE for dimensionality reduction
tsne = TSNE(n_components=2)
tsne_latents = tsne.fit_transform(latents)

plt.figure(figsize=(10, 8))
plt.scatter(tsne_latents[:, 0], tsne_latents[:, 1], color='dodgerblue')
# plt.xlim(-400, 400)
# plt.ylim(-400, 400)
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.show()

In [None]:
df_tsne = pd.DataFrame(tsne_latents, columns=['PC1', 'PC2'])
df_tsne['phylogroup'] = test_phylogroups
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_tsne['phylogroup'], data=df_tsne)
plt.savefig("figures/tsne_latent_space_visualisation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# Apply PCA
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups


In [None]:
# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("figures/pca_latent_space_visualisation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
df_pca

# 5) Hyperparameter tuning

## 5.1) Gridserch for simple hyperparameter tuning

In [None]:
# # Gridsearch
# input_dim = data_array_t.shape[1]
# hidden_dim_values = [256, 512, 1024]
# latent_dim_values = [32, 64, 128]
# learning_rate_values = [0.01, 1e-3] # Decrease of learning rate causes higher average train loss, better if 0.01, 0.001
# # beta_start_values = [0.01, 0.1, 0.2]
# # beta_end_values = [0.5, 1.0, 2.0]
# # max_norm_values = [0.5, 1.0, 2.0]
# max_norm = 1.0 
# beta_start = 0.1
# beta_end = 1.0

# # beta_start, beta_end, max_norm
# for hidden_dim, latent_dim, learning_rate in itertools.product(
#     hidden_dim_values, latent_dim_values, learning_rate_values): #beta_start_values, beta_end_values, max_norm_values
#     print(f"Training with hidden_dim={hidden_dim}, latent_dim={latent_dim}, learning_rate={learning_rate}") # beta_start={beta_start}, beta_end={beta_end}, max_norm={max_norm}"
#     model = VAE(input_dim, hidden_dim, latent_dim).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
#     train_with_KL_annelaing(model=model, optimizer=optimizer, scheduler=scheduler, n_epochs=10, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, max_norm=max_norm)
#     print("--------------------------------------------------------------------------------------")

### result - best params hidden_dim = 1024, latent_dim = 32, lr = 1e-3 (based on average train and val loss)

In [None]:
input_dim = data_array_t.shape[1]
hidden_dim = 1024
latent_dim = 32
max_norm = 1.0 
beta_start = 0.1
beta_end = 1.0
n_epochs = 100


model = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

train_loss_vals2, val_loss_vals = train_with_KL_annelaing(model=model, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, max_norm=max_norm)

In [None]:
torch.save(model.state_dict(), "models/saved_KL_annealing_VAE_tuned_100.pt")
print("Model saved.")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals2, color='dodgerblue')
plt.plot(epochs, train_loss_vals2, label='Train Loss', color='dodgerblue')
plt.scatter(epochs, val_loss_vals, color='darkorange')
plt.plot(epochs, val_loss_vals, label='Validation Loss', color='darkorange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# AHPT - after hyperparameter tuning
plt.savefig("figures/train_val_loss_AHPT_100.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# Visualize the first two principal components
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

df_tsne = pd.DataFrame(tsne_latents, columns=['PC1', 'PC2'])
df_tsne['phylogroup'] = test_phylogroups
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_tsne['phylogroup'], data=df_tsne)
plt.savefig("figures/pca_latent_space_visualisation_AHPT.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
len(latents)

In [None]:
# Apply t-SNE for dimensionality reduction
tsne = TSNE(n_components=2)
data_tsne = tsne.fit_transform(latents)

In [None]:
df_tsne = pd.DataFrame(tsne_latents, columns=['PC1', 'PC2'])
df_tsne['phylogroup'] = test_phylogroups
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_tsne['phylogroup'], data=df_tsne)
plt.savefig("figures/tsne_latent_space_visualisation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
len(data_tsne)

# 6) Evaluation test 

In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)
model.load_state_dict(torch.load('models/saved_KL_annealing_VAE_tuned_100.pt'))  

model.eval()
with torch.no_grad():
    recon_x, mu, logvar = model(test_data)

recon_x_binarized = (recon_x > 0.5).int()

f1 = sklearn.metrics.f1_score(test_data.flatten(), recon_x_binarized.flatten())
print(f'F1 Score: {f1:.2f}')

accuracy = sklearn.metrics.accuracy_score(test_data.flatten(), recon_x_binarized.flatten())
print(f'Accuracy Score: {accuracy:.2f}')

In [None]:
recon_x_binarized = (recon_x > 0.5).int()

f1_scores = []
accuracy_scores = []
for genome_x, genome in zip(recon_x_binarized, test_data):
    f1_scores.append(sklearn.metrics.f1_score(genome_x, genome))
    accuracy_scores.append(sklearn.metrics.accuracy_score(genome_x, genome))

In [None]:
type(test_data)

In [None]:
plt.figure(figsize=(10,8))
plt.hist(f1_scores, color='dodgerblue')
plt.xlabel("F1 score")
plt.ylabel("Frequency")
plt.savefig("figures/f1_score_frequency_test_set_AHPT.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.hist(accuracy_scores, color='dodgerblue')
plt.xlabel("Accuracy score")
plt.ylabel("Frequency")
plt.savefig("figures/accuracy_score_frequency_test_set_AHPT.pdf", format="pdf", bbox_inches="tight")
plt.show()

# 7) Simulation/generation

In [2]:
state_dict = torch.load('models/saved_KL_annealing_VAE_BD_100_AHPT.pt', map_location=torch.device('cpu'))
print(state_dict.keys())

odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.1.weight', 'encoder.1.bias', 'encoder.1.running_mean', 'encoder.1.running_var', 'encoder.1.num_batches_tracked', 'encoder.3.weight', 'encoder.3.bias', 'encoder.4.weight', 'encoder.4.bias', 'encoder.4.running_mean', 'encoder.4.running_var', 'encoder.4.num_batches_tracked', 'encoder.6.weight', 'encoder.6.bias', 'encoder.7.weight', 'encoder.7.bias', 'encoder.7.running_mean', 'encoder.7.running_var', 'encoder.7.num_batches_tracked', 'mean_layer.weight', 'mean_layer.bias', 'logvar_layer.weight', 'logvar_layer.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.1.weight', 'decoder.1.bias', 'decoder.1.running_mean', 'decoder.1.running_var', 'decoder.1.num_batches_tracked', 'decoder.3.weight', 'decoder.3.bias', 'decoder.4.weight', 'decoder.4.bias', 'decoder.4.running_mean', 'decoder.4.running_var', 'decoder.4.num_batches_tracked', 'decoder.6.weight', 'decoder.6.bias', 'decoder.7.weight', 'decoder.7.bias', 'decoder.7.running_mean', 'd

## 7.1) Random sampling from latent space

In [23]:
# Load trained model 
input_dim = 7580
hidden_dim = 512
latent_dim = 32

# changes layer norm layer to batch norm layer and 
model = VAE_2(input_dim, hidden_dim, latent_dim)
model.load_state_dict(torch.load('models/saved_KL_annealing_VAE_BD_100_AHPT.pt',  map_location=torch.device('cpu')))  
model.eval()  

# Generate 10 new samples
num_samples = 10 
with torch.no_grad():
    z = torch.randn(num_samples, latent_dim)  # Sample from the standard normal distribution because the latent space follows normal distribution 
    generated_samples = model.decode(z).cpu().numpy() 

threshold = 0.5
binary_generated_samples = (generated_samples > threshold).astype(float)

print("Generated samples (binary):\n", binary_generated_samples)
print("\n")
print("Generated samples (sigmoid function output):\n", generated_samples)


Generated samples (binary):
 [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


Generated samples (sigmoid function output):
 [[1.7769073e-04 3.2114153e-07 1.8784209e-01 ... 7.4324407e-02
  1.7745629e-01 2.3535958e-05]
 [5.1390733e-05 1.0870822e-09 7.5103067e-02 ... 5.4485691e-03
  1.2381562e-02 1.4658183e-08]
 [2.2642735e-02 4.5095883e-02 4.3876082e-02 ... 9.1393203e-02
  2.0216005e-02 4.3009086e-06]
 ...
 [2.4069037e-02 1.5522110e-01 3.4874178e-02 ... 6.6571529e-03
  5.5042654e-03 4.7373556e-04]
 [5.3547625e-04 2.4078868e-02 3.4662257e-03 ... 8.0089215e-03
  4.9852736e-02 7.6145463e-04]
 [1.2590332e-03 5.8197655e-02 2.8905066e-02 ... 3.6031518e-02
  2.8072030e-03 1.9860319e-05]]


## 7.2) Grid sampling from latent space

In [None]:
# Apply PCA
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups


In [None]:
# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("figures/pca_latent_space_visualisation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [24]:
grid_size = 3
scale = 21.0

x = np.linspace(-scale, scale, grid_size)
y = np.linspace(-scale, scale, grid_size)
xx, yy = np.meshgrid(x, y)
grid_points = np.stack([xx.ravel(), yy.ravel()], axis=-1)
grid_points = torch.tensor(grid_points, dtype=torch.float32)
    

In [25]:
grid_points

tensor([[-2., -2.],
        [ 0., -2.],
        [ 2., -2.],
        [-2.,  0.],
        [ 0.,  0.],
        [ 2.,  0.],
        [-2.,  2.],
        [ 0.,  2.],
        [ 2.,  2.]])

In [None]:
new_high_dim_grid_points = pca.inverse_transform(grid_points)

new_high_dim_grid_points = torch.tensor(new_high_dim_grid_points, dtype=torch.float32)

with torch.no_grad():
    generated_samples = model.decode(new_high_dim_grid_points).cpu().numpy()

In [43]:
threshold = 0.5
binary_generated_samples = (generated_samples > threshold).astype(float)

print("Generated samples (binary):\n", binary_generated_samples)
print("\n")
print("Generated samples (sigmoid function output):\n", generated_samples)

Generated samples (binary):
 [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


Generated samples (sigmoid function output):
 [[7.7252116e-05 1.2633001e-03 4.7773786e-02 ... 1.7216282e-01
  2.2460228e-01 2.9453573e-05]
 [6.8355046e-05 1.4295902e-03 5.1858902e-02 ... 1.6925547e-01
  2.2957756e-01 5.6759844e-05]
 [2.0140858e-04 3.8045559e-03 9.3354106e-02 ... 1.7241058e-01
  1.9027841e-01 1.2891165e-03]
 ...
 [7.9948673e-05 2.4576401e-03 5.0572816e-02 ... 1.5987152e-01
  2.2472443e-01 7.1877468e-05]
 [1.0823663e-04 1.2167287e-03 5.3681433e-02 ... 1.6496356e-01
  2.1537539e-01 1.2699362e-04]
 [3.2322315e-04 3.8754996e-03 1.0536576e-01 ... 1.8948840e-01
  1.8244295e-01 2.4207111e-03]]


## 7.3) Interpolation