In [14]:
import os
from pprint import pprint
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

print("Importing torch ...")
begin = time.time()
import torch
print("Imported torch in {:.2f} seconds".format(time.time() - begin))

print("Importing torchvision ...")
begin = time.time()
from torchvision import transforms
print("Imported torchvision in {:.2f} seconds".format(time.time() - begin))

# from networks.static_img_primal_dual_nn import StaticImgPrimalDualNN
from metrics import metrics
from data.turtle_data_loading import get_file_paths, TurtleDataset

Importing torch ...
Imported torch in 0.00 seconds
Importing torchvision ...
Imported torchvision in 0.00 seconds


In [26]:
upper_level = ".."

data_path=f"{upper_level}/../data/turtle_id_2022/turtles-data/data"

# test_reg_map_path = "../test_results/test_reg_map"
test_reg_map_path = "test_reg_map"
os.makedirs(test_reg_map_path, exist_ok=True)

test_scalar_reg_path = "../test_results/test_scalar_reg"

# "T": 128, "train_num_samples": 500, "val_num_samples": 50, "resize_square": 256, "sigmas": "[0.05, 0.1, 0.15, 0.2, 0.25, 0.3]", "init_filters": 32, "n_blocks": 3,
model_path=f"{upper_level}/tmp_2/model-turtle_id_2022-2024_06_21_14_31_50/model_epoch_30.pt"

size = 256
num_samples = 5 # 0 for all

sigmas = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]

default_device = "cuda"
torch.set_default_device(default_device)

In [16]:
file_paths:dict = get_file_paths(data_path, "test", num_samples, sigmas, size)

Loading original image paths in images_crop_resize_256_greyscale 

100%|██████████| 5/5 [00:00<00:00, 94466.31it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_256_greyscale_noisy_0_05 

100%|██████████| 5/5 [00:00<00:00, 95760.37it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_256_greyscale_noisy_0_1 

100%|██████████| 5/5 [00:00<00:00, 33989.50it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_256_greyscale_noisy_0_15 

100%|██████████| 5/5 [00:00<00:00, 100342.20it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_256_greyscale_noisy_0_2 

100%|██████████| 5/5 [00:00<00:00, 98457.84it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_256_greyscale_noisy_0_25 

100%|██████████| 5/5 [00:00<00:00, 103307.98it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_256_greyscale_noisy_0_3 

100%|██████████| 5/5 [00:00<00:00, 43965.45it/s]


In [17]:
# file_paths is a dict with keys the sigmas and values the list of file paths
# Check that every list has the same length
for key in file_paths.keys():
    assert len(file_paths[key]) == num_samples

In [18]:
# pprint(file_paths)

In [19]:
test_dataset = TurtleDataset(data_path, file_paths, default_device)

Loading original images 

100%|██████████| 5/5 [00:00<00:00, 93.78it/s]


Loading noisy images sigma=0.05 

100%|██████████| 5/5 [00:00<00:00, 147.44it/s]


Loading noisy images sigma=0.1 

100%|██████████| 5/5 [00:00<00:00, 131.32it/s]


Loading noisy images sigma=0.15 

100%|██████████| 5/5 [00:00<00:00, 133.53it/s]


Loading noisy images sigma=0.2 

100%|██████████| 5/5 [00:00<00:00, 138.85it/s]


Loading noisy images sigma=0.25 

100%|██████████| 5/5 [00:00<00:00, 144.84it/s]


Loading noisy images sigma=0.3 

100%|██████████| 5/5 [00:00<00:00, 161.94it/s]


In [20]:
assert len(test_dataset) == num_samples * len(sigmas), f"len(test_dataset)={len(test_dataset)} != {num_samples} * {len(sigmas)}"

In [21]:
model = torch.load(model_path)
model.eval()

StaticImagePrimalDualNN(
  (pdhg): PDHG(
    (GradOps): GradOperators()
    (ClipAct): ClipAct()
  )
  (cnn): UNet3d(
    (c0x0): DoubleConv(
      (conv_block): Sequential(
        (0): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
        (2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
        (3): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (encoder): ModuleList(
      (0): EncodeBlock3d(
        (pool): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=(0, 0, 0), dilation=1, ceil_mode=False)
        (double_conv): DoubleConv(
          (conv_block): Sequential(
            (0): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
            (1): LeakyReLU(negative_slope=0.01, inplace=True)
            (2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
            (3): LeakyReLU(negative_

In [22]:
def save_npy(tensor, path):
    np.save(path, tensor.detach().cpu().numpy())

In [23]:
def test_reg_map():
    for i in tqdm(range(len(test_dataset))):
        
        sample_id = i % num_samples
        sigma = sigmas[i // num_samples]
        file = file_paths[sigma][sample_id]
        extension = file.split(".")[-1]
        file = file.replace(f".{extension}", "")
        # print(f"Sigma: {sigma}")
        # sigma_str = f"{sigma:.2f}".replace(".", "_")
        result_path = f"{test_reg_map_path}/{file}"
        
        if os.path.exists(result_path):
            # print(f"Path {result_path} already exists. Skipping ...")
            continue
        
        os.makedirs(result_path, exist_ok=True)
        noisy_4d, clean_4d = test_dataset[i]
        noisy_5d = noisy_4d.unsqueeze(0)
        clean_5d = clean_4d.unsqueeze(0)
        reg_map = model.cnn(noisy_5d)
        # # reg_map = model.cnn(noisy_4d)
        # print(f"Shape of noisy image: {noisy_5d.shape}")
        # denoised_5d = model(noisy_5d, reg_map)
        # # denoised_5d = model(noisy_4d, reg_map)
        
        denoised_5d = model(noisy_5d)
        denoised_4d = denoised_5d.squeeze(0)
        
        mse, psnr, ssim = metrics.compare(clean_4d, denoised_4d)
        with open(f"{result_path}/metrics.csv", "w") as f:
            f.write(f"MSE,PSNR,SSIM\n{mse},{psnr},{ssim}")
        # print(f"MSE: {mse}, PSNR: {psnr}, SSIM: {ssim}")
        
        reg_map_path = f"{result_path}/reg_map.npy"
        save_npy(reg_map, reg_map_path)
        denoised_path = f"{result_path}/denoised.npy"
        save_npy(denoised_5d, denoised_path)
        denoised_PIL = transforms.ToPILImage()(denoised_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
        denoised_PIL.save(f"{result_path}/denoised.png")
        
        noisy_PIL = transforms.ToPILImage()(noisy_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
        noisy_PIL.save(f"{result_path}/noisy.png")
        
        clean_PIL = transforms.ToPILImage()(clean_5d.squeeze(-1).squeeze(0).squeeze(0).detach().cpu())
        clean_PIL.save(f"{result_path}/clean.png")
        
        # plt.figure(figsize=(15, 5))
        
        # plt.subplot(1, 3, 1)
        # plt.imshow(noisy_PIL, cmap="gray")
        
        # plt.subplot(1, 3, 2)
        # plt.imshow(denoised_PIL, cmap="gray")
        
        # plt.subplot(1, 3, 3)
        # plt.imshow(clean_PIL, cmap="gray")
        
        # plt.show();
    

In [24]:
test_reg_map()

  0%|          | 0/30 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: '../test_results/test_reg_map'

In [None]:
def compare_to_scalar_reg():
    for i in tqdm(range(len(test_dataset))):
        
        sample_id = i % num_samples
        sigma = sigmas[i // num_samples]
        file = file_paths[sigma][sample_id]
        extension = file.split(".")[-1]
        file = file.replace(f".{extension}", "")
        

SyntaxError: invalid syntax (1716585453.py, line 10)