In [None]:
import numpy as np
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset
import pandas as pd
import matplotlib.pyplot as plt
from dir import *
from VAE_model import *
from VAE_model_single import *
from VAE_MoG_model import *
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
import itertools
from sklearn.preprocessing import StandardScaler
plt.style.use('ggplot')

# 1) Data exploration 

## 1.1) Overall exporation 

In [None]:
data = pd.read_csv(PANGENOME_MATRIX_CSV, index_col=[0], header=[0])

In [None]:
data

In [None]:
data.dtypes

In [None]:
data.transpose()

In [None]:
data.transpose()[data.transpose()[data.transpose().columns].eq(0).all(1)]

In [None]:
data.columns

In [None]:
percent_GF_present = data.astype(bool).sum(axis=0) / len(data.index) * 100

In [None]:
percent_GF_present

In [None]:
plt.figure(figsize=(10, 8))
percent_GF_present.iloc[:100].plot(kind='bar')
plt.xlabel('Genomes')
plt.ylabel('Percentage of GFs present in the genome')
plt.show()

In [None]:
frequency1 = data.sum(axis=1)

In [None]:
plt.figure(figsize=(10,8))
plt.hist(frequency1)
plt.xlabel('Gene count')
plt.ylabel('Frequency')
plt.show()

In [None]:
frequency2 = data.sum(0)

In [None]:
frequency2

In [None]:
# plt.figure(figsize=(10,8))
# plt.hist(frequency2, bin=20)
# plt.xlabel('Genome size')
# plt.ylabel('Gene Gamily Frequency')
# plt.show()

In [None]:
threshold_data = []
thresholds = np.linspace(0, 20, num=10)

for i in thresholds:
    row_sums = data.sum(axis=1)
    threshold_data.append(len(data[row_sums >= i]))


In [None]:
threshold_data

In [None]:
thresholds

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(thresholds, threshold_data)
plt.plot(thresholds, threshold_data)
plt.xlabel('Gene Number Thershold')
plt.ylabel('Gene Frequency')
plt.show()

## 1.2) PCA

In [None]:
# Apply PCA
pca = PCA(n_components=2)
data_pca = pca.fit_transform(data.transpose())
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])

# Visualize the first two principal components
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', data=df_pca)
plt.title('PCA')
plt.show()

In [None]:
df_pca

In [None]:
from scipy.stats import shapiro

In [None]:
shapiro_test_pc1 = shapiro(df_pca['PC1'])
shapiro_test_pc2 = shapiro(df_pca['PC2'])
print(f"Shapiro-Wilk Test for PC1: {shapiro_test_pc1}")
print(f"Shapiro-Wilk Test for PC2: {shapiro_test_pc2}")

# 2) Data preprocessing  

In [None]:
row_sums = data.sum(axis=1)
filtered_data = data[row_sums >= 20]

In [None]:
filtered_data

In [None]:
data_array_t = np.array(filtered_data.transpose())

In [None]:
data_array_t

In [None]:
data_array_t.shape

In [None]:
data_array_t.shape[1]

In [None]:
# # Normalizing the data with Standard Scaler
# scaler = StandardScaler()
# data_normalized = scaler.fit_transform(data_array_t)

# Convert to PyTorch tensor
data_tensor = torch.tensor(data_array_t, dtype=torch.float32)

# Split into train and test sets
train_data, val_data = train_test_split(data_tensor, test_size=0.2, random_state=12345)

# TensorDataset
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)

# DataLoaders for main training
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Dataloader for overfitting on one sample (for dubbiging purposes)
input_dim = data_array_t.shape[1]
binary_data = torch.tensor(np.random.randint(0, 2, size=(1, input_dim)), dtype=torch.float32)
single_sample_dataset = TensorDataset(binary_data)
single_sample_loader = DataLoader(single_sample_dataset, batch_size=1, shuffle=True)

