## Objective 1: Build a SURROGATE MODEL using RNN

### Modules

In [1]:
import torch
import numpy as np
from model import models
from model import dataloader
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error

#### Device for training- (cpu, cuda or mps)

In [None]:
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    
elif torch.backends.mps.is_available():
    device = torch.device("mps")

print("Device: ", device)

### Define the type of encoder model, predictor model you want and loss function.

 - Encoders available: 
   - 'PCA'
   - 'Linear'
   - 'CVAE'
   - 'CAE1'
   - 'CAE2' or 'best'


 - Decoders available:
   - 'LSTM0'
   - 'LSTM1'


 - Loss functions available:
   - 'MSE'
   - 'L1Loss'
   - 'BCE'
   - 'MSSSIM'

In [None]:
encoder_model = 'PCA'
predictor_model = 'best'
loss_type = 'BCE'

#### Path to the data

In [None]:
train_data_path = 'data/Ferguson_fire_train.npy'
test_data_path = 'data/Ferguson_fire_test.npy'
obs_data_path = 'data/Ferguson_fire_obs.npy'
background_data_path = 'data/Ferguson_fire_background.npy'

#### Load Encoder

In [None]:
model_encoder = models.load_encoder(encoder_model, device=device)
loss_function = models.load_loss_function(loss_type)

#### Encoder architecture

In [None]:
#Show model details
if isinstance(model_encoder, PCA):
    print(np.cumsum(model_encoder.explained_variance_ratio_)[-1]*100, "%")
else:
    model_encoder.describe()

#### Load train data

In [None]:
train_data = dataloader.load_data(dataset_path=train_data_path, model=model_encoder, batch_size=32, shuffle=True)

#### Train loop (Optional)

In [None]:
# Only for torch.nn models

# losses = models.train_model(epochs=25, model=model_encoder, data_loader=train_data, learning_rate=0.001, loss_function=loss_function, device=device)

if isinstance(model_encoder, PCA):
    print(np.cumsum(model_encoder.explained_variance_ratio_)[-1]*100, "%")
else:
    model_encoder.eval()

#### Visualise the losses

In [None]:
# plt.plot(losses)
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training Loss')
# plt.show()

#### Autoencoder results

#### Load test data 

In [None]:
test_data_array = np.load(test_data_path)
test_data = dataloader.load_data(dataset_path=test_data_path, model=model_encoder, batch_size=10, shuffle=False)

#### Show reconstructed

In [None]:
if isinstance(model_encoder, PCA):
    test_batch = test_data[:5].reshape(-1, 256*256)
    outputs = model_encoder.transform(test_batch)
    outputs = model_encoder.inverse_transform(outputs)
    outputs = outputs.reshape(-1, 256, 256)
    test_batch = test_batch.reshape(-1, 256, 256)
    print(outputs.shape)
else:
    test_loader = iter(test_data)
    test_batch = next(test_loader)
    test_batch = test_batch[0]

    with torch.no_grad():
        outputs = model_encoder(test_batch.to(device))
        outputs = outputs.cpu().numpy()
        test_batch = test_batch.cpu().numpy()

# Plotting the results
fig, ax = plt.subplots(2, 5, figsize=(20, 10))
for i in range(5):
    ax[0, i].imshow(test_batch[i], cmap='magma')
    ax[1, i].imshow(outputs[i], cmap='magma')
plt.show()


#### Compress training data into latent space

In [None]:
if isinstance(model_encoder, PCA):
    train_latent = model_encoder.transform(train_data.reshape(train_data.shape[0], -1))
else:
    train_latent = np.zeros((1, 64))
    with torch.no_grad():
        for x, _ in train_data:
            data = np.array(model_encoder.encode(x.to(device)).cpu().numpy())
            train_latent = np.concatenate((train_latent, data), axis=0)
    train_latent = train_latent[1:, :]
    print(train_latent.shape)

#### Load predictor(LSTM) model

In [None]:
input_size = outputs_size = train_latent.shape[-1]

