# Evaluating the performance of the deinterlacing network

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.models import initialize_model_from_config
from src.datasets import initialize_dataset_from_config
from src.utils.common import remove_axis

checkpoint_dir_path = r'pretrained_checkpoint'

# Load the model checkpoint
model = initialize_model_from_config(checkpoint_dir_path+'\model_config.yaml')
# Initialize the dataset
dataset = initialize_dataset_from_config(checkpoint_dir_path+'\dataset_config.yaml', train=True)
# Construct a dataloader iterator
dl = dataset.get_dataloader(batch_size=1, num_workers=0)
it = iter(dl)

# # The number of samples to process and display
N = 3

fig, axes = plt.subplots(N, 3, figsize=(14, 3*N))
for n in range(N):
    # Get sample from dataloader
    fields, gt = next(it)
    # Process using the network
    network = model.forward(fields).detach()

    torch.clamp_(network, 0, 1)

    num_channels = model.config.num_color_channels

    start_index, end_index = model.get_middle_field_channel_indices(
        fields,
        num_channels
    )

    # Also perform plain linear interpolation
    interp = model._linear_row_interpolation(fields[:, start_index:end_index])

    # Construct complete images from known and estimated fields
    im_true = model.sample_to_im((fields, gt), num_channels)[100:200, 100:300, :]
    im_network = model.sample_to_im((fields, network), num_channels)[100:200, 100:300, :]
    im_interp = model.sample_to_im((fields, interp), num_channels)[100:200, 100:300, :]

    # Compute the MSE
    error_network = np.mean((im_true-im_network)**2)
    error_interpolation = np.mean((im_true-im_interp)**2)

    # Plot the results
    axes[n, 0].imshow(im_network)
    fontweight = 'bold' if error_interpolation > error_network else 'normal'
    axes[n, 0].set_title(f'network mse: {error_network:.6e}', fontweight=fontweight)

    axes[n, 1].imshow(im_interp)
    fontweight = 'bold' if error_interpolation < error_network else 'normal'
    axes[n, 1].set_title(f'linear interpolation mse: {error_interpolation:.6e}', fontweight=fontweight)

    axes[n, 2].imshow(im_true)
    axes[n, 2].set_title('true')

# Clean up the plot
remove_axis(axes)
plt.tight_layout()