In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

from models import ConvAutoencoder
from utils import Config, save_output, weighted_mse

In [2]:
# Initialise config
config = Config()

In [3]:
# Preprocess data
input_data = np.load(config.input_path)
input_data = [((i - np.min(input_data)) / (np.max(input_data) - np.min(input_data))) for i in input_data]
train_array, test_array = train_test_split(input_data, test_size=config.test_size, shuffle=True)
train_tensor = torch.Tensor(train_array)
test_tensor = torch.Tensor(test_array)

train_ds = torch.utils.data.TensorDataset(train_tensor)
test_ds = torch.utils.data.TensorDataset(test_tensor)
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

  train_tensor = torch.Tensor(train_array)


In [4]:
# Initialise model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
model = ConvAutoencoder(config.z_dim)
model.to(device)

optimizer = torch.optim.RMSprop(model.parameters(), config.lr)

Device:  cuda


In [None]:
train_loss_values = []
test_loss_values = []

scaler = torch.cuda.amp.GradScaler() 

for epoch in range(config.n_epochs):
    print(f"Epoch {epoch + 1}")
    running_train_loss = 0
    running_test_loss = 0

    model.train()
    optimizer.zero_grad()

    for train_idx, (inputs,) in enumerate(train_dl):

        inputs = inputs.view(-1, 1, 534, 1200).to(device)

        with torch.cuda.amp.autocast():  
            outputs = model(inputs)
            loss = weighted_mse(inputs, outputs, config.weighting_parameter)

        scaler.scale(loss).backward()  

        scaler.step(optimizer)
        scaler.update()

        optimizer.zero_grad()  
        running_train_loss += loss.item()

    model.eval()

    with torch.no_grad():

        for test_idx, (inputs,) in enumerate(test_dl):

            inputs = inputs.view(-1, 1, 534, 1200).to(device)

            with torch.cuda.amp.autocast():  
                outputs = model(inputs)
                loss = weighted_mse(inputs, outputs, config.weighting_parameter)

            running_test_loss += loss.item()

    train_loss_values.append(running_train_loss / train_idx)
    test_loss_values.append(running_test_loss / test_idx)

In [None]:
save_output(config, model)

In [None]:
plt.plot(range(len(train_loss_values)), train_loss_values, label="Training loss")
plt.plot(range(len(test_loss_values)), test_loss_values, label="Test loss")
plt.title(f"lr = {config.lr}, bs = {config.batch_size}, num_epochs = {config.n_epochs}")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()
plt.savefig(config.output_directory + "loss_plot.png")
plt.show()

print(f"Final training loss: {train_loss_values[-1]}")
print(f"Final test loss: {test_loss_values[-1]}")

In [None]:
with torch.no_grad():
    input_tensor = train_tensor[1].view(1, 1, 534, 1200).to(device)
    example_output = model(input_tensor).to(device)
    example_output = example_output.reshape(534,1200)
    example_output = example_output.cpu().detach().numpy()
    plt.imshow(example_output, cmap="jet")
    plt.show()