## Finding the best images from the Generative VAE model

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms
import torch.nn.functional as F
from livelossplot import PlotLosses

from torch.utils.data import Dataset, DataLoader, random_split, Subset
import os
from PIL import Image
import torchvision
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms.functional import to_tensor
import torch.distributions as dist
from tqdm import tqdm

Matplotlib is building the font cache; this may take a moment.


In [2]:
from src.generative.model import VAE

In [2]:
# Loading in the Observation Data
obs_dataset = 'Data/Ferguson_fire_obs.npy' 
obs_dataset = np.load(obs_dataset)

In [3]:
# # Model Code for the ghost structure
# class VAE(nn.Module):
#     def __init__(self, input_image_dims, latent_dims, hidden_layers, activation=nn.ReLU, device="cpu"):
#         super().__init__()

#         # inputs.
#         self.input_image_dims = input_image_dims
#         self.c, self.h, self.w = input_image_dims
#         self.hidden_layers = hidden_layers
#         self.latent_dims = latent_dims
#         self.device = device
#         self.activation = activation
#         self.distribution = torch.distributions.Normal(0, 1)

#         # encoder layers.
#         modules = []
#         previous_dim = self.c * self.h * self.w
#         for h_dim in hidden_layers:
#             modules.append(nn.Linear(previous_dim, h_dim))
#             modules.append(activation())
#             previous_dim = h_dim
#         self.encoder = nn.Sequential(*modules)

#         self._mu = nn.Linear(hidden_layers[-1], self.latent_dims)
#         self._logvar = nn.Linear(hidden_layers[-1], self.latent_dims)

#         # decoder layers.
#         modules = []
#         current_dim = self.latent_dims
#         for h_dim in reversed(hidden_layers):
#             modules.append(nn.Linear(current_dim, h_dim))
#             modules.append(activation())
#             current_dim = h_dim
        
#         modules.append(nn.Linear(hidden_layers[0], self.c * self.h * self.w))
#         modules.append(nn.Sigmoid())
#         self.decoder = nn.Sequential(*modules)

#     def encode(self, x):
#         """"""
#         return self.encoder(x)
    
#     def decode(self, x):
#         """"""
#         return self.decoder(x)   

#     def sample_latent_space(self, mu, logvar):
#         """"""
#         sigma = torch.exp(0.5 * logvar)  # stability trick.
#         z = mu +  sigma * self.distribution.sample(mu.shape).to(self.device)
#         kl_div = (sigma**2 + mu**2 - torch.log(sigma) - 0.5).sum()
#         return z, kl_div

#     def forward(self, x):
#         """"""
#         encoded = self.encode(x.view(x.size(0), -1))  # make sure its 1D.
        
#         # get mu and logvar from latent space.
#         mu = self._mu(encoded)
#         logvar = self._logvar(encoded)
        
#         # reparamaterise trick.
#         z, kl_div = self.sample_latent_space(mu, logvar)

#         decoded = self.decode(z).view(-1, self.c, self.h, self.w)
        
#         return decoded, kl_div

In [11]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = VAE(input_image_dims=(1, 256, 256),
            hidden_layers=[512, 256, 128],
            latent_dims=16,
            activation=nn.ReLU,
            device=device).to(device)

In [13]:
# Load in the model
model_path = 'VAE_1024_128_bs32_lr001_ld16.pt'

# Load the model state dictionary with map_location to handle the CPU-only environment
state_dict = torch.load(model_path, map_location=torch.device('cpu'))

# Load state dictionary into the model
model.load_state_dict(torch.load(model_path))

# Print the model to confirm it's loaded correctly
print(model)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [22]:
# Load in the model
model_path = 'VAE_1024_128_bs32_lr001_ld16.pt'

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = VAE(input_image_dims=(1, 256, 256),
            hidden_layers=[1024, 128],
            latent_dims=16,
            activation=nn.ReLU,
            device=device).to(device)

# Load the model state dictionary with map_location to handle the CPU-only environment
state_dict = torch.load(model_path, map_location=device)

# # Load state dictionary into the model
# model.load_state_dict(state_dict)

model.load_state_dict(
        torch.load(model_path, map_location=torch.device(device))["model_state_dict"]
    )

# Print the model to confirm it's loaded correctly
print(model)

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=65536, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=128, bias=True)
    (3): ReLU()
  )
  (_mu): Linear(in_features=128, out_features=16, bias=True)
  (_logvar): Linear(in_features=128, out_features=16, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1024, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=65536, bias=True)
    (5): Sigmoid()
  )
)


In [None]:
from sklearn.metrics import mean_squared_error
# Define the data_assimilation function
def best_obs_mse_image(autoencoder, obs_dataset_path, num_generated=500, latent_dim=32, device='cpu'):
    """
    Perform data assimilation using a pre-trained VAE.
 
    Args:
        autoencoder (nn.Module): The pre-trained VAE model.
        background_dataset_path (str): Path to the file containing background images.
        num_generated (int): Number of images to generate from the latent space.
        latent_dim (int): Dimension of the latent space.
        device (str): The device to run the model on ('cpu' or 'cuda').
 
    Returns:
        lowest_mse (float): The lowest MSE value found.
        best_generated_image (np.ndarray): The generated image with the lowest MSE.
        best_background_image (np.ndarray): The background image with the lowest MSE.
        best_background_index (int): The index of the background image with the lowest MSE.
    """
    autoencoder = autoencoder.to(device)
    autoencoder.eval()
 
    # Load the background images
    obs_images = np.load(obs_dataset_path)
    obs_images = obs_images.squeeze()  # Ensure images have correct dimensions
 
    # Generate images from latent space
    z = torch.randn(num_generated, latent_dim).to(device)
    with torch.no_grad():
        generated_images = autoencoder.decoder(z).cpu().numpy()
 
    # Initialize variables to store the lowest MSE and corresponding images
    lowest_mse = float('inf')
    best_generated_image = None
    best_obs_image = None
    best_obs_index = None
 
    # Iterate through generated images and compare with background images
    for i in range(num_generated):
        generated_image = generated_images[i].squeeze()
 
        for j in range(len(obs_images)):
            obs_image = obs_images[j].squeeze()
 
            # Compute MSE
            mse = mean_squared_error(obs_image, generated_image)
 
            # Update the lowest MSE and corresponding images
            if mse < lowest_mse:
                lowest_mse = mse
                best_generated_image = generated_image
                best_obs_image = obs_image
                best_obs_index = j
 
    # Plot the best matching images
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(best_generated_image, cmap='viridis')
    ax[0].set_title("Best Generated Image")
    ax[0].axis('off')
 
    ax[1].imshow(best_obs_image, cmap='viridis')
    ax[1].set_title("Best Observation Image")
    ax[1].axis('off')
 
    plt.show()
 
    return lowest_mse, best_generated_image, best_obs_image, best_obs_index