Notebook showing the performance of N2N4M on an unseen testing set of data with synthetic noise added.
CoTCAT [1] is used as a benchmark for performance comparison.  

1. Bultel B, Quantin C, Lozac’h L. Description of CoTCAT (Complement to CRISM Analysis Toolkit). IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing. 2015 Jun;8(6):3039–49. 

In [1]:
# Standard Imports
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import TensorDataset, DataLoader

# Internal imports
import n2n4m.preprocessing as preprocessing
from n2n4m.wavelengths import PLEBANI_WAVELENGTHS
from n2n4m.model import Noise2Noise1D
from n2n4m.model_functions import predict
from n2n4m.cotcat_denoise import cotcat_denoise
from n2n4m.n2n4m_denoise import instantiate_default_model

In [2]:
PACKAGE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
DATA_DIR = os.path.join(PACKAGE_DIR, "data")

In [3]:
BATCH_SIZE = 1000  # If you have memory issues, reduce this number
NUM_BLAND_PIXELS = 150_000  # How many bland pixels to add to the dataset

In [4]:
def set_seed(seed):
    """
    Use this to set ALL the random seeds to a fixed value and take out any randomness from cuda kernels
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False  # uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms.
    torch.backends.cudnn.enabled = False

    return True


set_seed(42)  # Fix seed for reproduceability.

True

Preprocessing

In [5]:
# Read the data.
mineral_dataset_path = os.path.join(
    DATA_DIR, "extracted_mineral_pixel_data", "mineral_pixel_dataset.json"
)
bland_dataset_path = os.path.join(
    DATA_DIR, "extracted_bland_pixel_data", "bland_pixel_dataset.json"
)
mineral_dataset = preprocessing.load_dataset(mineral_dataset_path)
bland_dataset = preprocessing.load_dataset(bland_dataset_path)

# Get as many bland pixels from the bland pixel set as desired.
# Sample equally from each image of bland pixels.
num_bland_images = bland_dataset["Image_Name"].nunique()
samples_per_image = NUM_BLAND_PIXELS // num_bland_images
bland_dataset_sample = (
    bland_dataset.groupby("Image_Name")
    .apply(lambda x: x.sample(min(len(x), samples_per_image), random_state=42))
    .reset_index(drop=True)
)

# Combine the bland and mineral datasets, then apply all preprocessing steps.
dataset = pd.concat(
    [mineral_dataset, bland_dataset_sample], ignore_index=True
).reset_index(drop=True)
dataset = preprocessing.expand_dataset(dataset)
dataset = preprocessing.drop_bad_bands(dataset, bands_to_keep=PLEBANI_WAVELENGTHS)
dataset = preprocessing.impute_bad_values(dataset, threshold=1)
dataset = preprocessing.impute_atmospheric_artefacts(
    dataset, wavelengths=PLEBANI_WAVELENGTHS
)
noise_dataset = preprocessing.generate_noisy_pixels(dataset.iloc[:, 3:], random_seed=42)
input_target_dataset = pd.concat([dataset, noise_dataset], axis=1)
train_set, test_set = preprocessing.train_test_split(
    input_target_dataset, bland_pixels=True
)
train_set, validation_set = preprocessing.train_validation_split(
    train_set, bland_pixels=True
)

# Split the training, validation, and testing sets.
X_train, y_train, ancillary_train = preprocessing.split_features_targets_anciliary(
    train_set
)
X_test, y_test, ancillary_test = preprocessing.split_features_targets_anciliary(
    test_set
)
X_validation, y_validation, ancillary_validation = (
    preprocessing.split_features_targets_anciliary(validation_set)
)

# Fit a scaler to the training data, and then apply it to the validation and test data.
X_train, feature_scaler = preprocessing.standardise(X_train, method="RobustScaler")
X_test, _ = preprocessing.standardise(
    X_test, method="RobustScaler", scaler=feature_scaler
)
X_validation, _ = preprocessing.standardise(
    X_validation, method="RobustScaler", scaler=feature_scaler
)

X_test_tensor = torch.from_numpy(X_test.values).float()
y_test_tensor = torch.from_numpy(y_test.values).float()

test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Load model

In [28]:
state_dict = torch.load(
    os.path.join(PACKAGE_DIR, "n2n4m", "data", "trained_model_weights.pt"),
    map_location=torch.device("cpu"),
)
# If the model was trained on multiple GPUs, the keys will have "module." in them. As we are running inference only on CPU, we need to remove this.
if "module." in list(state_dict.keys())[0]:
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model = Noise2Noise1D(kernel_size=5, depth=3, num_blocks=4, num_input_features=350)
model.load_state_dict(state_dict)

<All keys matched successfully>

Run Inference

In [29]:
N2N4M_test_set_predictions = predict(model, test_loader, device="cpu")

In [30]:
loss_func = nn.MSELoss()
N2NHD_test_loss = loss_func(
    N2N4M_test_set_predictions, torch.from_numpy(y_test.values)
).item()
print(f"MSE for N2N4M on test set: {N2NHD_test_loss}")

MSE for N2N4M on test set: 4.664849076300921e-06


CoTCAT performance

In [32]:
X_test_unstandardised = feature_scaler.inverse_transform(
    X_test
)  # Back to the original scale
X_test_unstandardised = X_test_unstandardised.reshape(
    21, -1, len(PLEBANI_WAVELENGTHS)
)  # Reshape to be 3D for the cotcat_denoise function

In [33]:
cotcat_test_set_predictions = cotcat_denoise(
    X_test_unstandardised, wavelengths=PLEBANI_WAVELENGTHS
)

In [34]:
cotcat_test_set_predictions = cotcat_test_set_predictions.reshape(
    -1, 350
)  # Reshape to be 2D for the loss function
cotcat_test_loss = loss_func(
    torch.from_numpy(cotcat_test_set_predictions), torch.from_numpy(y_test.values)
).item()
print(f"MSE for cotcat on test set: {cotcat_test_loss}")

MSE for cotcat on test set: 4.951200574360762e-06
