In [None]:
import os
from pprint import pprint
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import img2pdf

print("Importing torch ...")
begin = time.time()
import torch
print("Imported torch in {:.2f} seconds".format(time.time() - begin))

print("Importing torchvision ...")
begin = time.time()
from torchvision import transforms
print("Imported torchvision in {:.2f} seconds".format(time.time() - begin))

# from networks.static_img_primal_dual_nn import StaticImgPrimalDualNN
from metrics import metrics
from data.turtle_data_loading import get_file_paths, TurtleDataset

In [None]:
default_device = "cuda"
torch.set_default_device(default_device)

In [None]:
from turtle_id_test_config import data_path, upper_level

from turtle_id_test_config import test_reg_map_path, test_scalar_reg_path

model_path=f"{upper_level}/models/model-turtle_id_2022-2024_06_21_14_31_50/model_epoch_30.pt"

In [None]:
num_samples = 100 # 0 for all
size = 256
sigmas = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]

In [None]:
file_paths:dict = get_file_paths(data_path, "test", num_samples, sigmas, size)

In [None]:
# file_paths is a dict with keys the sigmas and values the list of file paths
# Check that every list has the same length
for key in file_paths.keys():
    assert len(file_paths[key]) == num_samples

In [None]:
# pprint(file_paths)

In [None]:
test_dataset = TurtleDataset(data_path, file_paths, default_device)

In [None]:
assert len(test_dataset) == num_samples * len(sigmas), f"len(test_dataset)={len(test_dataset)} != {num_samples} * {len(sigmas)}"

In [None]:
model = torch.load(model_path)
model.eval()

In [None]:
def save_npy(tensor, path):
    np.save(path, tensor.detach().cpu().numpy())

In [None]:
def infer_reg_map(file, i):
    result_path = f"{test_reg_map_path}/{file}"
    
    # if os.path.exists(result_path):
    #     # print(f"Path {result_path} already exists. Skipping ...")
    #     return
    
    os.makedirs(result_path, exist_ok=True)
    noisy_4d, clean_4d = test_dataset[i]
    noisy_5d = noisy_4d.unsqueeze(0)
    clean_5d = clean_4d.unsqueeze(0)
    reg_map = model.cnn(noisy_5d)
    # # reg_map = model.cnn(noisy_4d)
    # print(f"Shape of noisy image: {noisy_5d.shape}")
    # denoised_5d = model(noisy_5d, reg_map)
    # # denoised_5d = model(noisy_4d, reg_map)
    
    denoised_5d = model(noisy_5d)
    denoised_4d = denoised_5d.squeeze(0)
    
    mse, psnr, ssim = metrics.compare(clean_4d, denoised_4d)
    with open(f"{result_path}/metrics.csv", "w") as f:
        f.write(f"MSE,PSNR,SSIM\n{mse},{psnr},{ssim}")
    # print(f"MSE: {mse}, PSNR: {psnr}, SSIM: {ssim}")
    
    reg_map_path = f"{result_path}/reg_map.npy"
    save_npy(reg_map, reg_map_path)
    denoised_path = f"{result_path}/denoised.npy"
    save_npy(denoised_5d, denoised_path)
    denoised_PIL = transforms.ToPILImage()(denoised_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
    denoised_PIL.save(f"{result_path}/denoised.png")
    
    noisy_PIL = transforms.ToPILImage()(noisy_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
    noisy_PIL.save(f"{result_path}/noisy.png")
    
    clean_PIL = transforms.ToPILImage()(clean_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
    clean_PIL.save(f"{result_path}/clean.png")
    
    denoised_pdf = f"{result_path}/denoised.pdf"
    with open(denoised_pdf, "wb") as f:
        f.write(img2pdf.convert(f"{result_path}/denoised.png"))
        
    noisy_pdf = f"{result_path}/noisy.pdf"
    with open(noisy_pdf, "wb") as f:
        f.write(img2pdf.convert(f"{result_path}/noisy.png"))
        
    clean_pdf = f"{result_path}/clean.pdf"
    with open(clean_pdf, "wb") as f:
        f.write(img2pdf.convert(f"{result_path}/clean.png"))
    
    # plt.figure(figsize=(15, 5))
    
    # plt.subplot(1, 3, 1)
    # plt.imshow(noisy_PIL, cmap="gray")
    
    # plt.subplot(1, 3, 2)
    # plt.imshow(denoised_PIL, cmap="gray")
    
    # plt.subplot(1, 3, 3)
    # plt.imshow(clean_PIL, cmap="gray")
    
    # plt.show();

In [None]:
def check_scalar_reg(file):
    scalar_reg_path = f"{test_scalar_reg_path}/{file}"
    
    results_csv = f"{scalar_reg_path}/results.csv"
    
    df_results = pd.read_csv(results_csv)
    
    

In [None]:
def test_reg_map():
    for i in tqdm(range(len(test_dataset))):
        
        sample_id = i % num_samples
        sigma = sigmas[i // num_samples]
        file = file_paths[sigma][sample_id]
        extension = file.split(".")[-1]
        file = file.replace(f".{extension}", "")
        
        
        infer_reg_map(file, i)
        
        # check_scalar_reg(file)        
        

In [None]:
test_reg_map()

In [None]:
def compare_to_scalar_reg():
    for i in tqdm(range(len(test_dataset))):
        
        sample_id = i % num_samples
        sigma = sigmas[i // num_samples]
        file = file_paths[sigma][sample_id]
        extension = file.split(".")[-1]
        file = file.replace(f".{extension}", "")
        