In [None]:
import sys
import numpy as np
import torch
import torch.optim as optim

import matplotlib.pyplot as plt

sys.path.insert(0, "/home/simon/Documents/scripts/VIEWS_FAO_index/src/architectures/")

from vea_001 import VAE, vae_loss

In [None]:
# Create a random 2D array
# np_array = np.random.rand(128, 128).astype(np.float32)

space_dim = 128
month_dim = 100

np_array = np.random.lognormal(size=(month_dim, space_dim, space_dim)).astype(np.float32)

# Convert to PyTorch tensor
synth_tensor = torch.from_numpy(np_array)

# Reshape tensor to have batch dimension
synth_tensor = synth_tensor.unsqueeze(0)

In [None]:

# Set device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a VAE instance
vae = VAE()

# Move the VAE to the chosen device
vae = vae.to(device)

# Define the optimizer
optimizer = optim.Adam(vae.parameters(), lr=0.001)

# Define the loss function
loss_fn = vae_loss

# Number of training epochs
num_epochs = 100


In [4]:
for epoch in range(num_epochs):
   # Set the VAE to training mode
   vae.train()

   # Sum of training loss for this epoch
   train_loss = 0.0

   # Iterate over the training data
   for month in range(month_dim):
       # Move data to the chosen device
       synth_sub_tensor = synth_tensor[:,month,:,:].unsqueeze(1).to(device) # unsqueeze keeps the channel dimension

       # reshape to 1D bc it's not yet a convolutional vae
       #synth_sub_tensor = synth_sub_tensor.view(1, -1)

       # Zero the gradients
       optimizer.zero_grad()

       # Forward pass
       outputs, mean, logvar = vae(synth_sub_tensor)

       # Compute loss
       loss = loss_fn(outputs, synth_sub_tensor, mean, logvar)

       # Backward pass and optimization
       loss.backward()
       optimizer.step()

       # Accumulate training loss
       train_loss += loss.item()

   # Compute average training loss for this epoch
   train_loss /= month_dim

   # Print training loss for every 10 epocha
   if (epoch + 1) % 10 == 0:
       print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}")


In [None]:
#   # Set the VAE to evaluation mode
#   vae.eval()
#
#   # Sum of validation loss for this epoch
#   val_loss = 0.0
#
#   # Iterate over the validation data
#   for data in val_loader:
#       # Move data to the chosen device
#       data = data.to(device)
#
#       # Forward pass
#       outputs, mu, sigma = vae(data)
#
#       # Compute loss
#       loss = loss_fn(outputs, data, mu, sigma)
#
#       # Accumulate validation loss
#       val_loss += loss.item()
#
#   # Compute average validation loss for this epoch
#   val_loss /= len(val_loader)
#
#   # Print validation loss for this epoch
#   print(f"Validation Loss: {val_loss:.4f}\n")
#
#   # Save the model if validation loss decreases
#   if val_loss < best_val_loss:
#       best_val_loss = val_loss
#       torch.save(vae.state_dict(), 'best_vae.pth')
#

In [None]:
synth_test_tensor = synth_tensor[:,0,:,:].unsqueeze(1).to(device) # unsqueeze keeps the channel dimension

# reshape to 1D bc it's not yet a convolutional vae
#synth_test_tensor_1d = synth_test_tensor_2d.view(1, -1)

# Set the VAE to evaluation mode
vae.eval()

# Forward pass
output, mu, sigma = vae(synth_test_tensor)

# change outout back to 2D
output = output.view(128, 128) 

In [None]:
# Visualize the original array and the pooled array
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(synth_test_tensor.squeeze().cpu(), interpolation='nearest', cmap='viridis')
ax[0].set_title("Original")
ax[1].imshow(output.detach().cpu().numpy(), interpolation='nearest', cmap='viridis')
ax[1].set_title("Reconstruction")