model_predictor = models.load_predictor('best', model_encoder=model_encoder, input_size=input_size, output_size=outputs_size, device=device)


#### Train predictor on latent space data (Optional)

In [None]:
# Define the loss function for predictor

loss_function = models.load_loss_function('MSE')

#Change hyperparameters based on compression method
if isinstance(model_encoder, PCA):
    learning_rate = 0.001
    num_epochs = 200
else:
    learning_rate = 0.0001
    num_epochs = 600

# Load the latent space data into a dataloader
train_latent_loader = dataloader.load_data(dataset_path=train_latent, model=model_predictor, batch_size=32, shuffle=False)

In [None]:
losses = models.train_model(epochs=num_epochs, model=model_predictor, data_loader=train_latent_loader, learning_rate=learning_rate, loss_function=loss_function, device=device)
model_predictor.eval()

#### Visualize the training loss    

In [None]:
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

#### Test prediction model on test data:

In [None]:
def test_data_performance(model_predictor, model_encoder, test_data, timesteps):
    """ Test the performance of the model on the test data 
    
    Args:   
        model_predictor: The trained predictor model
        model_encoder: The trained encoder model
        test_data: The test data
        timesteps: The number of timesteps to use for the prediction
        
        Returns:
            mse_latent: The mean squared error of the latent space
            mse_physical: The mean squared error of the physical space
            
    """

    if isinstance(model_encoder, PCA):
        test_compressed = model_encoder.transform(test_data.reshape(test_data.shape[0], -1))
    else:
        #Create dataloader for memory safety
        test_data_loader = dataloader.load_data(dataset_path=test_data, model=model_encoder, timesteps=timesteps, batch_size=100, shuffle=False);
        test_compressed = np.zeros((1, 64))

        with torch.no_grad():
            for x, _ in test_data_loader:
                data = np.array(model_encoder.encode(x.view(-1, 1, 256, 256).to(device)).cpu().numpy())
                test_compressed = np.concatenate((test_compressed, data), axis=0)

        test_compressed = test_compressed[1:, :]

    # Create test set and targets as in training set
    test_compresed_dataloader = dataloader.load_data(dataset_path=test_compressed, model=model_predictor, timesteps=timesteps, batch_size=100, shuffle=False);

    mse_latent = 0
    mse_physical = 0

    with torch.no_grad():
        for x, y in test_compresed_dataloader:
            output = model_predictor(x.to(device))
            mse_latent += loss_function(output, y.to(device)).item()

            if isinstance(model_encoder, PCA):
                outputs_inverse = model_encoder.inverse_transform(output.cpu().numpy())
                targets_inverse = model_encoder.inverse_transform(y.numpy())
                mse_physical += loss_function(torch.Tensor(outputs_inverse), torch.Tensor(targets_inverse)).item()
            else:
                outputs_inverse = model_encoder.decode(output.view(-1, 1, 64))
                targets_inverse = model_encoder.decode(y.view(-1, 1, 64).to(device))
                mse_physical += loss_function(torch.Tensor(outputs_inverse), torch.Tensor(targets_inverse)).item()


    print(f"Test loss on latent space: {mse_latent}")
    print(f"Test loss on physical space: {mse_physical}")

    return mse_latent, mse_physical 

In [None]:
# Why are we calculating MSE between compressed and then decompressed target and output from LSTM?

### Load test data

In [None]:
test_data = np.load(test_data_path)
test_data_performance(model_predictor, model_encoder, test_data, timesteps=10)

### Predict on background data

In [None]:
# Load the background and observation data

background_data = np.load(background_data_path)
obs_data = np.load(obs_data_path)
background_data.shape, obs_data.shape