# Dataloader fot a small subset for overfitting (again, for debugging)
small_subset_indices = np.random.choice(len(train_dataset), size=256, replace=False)
small_subset = Subset(train_dataset, small_subset_indices)
small_loader = DataLoader(small_subset, batch_size=batch_size, shuffle=True)

In [None]:
len(train_dataset)

In [None]:
len(train_loader)

In [None]:
len(train_dataset)

In [None]:
len(val_dataset)

In [None]:
train_data

In [None]:
TensorDataset(train_data)

In [None]:
TensorDataset(torch.tensor(train_data))

In [None]:
print(data_tensor)

# 3) Overfitting on a single sample and small data subset

## 3.1) Overfitting on a single sample

In [None]:
single_sample = torch.randn(1, data_array_t.shape[1])

In [None]:
single_sample.shape[0]

In [None]:
# NO GRADIENT CLIPPING AND SCHEDULER 
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
free_bits = 0.1
n_epochs = 10

model = VAE_single(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Overfitting one sample to see if the model is broken 
model.train()
num_epochs = 1000

# Collecting data for visualisation 
train_loss_vals1 = []
train_loss_vals2 = []
kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / n_epochs
    epoch_kl_divergence = 0
    epoch_kl_divergence_beta = 0 
    
    for data in single_sample_loader:
        data = data[0].to(torch.float)
        # print(data)
        
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())
        
        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        # Backpropagation
        loss.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method1) = {loss.item()}\nLoss (method2) = {loss2.item()}")

    train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(train_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(train_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
epochs = np.linspace(1, 1000, num=1000)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1)
plt.plot(epochs, train_loss_vals1, label='train loss (no KL annelaing)')
plt.scatter(epochs, train_loss_vals2)
plt.plot(epochs, train_loss_vals2, label='train loss using KL annelaing')
plt.ylim(-10, 1000)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("train_loss_comparisons_GS.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta)
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing')
plt.scatter(epochs, kl_divergences_beta)
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling')
plt.xlim(0, 50)
plt.xlabel('KL divergence')
plt.ylabel('Epoch')
plt.legend()
plt.savefig("kl_divergence_comparison_no_GS.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# GRADIENT CLIPPING PLUS SCHEDULER USED 
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
free_bits = 0.1
n_epochs = 10

# Model
model = VAE_single(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Overfitting
model.train()
num_epochs = 1000 

# Gradient clipping and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

# Collecting data for visualisation 
train_loss_vals1 = []
train_loss_vals2 = []
kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / n_epochs
    epoch_kl_divergence = 0
    epoch_kl_divergence_beta = 0 

    for data in single_sample_loader:
        data = data[0].to(torch.float)

        # print(data)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())
        
        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    scheduler.step()  

    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss (method1) = {loss.item()}")
        print(f"Epoch {epoch}: Loss (method2) = {loss2.item()}")

    train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(train_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(train_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1)
plt.plot(epochs, train_loss_vals1, label='train loss (no KL annelaing)')
plt.scatter(epochs, train_loss_vals2)
plt.plot(epochs, train_loss_vals2, label='train loss using KL annelaing')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("train_loss_comparisons_GS.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta)
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing')
plt.scatter(epochs, kl_divergences_beta)
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling')
plt.xlim(0, 50)
plt.xlabel('Epoch')
plt.ylabel('KL divergence value')
plt.legend()
plt.savefig("kl_divergence_comparison_GS.pdf", format="pdf", bbox_inches="tight")
plt.show()

## 3.2) Overfitting on a small train subset

In [None]:
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
free_bits = 0.1

model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
num_epochs = 1000  

train_loss_vals1 = []
# train_loss_vals2 = []
kl_divergences_no_beta = []
# kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / num_epochs
    epoch_kl_divergence = 0
    for data in small_loader:
        data = data[0].to(torch.float)
        # print(data)
    
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())

        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        epoch_kl_divergence += kl_divergence_loss.item()
        # epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        loss = reconstruction_loss + kl_divergence_loss
        # loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        loss.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method1) = {loss.item()}")

    train_loss_vals1.append(loss.item())
    # train_loss_vals2.append(loss2.item())

    kl_divergences_no_beta.append(epoch_kl_divergence / len(train_loader.dataset))
    # kl_divergences_beta.append(epoch_kl_divergence_beta / len(train_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
# Save trained model
torch.save(model.state_dict(), "saved_small_VAE1.pt")
print("Model saved.")

In [None]:
epochs = np.linspace(1, 1000, num=1000)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals1)
plt.plot(epochs, train_loss_vals1)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("train_loss_small_ds1.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# plt.figure(figsize=(10,8))
# plt.scatter(epochs, kl_divergences_no_beta)
# plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing')
# plt.scatter(epochs, kl_divergences_beta)
# plt.plot(epochs, kl_divergences_beta, label = 'KL anneling')
# plt.xlim(0, 50)
# plt.xlabel('Epoch')
# plt.ylabel('KL divergence value')
# plt.legend()
# plt.savefig("kl_divergence_comparison_GS.pdf", format="pdf", bbox_inches="tight")
# plt.show()

In [None]:
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64
beta_start = 0.1
beta_end = 1.0
free_bits = 0.1

model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
num_epochs = 1000  

# train_loss_vals1 = []
train_loss_vals2 = []
# kl_divergences_no_beta = []
kl_divergences_beta = []

for epoch in range(num_epochs):
    beta = beta_start + (beta_end - beta_start) * epoch / num_epochs
    epoch_kl_divergence = 0
    for data in small_loader:
        data = data[0].to(torch.float)
        # print(data)
    
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, data, reduction='sum')
        # print(reconstruction_loss.item())

        kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_divergence_loss_beta = beta * kl_divergence_loss
        # epoch_kl_divergence += kl_divergence_loss.item()
        epoch_kl_divergence_beta += kl_divergence_loss_beta.item()
        
        # Total loss
        # loss = reconstruction_loss + kl_divergence_loss
        loss2 = reconstruction_loss + kl_divergence_loss_beta
        
        loss2.backward()
        optimizer.step()   

    if epoch % 100 == 0:
        print(f"Epoch {epoch}\nLoss (method2) = {loss2.item()}")

    # train_loss_vals1.append(loss.item())
    train_loss_vals2.append(loss2.item())

    # kl_divergences_no_beta.append(epoch_kl_divergence / len(train_loader.dataset))
    kl_divergences_beta.append(epoch_kl_divergence_beta / len(train_loader.dataset))

print(f"Final Loss after {num_epochs} epochs: {loss.item()}")

In [None]:
# Save trained model
torch.save(model.state_dict(), "saved_small_VAE2.pt")
print("Model saved.")

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals2)
plt.plot(epochs, train_loss_vals2)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.legend()
plt.savefig("train_loss_small_d2.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, kl_divergences_no_beta)
plt.plot(epochs, kl_divergences_no_beta, label='no KL annealing')
plt.scatter(epochs, kl_divergences_beta)
plt.plot(epochs, kl_divergences_beta, label = 'KL anneling')
plt.xlabel('Epoch')
plt.ylabel('KL divergence value')
plt.legend()
plt.savefig("kl_divergence_comparison_1_2.pdf", format="pdf", bbox_inches="tight")
plt.show()

# 4) training VAE model on full dataset (train + validation sets)

In [None]:
def train(model, optimizer, scheduler, n_epochs, train_loader, val_loader, beta_start, beta_end, free_bits, max_norm):
    global train_loss_vals 
    train_loss_vals = []
    global train_loss_vals2 
    train_loss_vals2 = []
    global val_loss_vals
    val_loss_vals = []
    train_loss = 0.0
    train_loss2 = 0.0
    val_loss = 0.0
    best_val_loss = float('inf')
    early_stopping_patience = 5
    early_stopping_counter = 0

    for epoch in range(n_epochs):
        beta = beta_start + (beta_end - beta_start) * epoch / n_epochs
        model.train()

        epoch_train_loss = 0.0
        epoch_train_loss2 = 0.0

        for batch in train_loader:
            data = batch[0].to(torch.float)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(data)
            # print('reco_x:', recon_x[:1, :5])
            # print('data:', data[:1, :5])

            # print(recon_x.shape)
            # print(data.shape) 

            reconstruction_loss = nn.functional.binary_cross_entropy(recon_x, data, reduction='sum')
            kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = reconstruction_loss + kl_divergence_loss
            loss2 = reconstruction_loss + (beta * kl_divergence_loss)

            loss.backward()
            
            # Need to read more on gradient clipping 
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            optimizer.step()

            epoch_train_loss += loss.item()
            epoch_train_loss2 += loss2.item()

            # for name, param in model.named_parameters():
            #     if param.grad is not None:
            #         print(f'{name} gradient: {param.grad.abs().mean().item()}') 

        avg_train_loss = epoch_train_loss / len(train_loader.dataset)
        avg_train_loss2 = epoch_train_loss2 / len(train_loader.dataset)
        train_loss_vals.append(avg_train_loss)
        train_loss_vals2.append(avg_train_loss2)

        model.eval()
        epoch_val_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                data = batch[0]
                recon_x, mu, logvar = model(data)
                reconstruction_loss = nn.functional.binary_cross_entropy(recon_x, data, reduction='sum')
                kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                loss = reconstruction_loss + kl_divergence_loss

                epoch_val_loss += loss.item()

        avg_val_loss = epoch_val_loss / len(val_loader.dataset)
        val_loss_vals.append(avg_val_loss)

        scheduler.step()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}:\n"
                  f" Learning Rate: {scheduler.get_last_lr()[0]}\n"
                  f" Train Loss (method 1): {avg_train_loss}\n"
                  f" Train Loss (method 2): {avg_train_loss2}\n"
                  f" Validation Loss: {avg_val_loss}")

        train_loss += avg_train_loss
        train_loss2 += avg_train_loss2
        val_loss += avg_val_loss

        # # Check for early stopping
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     early_stopping_counter = 0
        # else:
        #     early_stopping_counter += 1

        # if early_stopping_counter >= early_stopping_patience:
        #     print("Early stopping triggered")
        #     break

    final_avg_train_loss = train_loss / n_epochs
    final_avg_train_loss2 = train_loss2 / n_epochs
    final_avg_val_loss = val_loss / n_epochs

    print(f"\nFinal Average Training Loss (method 1): {final_avg_train_loss}")
    print(f"Final Average Training Loss (method 2): {final_avg_train_loss2}")
    print(f"Final Average Validation Loss: {final_avg_val_loss}")

    


In [None]:
# Create a smaller subset of the training data
small_subset_indices = np.random.choice(len(train_dataset), size=256, replace=False)
small_subset = Subset(train_dataset, small_subset_indices)
small_loader = DataLoader(small_subset, batch_size=batch_size, shuffle=True)

In [None]:
# Model
input_dim = data_array_t.shape[1]
hidden_dim = 512
latent_dim = 64

model1 = VAE(input_dim, hidden_dim, latent_dim).to(device)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model1.parameters(), lr=1e-5)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

max_norm = 1.0 
beta_start = 0.1
beta_end = 1.0
free_bits = 0.1
n_epochs = 100

train(model=model1, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, free_bits=free_bits, max_norm=max_norm)

In [None]:
# Save trained model
torch.save(model1.state_dict(), "saved_base_VAE.pt")
print("Model saved.")

In [None]:
epochs = np.linspace(1, 100, num=100)

In [None]:
len(train_loss_vals)

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(epochs, train_loss_vals)
plt.plot(epochs, train_loss_vals, label='Train Loss')
plt.scatter(epochs, val_loss_vals)
plt.plot(epochs, val_loss_vals, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig("first_model_train_val_loss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# num_components = 3
# model2 = VAEWithMoGPrior(input_dim, hidden_dim, latent_dim, num_components).to(device)
# optimizer = torch.optim.Adam(model2.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
# # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# train(model=model2, optimizer=optimizer, scheduler=scheduler, n_epochs=n_epochs, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, free_bits=free_bits, max_norm=max_norm)

In [None]:
# # save trained model
# torch.save(model2.state_dict(), "saved_MoG_VAE.pt")
# print("Model saved.")

# Observing the latent spaces of the model(s) fitted

In [None]:
# Function to extract latent variables
def get_latent_variables(model, data_loader, device):
    model.eval()
    latents = []
    with torch.no_grad():
        for batch_idx, (data,) in enumerate(data_loader):
            data = data.to(device)
            mean, logvar = model.encode(data)
            latents.append(mean.cpu().numpy())

    latents = np.concatenate(latents, axis=0)
    return latents

In [None]:
# Trying to get teh latent space
# Get latent variables
latents = get_latent_variables(model1, train_loader, device)

# Apply t-SNE for dimensionality reduction
tsne = TSNE(n_components=2)
latents_2d = tsne.fit_transform(latents)

# Plot the latent space
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latents_2d[:, 0], latents_2d[:, 1], cmap='viridis')
plt.colorbar(scatter)
# plt.xlim(-400, 400)
# plt.ylim(-400, 400)
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.legend()
plt.savefig("latent_space_visualisation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# # Trying to get teh latent space
# # Get latent variables
# latents = get_latent_variables(model2, data_loader, device)

# # Apply t-SNE for dimensionality reduction
# tsne = TSNE(n_components=2)
# latents_2d = tsne.fit_transform(latents)

# # Plot the latent space
# plt.figure(figsize=(10, 8))
# scatter = plt.scatter(latents_2d[:, 0], latents_2d[:, 1], cmap='viridis')
# plt.colorbar(scatter)
# # plt.xlim(-400, 400)
# # plt.ylim(-400, 400)
# plt.xlabel('t-SNE Dimension 1')
# plt.ylabel('t-SNE Dimension 2')
# plt.title('Latent Space Visualization')
# plt.show()

# 5) Reconstruction/generation (evaluation)

In [None]:
# Load trained model 
model = VAE(input_dim, hidden_dim, latent_dim)
model.load_state_dict(torch.load('saved_base_VAE.pt'))  
model.eval()  

# Generate 10 new samples
num_samples = 10 
with torch.no_grad():
    z = torch.randn(num_samples, latent_dim)  # Sample from the standard normal distribution (????)
    generated_samples = model.decode(z).cpu().numpy() 

threshold = 0.5
binary_generated_samples = (generated_samples > threshold).astype(float)

print("Generated samples:\n", binary_generated_samples)
print("Generated samples:\n", generated_samples)


# gridsearch best params

In [None]:
# # gridsearch
# hidden_dim_values = [256, 512, 1024]
# latent_dim_values = [32, 64, 128]
# learning_rate_values = [1e-3, 1e-4, 1e-5]
# beta_start_values = [0.01, 0.1, 0.2]
# beta_end_values = [0.5, 1.0, 2.0]
# free_bits_values = [0.0, 0.1, 0.2]
# max_norm_values = [0.5, 1.0, 2.0]

# # Experiment with different hyperparameter combinations
# for hidden_dim, latent_dim, learning_rate, beta_start, beta_end, free_bits, max_norm in itertools.product(
#     hidden_dim_values, latent_dim_values, learning_rate_values, beta_start_values, beta_end_values, free_bits_values, max_norm_values):
#     print(f"Training with hidden_dim={hidden_dim}, latent_dim={latent_dim}, learning_rate={learning_rate}, beta_start={beta_start}, beta_end={beta_end}, free_bits={free_bits}, max_norm={max_norm}")
#     model = VAE(input_dim, hidden_dim, latent_dim).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#     train1(model, optimizer, scheduler=0, n_epochs=10, train_loader=train_loader, val_loader=val_loader, beta_start=beta_start, beta_end=beta_end, free_bits=free_bits, max_norm=max_norm)