In [None]:
def predict_on_background(model_encoder, model_predictor, background_data):
    """ Predict on the background data

    Args:
        model_encoder: The trained encoder model
        model_predictor: The trained predictor model
        background_data: The background data
    
    Returns:
        outputs_inverse: The outputs of the model on the background data
    """
    
    if isinstance(model_encoder, PCA):
        background_compressed = model_encoder.transform(background_data.reshape(background_data.shape[0], -1))
    else:
        background_compressed = model_encoder.encode(torch.from_numpy(np.expand_dims(background_data, axis=1)).float().to(device)).cpu().detach().numpy()

    # Move the inputs and targets to device
    inputs = torch.Tensor(np.expand_dims(background_compressed, axis=1))

    # Forward pass and get the outputs
    outputs = model_predictor(inputs.to(device))

    # Decompress the outputs
    if isinstance(model_encoder, PCA):
        outputs_inverse = model_encoder.inverse_transform(outputs.detach().cpu().numpy()).reshape(-1, 256, 256)
    else:
        outputs_inverse = model_encoder.decode(outputs).detach().cpu().numpy().reshape(-1, 256, 256)

    # Print the loss
    return outputs_inverse

### Predict on the background data

In [None]:
#Perform prediction on background data
background_predict = predict_on_background(model_encoder=model_encoder, model_predictor=model_predictor, background_data=background_data)

# Store the background prediction to perform data assmilations
np.save('data/predictions_lstm.npy', background_predict[:-1])

In [None]:
def plot_background_obs(backgound_data, background_predict, obs_data):
    """ Plot the background, observation and prediction data

    Args:
        backgound_data: The background data
        background_predict: The background prediction
        obs_data: The observation data
    """
    
    # Set same colorbar range for all plots
    vmin = 0
    vmax = 1
    # Create subplots and set colorbar
    _, axes = plt.subplots(3, 6, figsize=(20, 10))

    for i in range(backgound_data.shape[0]):
        cax = axes[0, i].imshow(backgound_data[i], vmin=vmin, vmax=vmax)
        axes[0, i].set_title(f"Background {i+1}")
        axes[1, i+1].imshow(background_predict[i], vmin=vmin, vmax=vmax)
        axes[1, i+1].set_title(f"Prediction {i+1}")
        axes[2, i].imshow(obs_data[i], vmin=vmin, vmax=vmax)
        axes[2, i].set_title(f"Observation {i+1}")
    plt.colorbar(cax, ax=axes.ravel().tolist())
    # Do not show axis 
    for ax in axes.ravel():
        ax.axis('off')
    
    plt.show()

####    Plot the background, prediction and observation data

In [None]:
plot_background_obs(background_data, background_predict, obs_data)

#### Save results for analysis 

In [None]:
# MSE between the backgound and the their corresponding predictions
loss = loss_function(torch.Tensor(background_predict[:4]), torch.Tensor(background_data[1:]))
loss.item()

# Write results to file
with open('results_lstm.txt', 'a') as f:
    if isinstance(model_encoder, PCA):
        f.write(f"Loss on prediction on background data for PCA: {loss.item()}\n")
    else:
        f.write(f"Loss on prediction on background data for Autoencoder: {loss.item()}\n")

### What did we learn:




1. We started with ConvolutionLSTM, it gave good results comparatively. However, the training was very expensive due to limited hardware. Therefore we adopted Reduced Order Modelling. 
2. Using an autoencoder to compress images and then using an LSTM for prediction in latent space. This was quick to train. However, the error propagated at every step from one model to another.
3. Ultimately, we used PCA since it gave best results and the data does have non-linearity but not too high. It is able to represent the data within managable PC's for LSTM's latent space.
3. The advantage of this approach is it can scale with the data size(image resolution or number of images). Since, the LSTM predicts in latent space.
4. We experimented with various loss functions for encoder, in our testing use of perception index loss(MSSSIM) gave good reconstruction of the data. But, it did notshow good results when trained for low epochs.
5. We explored using different loss function, because the given dataset has most of the pixel values as 0, so however bad the model performed the MSE loss was less than 1. This also had an affect on reconstructed images. The images wildfire images did not look like the wildfire scenarios anymore. That's why we used MSSIM perception index so that it can at least replicate the real scenarios. 