### Imports

In [1]:
first_time_importing_torch = True

In [2]:
import os

# NOTE: Importing torch the first time will always take a long time!
import time
# NOTE: Importing torch the first time will always take a long time!
if first_time_importing_torch:
    print(f"Importing torch ...")
    import_torch_start_time = time.time() 
import torch
if first_time_importing_torch:
    import_torch_end_time = time.time()
    print(f"Importing torch took {import_torch_end_time - import_torch_start_time} seconds")
    first_time_importing_torch = False

import torch.nn as nn
from torch.utils.data import WeightedRandomSampler
import torch.nn.functional as F
from torch.utils.data import Dataset

# from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

from skimage.metrics import structural_similarity
# from skimage.metrics import peak_signal_noise_ratio

from functools import partial

import numpy as np
import matplotlib.pyplot as plt
import datetime

from PIL import Image

# Optional
from tqdm import tqdm # progress bar

import wandb # Optional, for logging

import json
import yaml

Importing torch ...
Importing torch took 9.568702697753906 seconds


In [3]:
!which python
print(f"Torch version: {torch.__version__}")
print(f"Path: {os.getcwd()}")

/mnt/c/Users/t/Documents/GIT/DISSERTATION/LearningRegularizationParameterMaps/venv/bin/python
Torch version: 2.3.0+cu121
Path: /mnt/c/Users/t/Documents/GIT/DISSERTATION/LearningRegularizationParameterMaps/chest_xray


In [4]:
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

In [5]:
# DISABLING_TESTS = False
DISABLING_TESTS = True   # Disable tests for less output

### Use GPU

In [6]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print(f"Using {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print(f"Using {torch.backends.mps.get_device_name(0)} with MPS")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

torch.set_default_device(DEVICE)

Using NVIDIA GeForce RTX 4090


### CONFIG

In [7]:
SIDD_DATA_PATH = "../data/dyn_img_static/tmp/SIDD_Small_sRGB_Only/Data"

In [8]:
CHEST_XRAY_BASE_DATA_PATH = "../data/chest_xray"

In [9]:
def get_config():
    CHEST_XRAY_BASE_DATA_PATH = "../data/chest_xray"
    return {
        "project": "chest_xray",
        "dataset": CHEST_XRAY_BASE_DATA_PATH,
        "train_data_path": f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL",
        "val_data_path": f"{CHEST_XRAY_BASE_DATA_PATH}/val/NORMAL",
        "test_data_path": f"{CHEST_XRAY_BASE_DATA_PATH}/test/NORMAL",
        "train_num_samples": 200,
        "val_num_samples": 8,
        "test_num_samples": 1,

        # "patch": 512,
        # "stride": 512,
        "resize_square": 256,
        "min_sigma": 0.1,
        "max_sigma": 0.5,
        "batch_size": 1,
        "random_seed": 42,

        "architecture": "UNET-PDHG",
        "in_channels": 1,
        "out_channels": 2,
        "init_filters": 64,
        "n_blocks": 3,
        "activation": "LeakyReLU",
        "downsampling_kernel": (2, 2, 1),
        "downsampling_mode": "max",
        "upsampling_kernel": (2, 2, 1),
        "upsampling_mode": "linear_interpolation",

        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "loss_function": "MSELoss",

        # "up_bound": 0.5,
        "up_bound": 0,
        "T": 128,

        "epochs": 10_000,
        "device": "cuda:0",

        "wandb_mode": "online",
        "save_epoch_wandb": 10,
        "save_epoch_local": 2,
        "save_dir": "tmp_2",
    }

print(get_config())

{'project': 'chest_xray', 'dataset': '../data/chest_xray', 'train_data_path': '../data/chest_xray/train/NORMAL', 'val_data_path': '../data/chest_xray/val/NORMAL', 'test_data_path': '../data/chest_xray/test/NORMAL', 'train_num_samples': 200, 'val_num_samples': 8, 'test_num_samples': 1, 'resize_square': 256, 'min_sigma': 0.1, 'max_sigma': 0.5, 'batch_size': 1, 'random_seed': 42, 'architecture': 'UNET-PDHG', 'in_channels': 1, 'out_channels': 2, 'init_filters': 64, 'n_blocks': 3, 'activation': 'LeakyReLU', 'downsampling_kernel': (2, 2, 1), 'downsampling_mode': 'max', 'upsampling_kernel': (2, 2, 1), 'upsampling_mode': 'linear_interpolation', 'optimizer': 'Adam', 'learning_rate': 0.0001, 'loss_function': 'MSELoss', 'up_bound': 0, 'T': 128, 'epochs': 10000, 'device': 'cuda:0', 'wandb_mode': 'online', 'save_epoch_wandb': 10, 'save_epoch_local': 2, 'save_dir': 'tmp_2'}


------

### Import the image and transform the data

#### Download the data

In [10]:
# # REMEMBER TO COMMENT THIS OUT IF THE DATA HAS BEEN DOWNLOADED!
# !wget https://competitions.codalab.org/my/datasets/download/a26784fe-cf33-48c2-b61f-94b299dbc0f2
# !unzip "a26784fe-cf33-48c2-b61f-94b299dbc0f2" -d .

#### Load SIDD images

In [11]:
def get_npy_file(sample_path: str, scale_factor: float) -> np.ndarray:			
    scale_factor_str = str(scale_factor).replace('.','_')
    xf = np.load(os.path.join(sample_path, f"xf_scale_factor{scale_factor_str}.npy"))
    xf = torch.tensor(xf, dtype=torch.float)
    xf = xf.unsqueeze(0) / 255
    return xf

In [12]:
# TODO: CHANGE THIS TO YOUR PATH
# NOTE: Windows uses \\ instead of /
def load_images_SIDD(ids: list, take_npy_files: bool) -> list:
    data_path = SIDD_DATA_PATH
    k = 0

    images = []

    for folder in os.listdir(data_path):
        img_id = folder[:4]	# The first 4 characters of folder name is the image id (0001, 0002, ..., 0200)
        if img_id not in ids:
            continue
        k += 1
        print(f'loading image id {img_id}, {k}/{len(ids)}')

        files_path = os.path.join(data_path, folder)

        # if take_npy_files:
        #     xf = get_npy_file(files_path, scale_factor)
        #     images.append(xf)
        #     continue

        # Use only the ground truth images
        file = "GT_SRGB_010.PNG"  # GT = Ground Truth

        image = Image.open(os.path.join(files_path, file))
        assert image.mode == 'RGB', f"Image mode is not RGB: {image.mode}" # For now, expect RGB images

        images.append(image)

    return images

In [13]:
def test_load_images_SIDD():
    if DISABLING_TESTS: return
    for img in load_images_SIDD(["0065"], False):
        print(img.size)
        plt.imshow(img)

test_load_images_SIDD()

#### Load Chest X-ray images

In [14]:
# TODO: CHANGE THIS TO YOUR PATH
# NOTE: Windows uses \\ instead of /
def load_images_chest_xray(data_path: str, ids: list) -> list:
    files = os.listdir(data_path)
    jpeg_files = [f for f in files if f.endswith(".jpeg")]

    images = []
    for id in tqdm(ids):
        if id >= len(jpeg_files): continue
        # print(f"Loading image {id} from {data_path}")
        image = Image.open(os.path.join(data_path, jpeg_files[id]))
        images.append(image)
    
    return images

In [15]:
def test_load_images_chest_xray(stage="train", label="NORMAL"):
    if DISABLING_TESTS: return
    for img in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/{stage}/{label}", [0]):
        print(img.size)
        plt.imshow(img, cmap='gray')
    plt.show();

test_load_images_chest_xray()

------

#### Convert image to grayscale

In [16]:
def convert_to_grayscale(image: Image) -> Image:
    return image.convert('L')

In [17]:
def test_convert_to_grayscale():
    if DISABLING_TESTS: return
    # for img in load_images_SIDD(["0065"], False):
    for img in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        img = convert_to_grayscale(img)
        plt.imshow(img, cmap='gray') # cmap='gray' for proper display in black and white. It does not convert the image to grayscale.

test_convert_to_grayscale()

#### Transform image

In [18]:
def crop_to_square(image: Image) -> Image:
    width, height = image.size
    new_size = min(width, height)
    left = (width - new_size) / 2
    top = (height - new_size) / 2
    right = (width + new_size) / 2
    bottom = (height + new_size) / 2
    return image.crop((left, top, right, bottom))

In [19]:
def test_crop_to_square():
    if DISABLING_TESTS: return
    # for img in load_images_SIDD(["0083"], False):
    for img in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        plt.imshow(img, cmap='gray')
        plt.show();
        img = crop_to_square(img)
        plt.imshow(img, cmap='gray')
        plt.show();

test_crop_to_square()

In [20]:
def crop_to_square_and_resize(image: Image, side_len: int) -> Image:
    image = crop_to_square(image)
    return image.resize(size=(side_len, side_len))

In [21]:
def test_crop_to_square_and_resize():
    if DISABLING_TESTS: return
    # for img in load_images_SIDD(["0083"], False):
    for img in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        plt.imshow(img, cmap='gray')
        plt.show();
        img = crop_to_square_and_resize(img, 120)
        print(img.size)
        plt.imshow(img, cmap='gray') # cmap='gray' for proper display in black and white. It does not convert the image to grayscale.
        plt.show();

test_crop_to_square_and_resize()

#### Convert to numpy array

In [22]:
def convert_to_numpy(image):
    image_data = np.asarray(image)
    return image_data

In [23]:
def test_convert_to_numpy():
    if DISABLING_TESTS: return
    # for img in load_images_SIDD(["0083"], False):
    for img in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        img = convert_to_grayscale(img)
        print(f"Before conversion: {type(img)}")
        image_data = convert_to_numpy(img)
        print(f"After conversion: {type(image_data)}")
        # plt.imshow still works with numpy array
        plt.imshow(image_data, cmap='gray')

test_convert_to_numpy()

#### Convert to tensor

For efficient computation on GPU

In [24]:
def convert_to_tensor_4D(image_numpy):
    # xf = []
    # xf.append(image_numpy)
    # xf = np.stack(xf, axis=-1)
    # xf = torch.tensor(xf, dtype=torch.float)
    xf = torch.tensor(image_numpy, dtype=torch.float)
    xf = xf.unsqueeze(0)
    xf = xf.unsqueeze(-1)
    xf = xf / 255 # Normalise from [0, 255] to [0, 1]
    return xf

In [25]:
def test_convert_to_tensor():
    if DISABLING_TESTS: return
    # for image in load_images_SIDD(["0083"], False):
    for image in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        image = convert_to_grayscale(image)
        image_numpy = convert_to_numpy(image)
        image_tensor_4D = convert_to_tensor_4D(image_numpy)
        print(image_tensor_4D.size())
        plt.imshow(image_tensor_4D.squeeze(0).to("cpu"), cmap='gray')


test_convert_to_tensor()

#### Add synthetic noise

<!-- artificial Gaussian noise

Noise can occur in reality.

It is difficult to obtain a pair of clean and noisy images of one exact same scene.

For training, it is common to add synthetic noise to an image that is considered clean and then try to reconstruct it.

There are many types of noise and different ways to add noise. We can add salt-and-pepper noise. (?)We can add more noise in some parts and less in others. We can use a combination of noise-adding strategies to build more robust models.

For our purpose, we will focus on Gaussian noise. This is sufficient for most cases. 

(?) We will add noise with the same probability for each pixel (not using the strategies of focusing on certain regions) -->

In [26]:
def get_variable_noise(sigma_min, sigma_max):
    return sigma_min + torch.rand(1) * (sigma_max - sigma_min)

def add_noise(xf: torch.tensor, sigma) -> torch.tensor:
    std = torch.std(xf)
    mu = torch.mean(xf)

    x_centred = (xf  - mu) / std

    x_centred += sigma * torch.randn(xf.shape, dtype = xf.dtype)

    xnoise = std * x_centred + mu

    del std, mu, x_centred

    return xnoise

In [27]:
def test_add_noise():
    if DISABLING_TESTS: return
    # for rgb_image in load_images_SIDD(["0083"], False):
    for rgb_image in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        grayscale_image = convert_to_grayscale(rgb_image)
        grayscale_image = crop_to_square_and_resize(grayscale_image, 120)
        print(f"grayscale_image.size: {grayscale_image.size}")
        image_numpy = convert_to_numpy(grayscale_image)
        image_tensor_4D = convert_to_tensor_4D(image_numpy)
        constant_noise_img = add_noise(image_tensor_4D, sigma=0.1)
        variable_noise_img = add_noise(image_tensor_4D, get_variable_noise(
            sigma_min=0.1, sigma_max=0.2))
        plt.imshow(grayscale_image, cmap='gray')
        plt.show();
        plt.imshow(constant_noise_img.squeeze(0).to("cpu"), cmap='gray')
        plt.show();
        plt.imshow(variable_noise_img.squeeze(0).to("cpu"), cmap='gray')
        plt.show();

    with torch.no_grad():
        torch.cuda.empty_cache()
        
test_add_noise()

------

### Calculate PSNR

PSNR is a common metrics for noisy image.

Compare before and after adding synthetic noise

In [28]:
def PSNR(original, compressed): 
    mse = torch.mean((original - compressed) ** 2) 
    if(mse == 0): # MSE is zero means no noise is present in the signal. 
                  # Therefore PSNR have no importance. 
        return 100
    # max_pixel = 255.0
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse)) 

    del mse

    return psnr

In [29]:
def test_PSNR():
    if DISABLING_TESTS: return
    # for rgb_image in load_images_SIDD(["0083"], False):
    for rgb_image in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        grayscale_image = convert_to_grayscale(rgb_image)
        grayscale_image = crop_to_square_and_resize(grayscale_image, 120)
        image_numpy = convert_to_numpy(grayscale_image)
        image_tensor_4D = convert_to_tensor_4D(image_numpy)

        print(f"PSNR of original image: {PSNR(image_tensor_4D, image_tensor_4D)} dB")
        plt.imshow(image_tensor_4D.squeeze(0).to("cpu"), cmap='gray')
        plt.show();

        noisy_image_tensor_4D = add_noise(image_tensor_4D, sigma=0.5)
        print(f"PSNR of constant noise image: {PSNR(noisy_image_tensor_4D, image_tensor_4D):.2f} dB")
        plt.imshow(noisy_image_tensor_4D.squeeze(0).to("cpu"), cmap='gray')
        plt.show();


test_PSNR()

---

### Calculate SSIM

In [30]:
def SSIM(tensor_2D_a: torch.Tensor, tensor_2D_b: torch.Tensor, data_range: float=1) -> float:
    return structural_similarity(
        tensor_2D_a.to("cpu").detach().numpy(), 
        tensor_2D_b.to("cpu").detach().numpy(),
        data_range=data_range)

In [31]:
def test_SSIM(sigma=0.5):
    if DISABLING_TESTS: return
    # for rgb_image in load_images_SIDD(["0083"], False):
    for rgb_image in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        grayscale_image = convert_to_grayscale(rgb_image)
        grayscale_image = crop_to_square_and_resize(grayscale_image, 120)
        image_numpy = convert_to_numpy(grayscale_image)
        image_tensor_4D = convert_to_tensor_4D(image_numpy)

        image_tensor_2D = image_tensor_4D.squeeze(0).squeeze(-1)
        print(f"image_tensor_2D: {image_tensor_2D.size()}")
        print(f"SSIM of original image: {SSIM(image_tensor_2D, image_tensor_2D)}")
        plt.imshow(image_tensor_2D.cpu(), cmap='gray')
        plt.show();

        noisy_image_tensor_2D = add_noise(image_tensor_2D, sigma=sigma)
        print(f"noisy_image_tensor_2D: {noisy_image_tensor_2D.size()}")
        print(f"SSIM of noisy image (sigma={sigma}): {SSIM(noisy_image_tensor_2D, image_tensor_2D):.2f}")
        plt.imshow(noisy_image_tensor_2D.cpu(), cmap='gray')
        plt.show();

    with torch.no_grad():
        torch.cuda.empty_cache()

test_SSIM()

------

### Reconstruct an image with PDHG

#### Calculate the gradient

<!-- The gradient is a Laplacian ?

There are $x$ gradient and $y$ gradient -->

In [32]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

class GradOperators(torch.nn.Module):
    @staticmethod
    def diff_kernel(ndim, mode):
        if mode == "doublecentral":
            kern = torch.tensor((-1, 0, 1))
        elif mode == "central":
            kern = torch.tensor((-1, 0, 1)) / 2
        elif mode == "forward":
            kern = torch.tensor((0, -1, 1))
        elif mode == "backward":
            kern = torch.tensor((-1, 1, 0))
        else:
            raise ValueError(f"mode should be one of (central, forward, backward, doublecentral), not {mode}")
        kernel = torch.zeros(ndim, 1, *(ndim * (3,)))
        for i in range(ndim):
            idx = tuple([i, 0, *(i * (1,)), slice(None), *((ndim - i - 1) * (1,))])
            kernel[idx] = kern
        return kernel

    def __init__(self, dim:int=2, mode:str="doublecentral", padmode:str = "circular"):
        """
        An Operator for finite Differences / Gradients
        Implements the forward as apply_G and the adjoint as apply_GH.
        
        Args:
            dim (int, optional): Dimension. Defaults to 2.
            mode (str, optional): one of doublecentral, central, forward or backward. Defaults to "doublecentral".
            padmode (str, optional): one of constant, replicate, circular or refelct. Defaults to "circular".
        """
        super().__init__()
        self.register_buffer("kernel", self.diff_kernel(dim, mode), persistent=False)
        self._dim = dim
        self._conv = (torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d)[dim - 1]
        self._convT = (torch.nn.functional.conv_transpose1d, torch.nn.functional.conv_transpose2d, torch.nn.functional.conv_transpose3d)[dim - 1]
        self._pad = partial(torch.nn.functional.pad, pad=2 * dim * (1,), mode=padmode)
        if mode == 'central':
            self._norm = (self.dim) ** (1 / 2)
        else:
            self._norm = (self.dim * 4) ** (1 / 2)

    @property
    def dim(self):
        return self._dim
    
    def apply_G(self, x):
        """
        Forward
        """
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, 1, *x.shape[-self.dim :])
        xp = self._pad(xr)
        y = self._conv(xp, weight=self.kernel, bias=None, padding=0)
        if x.is_complex():
            y = y.reshape(2, *x.shape[: -self.dim], self.dim, *x.shape[-self.dim :])
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape[0 : -self.dim], self.dim, *x.shape[-self.dim :])

        del x, xr, xp

        return y

    def apply_GH(self, x):
        """
        Adjoint
        """
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, self.dim, *x.shape[-self.dim :])
        xp = self._pad(xr)
        y = self._convT(xp, weight=self.kernel, bias=None, padding=2)
        if x.is_complex():
            y = y.reshape(2, *x.shape[: -self.dim - 1], *x.shape[-self.dim :])
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape[: -self.dim - 1], *x.shape[-self.dim :])

        del x, xr, xp

        return y
    
    def apply_GHG(self, x):
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, 1, *x.shape[-self.dim :])
        xp = self._pad(xr)
        tmp = self._conv(xp, weight=self.kernel, bias=None, padding=0)
        tmp = self._pad(tmp)
        y = self._convT(tmp, weight=self.kernel, bias=None, padding=2)
        if x.is_complex():
            y = y.reshape(2, *x.shape)
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape)

        del x, xr, xp, tmp

        return y

    def forward(self, x, direction=1):
        if direction>0:
            return self.apply_G(x)
        elif direction<0:
            return self.apply_GH(x)
        else:
            return self.apply_GHG(x)

    @property
    def normGHG(self):
        return self._norm

#### Helper function for PDHG: Clip act

In [33]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

class ClipAct(nn.Module):
    def forward(self, x, threshold):
        return clipact(x, threshold)


def clipact(x, threshold):
    is_complex = x.is_complex()
    if is_complex:
        x = torch.view_as_real(x)
        threshold = threshold.unsqueeze(-1)
    x = torch.clamp(x, -threshold, threshold)
    if is_complex:
        x = torch.view_as_complex(x)
    return x

#### Only PDHG

For some reason, running PDHG with T large (many iterations in PDGH) will make GPU memory full?

In [34]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

def reconstruct_with_PDHG(
        x_dynamic_image_tensor_5D, lambda_reg, T, 
        # lambda_reg_container=None,
):
    """
    Reconstructs the image using the PDHG algorithm.

    Parameters:
        dynamic_image_tensor_5D: The (noisy) (dynamic) image tensor.
        Size of the tensor: (`patches`, `channels`, `Nx`, `Ny`, `Nt`) where
        
        - `patches`: number of patches
        - `channels`: number of (colour) channels
        - `Nx`: number of pixels in x
        - `Ny`: number of pixels in y
        - `Nt`: number of time steps (frames)

        lambda_reg: The regularization parameter. Can be a scalar or a tensor of suitable size.
        T: Number of iterations.

    Returns:
        The reconstructed image tensor.
    """

    dim = 3
    patches, channels, Nx, Ny, Nt = x_dynamic_image_tensor_5D.shape
    
    assert channels == 1, "Only grayscale images are supported."

    device = x_dynamic_image_tensor_5D.device

    # starting values
    xbar = x_dynamic_image_tensor_5D.clone()
    x0 = x_dynamic_image_tensor_5D.clone()
    xnoisy = x_dynamic_image_tensor_5D.clone()

    # dual variable
    p = x_dynamic_image_tensor_5D.clone()
    q = torch.zeros(patches, dim, Nx, Ny, Nt, dtype=x_dynamic_image_tensor_5D.dtype).to(device)

    # operator norms
    op_norm_AHA = torch.sqrt(torch.tensor(1.0))
    op_norm_GHG = torch.sqrt(torch.tensor(12.0))
    # operator norm of K = [A, \nabla]
    # https://iopscience.iop.org/article/10.1088/0031-9155/57/10/3065/pdf,
    # see page 3083
    L = torch.sqrt(op_norm_AHA**2 + op_norm_GHG**2)

    tau = nn.Parameter(
        torch.tensor(10.0), requires_grad=True
    )  # starting value approximately  1/L
    sigma = nn.Parameter(
        torch.tensor(10.0), requires_grad=True
    )  # starting value approximately  1/L

    # theta should be in \in [0,1]
    theta = nn.Parameter(
        torch.tensor(10.0), requires_grad=True
    )  # starting value approximately  1

    # sigma, tau, theta
    sigma = (1 / L) * torch.sigmoid(sigma)  # \in (0,1/L)
    tau = (1 / L) * torch.sigmoid(tau)  # \in (0,1/L)
    theta = torch.sigmoid(theta)  # \in (0,1)

    GradOps = GradOperators(
        dim=dim, 
        mode="forward", padmode="circular")
    clip_act = ClipAct()
    # Algorithm 2 - Unrolled PDHG algorithm (page 18)
    # TODO: In the paper, L is one of the inputs but not used anywhere in the pseudo code???
    for kT in range(T):
        # update p
        p =  (p + sigma * (xbar - xnoisy) ) / (1. + sigma)
        # update q
        q = clip_act(q + sigma * GradOps.apply_G(xbar), lambda_reg)

        x1 = x0 - tau * p - tau * GradOps.apply_GH(q)

        if kT != T - 1:
            # update xbar
            xbar = x1 + theta * (x1 - x0)
            x0 = x1
        with torch.no_grad():
            torch.cuda.empty_cache()

    del x_dynamic_image_tensor_5D
    del xbar, x0, xnoisy
    del p, q
    del op_norm_AHA, op_norm_GHG, L
    del tau, sigma, theta
    del GradOps
    del clip_act

    with torch.no_grad():
        torch.cuda.empty_cache()

    # if lambda_reg_container is not None:
    #     assert isinstance(lambda_reg_container, list), f"lambda_reg_container should be a list, not {type(lambda_reg_container)}"
    #     lambda_reg_container.append(lambda_reg) # For comparison

    return x1

In [35]:
def test_reconstruct_with_PDHG():
    if DISABLING_TESTS: return
    # for rgb_image in load_images_SIDD(["0083"], False):
    for rgb_image in load_images_chest_xray(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL", [0]):
        grayscale_image = convert_to_grayscale(rgb_image)
        grayscale_image = crop_to_square_and_resize(grayscale_image, 512)
        image_numpy = convert_to_numpy(grayscale_image)

        image_tensor_4D = convert_to_tensor_4D(image_numpy)
        print(f"Image tensor size: {image_tensor_4D.size()}")
        assert len(image_tensor_4D.size()) == 4, "The image should be 4D"
        plt.imshow(image_tensor_4D.squeeze(0).to("cpu"), cmap='gray')
        plt.show();

        TEST_SIGMA = 0.5  # Relatively high noise
        noisy_image_tensor_4D = add_noise(image_tensor_4D, sigma=TEST_SIGMA)
        print(f"PSNR of constant noise image: {PSNR(image_tensor_4D, noisy_image_tensor_4D):.2f} dB")
        print(f"SSIM of constant noise image: {SSIM(image_tensor_4D.squeeze(0).squeeze(-1), noisy_image_tensor_4D.squeeze(0).squeeze(-1)):.2f}")
        plt.imshow(noisy_image_tensor_4D.squeeze(0).to("cpu"), cmap='gray')
        plt.show();

        TEST_LAMBDA = 0.04
        pdhg_input_tensor_5D = noisy_image_tensor_4D.unsqueeze(0)
        print(f"PDHG input size: {pdhg_input_tensor_5D.size()}")
        assert len(pdhg_input_tensor_5D.size()) == 5, "The input for PDHG should be 5D"
        denoised_image_tensor_5D = reconstruct_with_PDHG(
            pdhg_input_tensor_5D, 
            lambda_reg=TEST_LAMBDA, 
            T=128)
        
        denoised_image_tensor_5D = torch.clamp(denoised_image_tensor_5D, 0, 1) # Clip the values to 0 and 1
        psnr_value_denoised = PSNR(image_tensor_4D, denoised_image_tensor_5D.squeeze(0))
        print(f"PSNR of reconstructed image: {psnr_value_denoised:.2f} dB")
        denoised_image_numpy_3D = denoised_image_tensor_5D.squeeze(0).squeeze(0).to("cpu").detach().numpy()
        plt.imshow(denoised_image_numpy_3D, cmap='gray')
        plt.show();

    with torch.no_grad():
        torch.cuda.empty_cache()

    print("""
In this example, a lot of noise has been applied to the original image. The PDHG algorithm tries to reconstruct the image from the noisy image. It did remove some noise and improved the PSNR value. However, the quality has been degraded significantly. We will see whether we can improve this by learning a set of parameters map.
""")
    
    # The lambda parameter is the regularization parameter. The higher the lambda, the more the regularization. The T parameter is the number of iterations. The higher the T, the more the iterations. The PSNR value is the Peak Signal to Noise Ratio. The higher the PSNR, the better the reconstruction.

test_reconstruct_with_PDHG()

------

### Full Architecture

<!-- UNET to PDHG

The whole architecture can be seen as unsupervised: The data only contains (clean) images.

The whole model: Input is an image. Output is also an image.

The UNET actually only outputs the regularisation parameter map. -->

In [36]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

class DynamicImageStaticPrimalDualNN(nn.Module):
    def __init__(
        self,
        T=128,
        cnn_block=None,
        mode="lambda_cnn",
        up_bound=0,
        phase="training",
    ):
        # print(f"Running: {DynamicImageStaticPrimalDualNN.__name__}")
        super(DynamicImageStaticPrimalDualNN, self).__init__()

        # gradient operators and clipping function
        dim = 3
        self.GradOps = GradOperators(dim, mode="forward", padmode="circular")

        # operator norms
        self.op_norm_AHA = torch.sqrt(torch.tensor(1.0))
        self.op_norm_GHG = torch.sqrt(torch.tensor(12.0))
        # operator norm of K = [A, \nabla]
        # https://iopscience.iop.org/article/10.1088/0031-9155/57/10/3065/pdf,
        # see page 3083
        self.L = torch.sqrt(self.op_norm_AHA**2 + self.op_norm_GHG**2)

        # function for projecting
        self.ClipAct = ClipAct()

        if mode == "lambda_xyt":
            # one single lambda for x,y and t
            self.lambda_reg = nn.Parameter(torch.tensor([-1.5]), requires_grad=True)

        elif mode == "lambda_xy_t":
            # one (shared) lambda for x,y and one lambda for t
            self.lambda_reg = nn.Parameter(
                torch.tensor([-4.5, -1.5]), requires_grad=True
            )

        elif mode == "lambda_cnn":
            # the CNN-block to estimate the lambda regularization map
            # must be a CNN yielding a two-channeld output, i.e.
            # one map for lambda_cnn_xy and one map for lambda_cnn_t
            self.cnn = cnn_block    # NOTE: This is actually the UNET!!! (At least in this project)
            self.up_bound = torch.tensor(up_bound)

        # number of terations
        self.T = T
        self.mode = mode

        # constants depending on the operators
        self.tau = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1/L
        self.sigma = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1/L

        # theta should be in \in [0,1]
        self.theta = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1

        # distinguish between training and test phase;
        # during training, the input is padded using "reflect" padding, because
        # patches are used by reducing the number of temporal points;
        # while testing, "reflect" padding is used in x,y- direction, while
        # circular padding is used in t-direction
        self.phase = phase

    def get_lambda_cnn(self, x):
        # padding
        # arbitrarily chosen, maybe better to choose it depending on the
        # receptive field of the CNN or so;
        # seems to be important in order not to create "holes" in the
        # lambda_maps in t-direction
        npad_xy = 4
        # npad_t = 8
        npad_t = 0 # TODO: Time dimension should not be necessary for single image input.
        # I changed the npad_t to 0 so that I can run on single image input without change the 3D type config. It seems that the number of frames must be greater than npad_t?

        pad = (npad_t, npad_t, npad_xy, npad_xy, npad_xy, npad_xy)

        if self.phase == "training":
            x = F.pad(x, pad, mode="reflect")

        elif self.phase == "testing":
            pad_refl = (0, 0, npad_xy, npad_xy, npad_xy, npad_xy)
            pad_circ = (npad_t, npad_t, 0, 0, 0, 0)

            x = F.pad(x, pad_refl, mode="reflect")
            x = F.pad(x, pad_circ, mode="circular")

        # estimate parameter map
        lambda_cnn = self.cnn(x) # NOTE: The cnn is actually the UNET block!!! (At least in this project)

        # crop
        neg_pad = tuple([-pad[k] for k in range(len(pad))])
        lambda_cnn = F.pad(lambda_cnn, neg_pad)

        # double spatial map and stack
        lambda_cnn = torch.cat((lambda_cnn[:, 0, ...].unsqueeze(1), lambda_cnn), dim=1)

        # constrain map to be striclty positive; further, bound it from below
        if self.up_bound > 0:
            # constrain map to be striclty positive; further, bound it from below
            lambda_cnn = self.up_bound * self.op_norm_AHA * torch.sigmoid(lambda_cnn)
        else:
            lambda_cnn = 0.1 * self.op_norm_AHA * F.softplus(lambda_cnn)

        del pad
        del x
        del neg_pad

        return lambda_cnn

    def forward(
            self, x, lambda_map=None, 
            # lambda_reg_container=None,
    ):
        if lambda_map is None:
            # estimate lambda reg from the image
            lambda_reg = self.get_lambda_cnn(x)
        else:
            lambda_reg = lambda_map

        # if lambda_reg_container is not None:
        #     assert type(lambda_reg_container) == list, f"lambda_reg_container should be a list, not {type(lambda_reg_container)}"
        #     lambda_reg_container.append(lambda_reg) # For comparison

        x.to(DEVICE)
        x1 = reconstruct_with_PDHG(x, lambda_reg, self.T)

        del lambda_reg
        del x

        return x1

------

### Data loading class

In [37]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

class DynamicImageStaticDenoisingDataset(Dataset):
	
	def __init__(
		self, 
		data_path: str, 
		ids: list,
		# scale_factor = 0.5, 
		# patches_size = None,
		# strides= None,
		resize_square = 120,
		sigma = (0.1, 0.5),  
		device: str = "cuda"
	):
		self.device = device
		# self.scale_factor = scale_factor
		self.resize_square = resize_square

		xray_images = load_images_chest_xray(data_path, ids)

		xf_list = []
		for image in xray_images:
			image = crop_to_square_and_resize(image, self.resize_square)
			image = image.convert('L') #convert to grey_scale
			image_data = np.asarray(image)
			xf = torch.tensor(image_data, dtype=torch.float)
			# Assume image is in [0, 255] range
			xf = xf / 255
			assert len(xf.size()) == 2, f"Expected 2D tensor, got {xf.size()}"
			xf = xf.unsqueeze(0) # Add channel dimension
			xf = xf.unsqueeze(-1) # Add time dimension. TODO: For legacy dynamic image code only. Will remove later.
			xf_list.append(xf)
		xf = torch.stack(xf_list, dim=0) # will have shape (mb, 1, Nx, Ny, Nt), where mb denotes the number of patches
		assert len(xf.size()) == 5, f"Expected 5D tensor, got {xf.size()}"
		assert xf.size(1) == 1, f"Expected 1 channel, got {xf.size(1)}"
		assert xf.size(2) == self.resize_square, f"Expected width (Nx) of {self.resize_square}, got {xf.size(-3)}"
		assert xf.size(3) == self.resize_square, f"Expected height (Ny) of {self.resize_square}, got {xf.size(-2)}"
		assert xf.size(4) == 1, f"Expected 1 time step, got {xf.size(-1)}"

		#create temporal TV vector to detect which patches contain the most motion
		xf_patches_tv = (xf[...,1:] - xf[...,:-1]).pow(2).sum(dim=[1,2,3,4]) #contains the TV for all patches
		
		#normalize to 1 to have a probability vector
		xf_patches_tv /= torch.sum(xf_patches_tv)
		
		#sort TV in descending order --> xfp_tv_ids[0] is the index of the patch with the most motion
		self.samples_weights = xf_patches_tv

		# # TODO: Investigate
		# # Change the values in samples_weights to be a range of integers from 0 to len(samples_weights)
		# # Unless I do this, when I run on a set of identical images, it will give me an error:
		# # RuntimeError: invalid multinomial distribution (encountering probability entry < 0)
		# self.samples_weights = torch.arange(len(self.samples_weights))
		
		self.xf = xf
		self.len = xf.shape[0]
		
		self.sigma_min = sigma[0]
		self.sigma_max = sigma[1]
		
			
	def __getitem__(self, index):

		sigma = self.sigma_min + torch.rand(1) * ( self.sigma_max - self.sigma_min )

		x_noise = add_noise(self.xf[index], sigma)

		del sigma

		return (
			x_noise.to(device=self.device),
   			self.xf[index].to(device=self.device)
        )
		
	def __len__(self):
		return self.len

------

### UNET

The specific UNET architecture we use has the following parts:

...

We use Leaky RELU instead of RELU or Sigmoid.

In [38]:
# Used https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py as a reference

class DoubleConv(nn.Module):
    def __init__(
            self, in_channels: int, out_channels: int, n_dimensions=3, activation="LeakyReLU"):
        super(DoubleConv, self).__init__()

        def get_conv(in_channels, out_channels):
            # 1-dimensional convolution is not supported
            if n_dimensions == 3:
                return nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 1), padding=(1, 1, 0))
            elif n_dimensions == 2:
                return nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
            else:
                raise ValueError(f"Unsupported number of dimensions: {n_dimensions}")

        def get_activation():
            if activation == "LeakyReLU":
                return nn.LeakyReLU(negative_slope=0.01, inplace=True)
            elif activation == "ReLU":
                return nn.ReLU(inplace=True)
            else:
                raise ValueError(f"Unsupported activation function: {activation}")

        self.conv_block = nn.Sequential(
            get_conv(in_channels, out_channels), get_activation(),
            get_conv(out_channels, out_channels), get_activation())

    def forward(self, x: torch.Tensor):
        return self.conv_block(x)
        

class EncodeBlock3d(nn.Module):
    def __init__(
            self, in_channels: int, n_dimensions=3,
            activation="LeakyReLU",
            downsampling_kernel=(2, 2, 1), downsampling_mode="max"):
        super(EncodeBlock3d, self).__init__()

        len = downsampling_kernel[0] # Assume kernel has shape (len, len, 1)
        assert downsampling_kernel == (len, len, 1), f"Expected a flat square kernel like {(len, len, 1)}, got {downsampling_kernel}"
        stride = (2, 2, 1) # Stride 2x2 to halve each side 
        padding = ((len-1)//2, (len-1)//2, 0) # Padding (len-1) // 2 to exactly halve each side 
        if downsampling_mode == "max":
            self.pool = nn.MaxPool3d(
                kernel_size=downsampling_kernel, stride=stride, padding=padding)
        elif downsampling_mode == "avg":
            self.pool = nn.AvgPool3d(
                kernel_size=downsampling_kernel, stride=stride, padding=padding)
        else:
            raise ValueError(f"Unknown pooling method: {downsampling_mode}")

        self.double_conv = DoubleConv(in_channels, in_channels * 2, n_dimensions, activation=activation)

    def forward(self, x: torch.Tensor):
        x = self.pool(x)
        x = self.double_conv(x)
        return x



class DecodeBlock3d(nn.Module):
    def __init__(
            self, in_channels: int, n_dimensions=3, 
            activation="LeakyReLU",
            upsampling_kernel=(2, 2, 1), upsampling_mode="linear_interpolation"):
        super(DecodeBlock3d, self).__init__()

        if upsampling_mode == "linear_interpolation":
            self.upsampling = nn.Sequential(
                nn.Upsample(
                    scale_factor=(2, 2, 1), # Assume the shape is (Nx, Ny, 1) where Nx is the image width and Ny is the image height.
                    mode='trilinear', align_corners=True), # What difference does it make in the end if align_corners is True or False? Preserving symmetry?
                # 1x1x1 convolution to reduce the number of channels while keeping the size the same
                nn.Conv3d(
                    in_channels, in_channels // 2, 
                    kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0))
            )
        elif upsampling_mode == "transposed_convolution":  
            len = upsampling_kernel[0] # Assume kernel has shape (len, len, 1)
            assert upsampling_kernel == (len, len, 1), f"Expected a flat square kernel like {(len, len, 1)}, got {upsampling_kernel}"
            stride = (2, 2, 1) # Stride 2x2 to double each side 
            padding = ((len-1)//2, (len-1)//2, 0) # Padding (len-1) // 2 to exactly double each side    
            self.upsampling = nn.ConvTranspose3d(
                in_channels, in_channels // 2, 
                kernel_size=upsampling_kernel, stride=stride, padding=padding, 
                output_padding=padding # TODO: Should this be the same as padding?
            )
        else:
            raise ValueError(f"Unsupported upsampling method: {upsampling_mode}")
        
        self.double_conv = DoubleConv(in_channels, in_channels // 2, n_dimensions, activation=activation)

    def forward(self, x: torch.Tensor, x_encoder_output: torch.Tensor):
        x = self.upsampling(x)
        x = torch.cat([x_encoder_output, x], dim=1)   # skip-connection. No cropping since we ensure that the size is the same.
        x = self.double_conv(x)
        return x



class UNet3d(nn.Module):
    def __init__(
            self, in_channels=1, out_channels=2, init_filters=32, n_blocks=3,
            activation="LeakyReLU",
            downsampling_kernel=(2, 2, 1), downsampling_mode="max",
            upsampling_kernel=(2, 2, 1), upsampling_mode="linear_interpolation",
    ):
        """
        Assume that input is 5D tensor of shape (batch_size, channels, Nx, Ny, Nt)
        where Nx is the image width and Ny is the image height.
        Assume that batch_size = 1, channels = 1, Nx = Ny (square image), Nt = 1 (static image).
        NOTE: The convention used in pytorch documentation is (batch_size, channels, Nt, Ny, Nx).
        "channels" is equivalent to the number of filters or features.

        Our paper (figure 2):
            - in_channels = 1
            - out_channels = 2
            - init_filters = 32
            - n_blocks = 3
            - pooling: max pooling 2x2
            - pool padding = 1
                - 1 padding will keep the size of the "image" the same after each convolution. The skip-connection will NOT crop the encoder's output.
            - upsampling kernel: 2x2 ?
            - up_mode: linear interpolation

        U-Net paper (2015, Olaf Ronneberger https://arxiv.org/abs/1505.04597):
            - in_channels = 1
            - out_channels = 2
            - init_filters = 64
            - n_blocks = 4
            - pooling: max pooling 2x2
            - pool padding = 0
                - 0 padding will reduce the size of the "image" by 2 in each dimension after each convolution. The skip-connection will have to crop the encoder's output to match the decoder's input.
            - upsampling kernel: 2x2
            - up_mode: ? (linear interpolation or transposed convolution)
        """
        super(UNet3d, self).__init__()
        
        self.c0x0 = DoubleConv( # TODO: Find a better name
            in_channels=in_channels, 
            out_channels=init_filters,
            activation=activation
        )
        self.encoder = nn.ModuleList([
            EncodeBlock3d(
                in_channels=init_filters * 2**i,
                activation=activation,
                downsampling_kernel=downsampling_kernel,
                downsampling_mode=downsampling_mode
            ) for i in range(n_blocks)
        ])
        self.decoder = nn.ModuleList([
            DecodeBlock3d(
                in_channels=init_filters * 2**(n_blocks-i),
                activation=activation, 
                upsampling_kernel=upsampling_kernel,
                upsampling_mode=upsampling_mode
            ) for i in range(n_blocks)
        ])
        # 1x1x1 convo
        self.c1x1 = nn.Conv3d( # TODO: Find a better name
            in_channels=init_filters,
            out_channels=out_channels,
            kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0)
        )

    def forward(self, x: torch.Tensor):
        # Assume that x is 5D tensor of shape (batch_size, channels, Nx, Ny, Nt)
        # where Nx is the image width and Ny is the image height.
        # Aslo assume that batch_size = 1, channels = 1, Nx = Ny (square image), Nt = 1 (static image).
        # NOTE: The convention used in pytorch documentation is (batch_size, channels, Nt, Ny, Nx).
        assert len(x.size()) == 5, f"Expected 5D tensor, got {x.size()}"
        batch_size, channels, Nx, Ny, Nt = x.size()
        assert channels == 1, f"Expected 1 channel, got {channels}" # TODO: Allow multiple channels (colour images)
        assert Nx == Ny, f"Expected square image, got ({Nx}, {Ny})" # TODO: Allow non-square images
        assert Nt == 1, f"Expected 1 time step, got {Nt}" # TODO: Allow multiple time steps (dynamic images, video)
        assert batch_size == 1, f"Expected batch size 1, got {batch_size}" # TODO: Might train with larger batch size

        n_blocks = len(self.encoder)
        assert Nx >= 2**n_blocks, f"Expected width (Nx) of at least {2**n_blocks}, got {Nx}"
        assert Ny >= 2**n_blocks, f"Expected height (Ny) of at least {2**n_blocks}, got {Ny}"

        x = self.c0x0(x)

        encoder_outputs = []
        for i, enc_block in enumerate(self.encoder):
            encoder_outputs.append(x)
            x = enc_block(x)
        for i, dec_block in enumerate(self.decoder):
            x = dec_block(x, encoder_outputs[-i-1]) # skip-connection inside
            
        x = self.c1x1(x)

        for enc_output in encoder_outputs:
            del enc_output
        del encoder_outputs

        return x

In [39]:
def assert_and_clear_cuda(expected, actual):
    try:
        assert expected == actual
    except AssertionError:
        print(f"!!! ERROR !!! Expected: {expected}, got {actual}")
        with torch.no_grad():
            torch.cuda.empty_cache()
    

def test_unet_3d():  
    if DISABLING_TESTS: return  
    input_tensor = torch.randn(1, 1, 512, 512, 1)  # batch size of 1, 1 channel, 512x512x1 volume
    
    config = get_config()

    # Example usage
    model = UNet3d(
        init_filters=32,
        n_blocks=config["n_blocks"],
        activation="ReLU",
        downsampling_kernel=(2, 2, 1),
        downsampling_mode=config["downsampling_mode"],
        upsampling_kernel=(2, 2, 1),
        upsampling_mode=config["upsampling_mode"],
    )
    output = model(input_tensor)
    print(f"UNet output shape: {output.shape}")
    assert_and_clear_cuda((1, 2, 512, 512, 1), output.shape)


    conv_3d = nn.Conv3d(1, 64, kernel_size=3, stride=1, padding=1)
    conv_3d_output = conv_3d(input_tensor)
    print(f"Conv3d output shape: {conv_3d_output.shape}")
    assert_and_clear_cuda((1, 64, 512, 512, 1), conv_3d_output.shape)


    double_conv_3d = DoubleConv(64, 128)
    double_conv_output = double_conv_3d(conv_3d_output)
    print(f"{DoubleConv.__name__} output shape: {double_conv_output.shape}")
    assert_and_clear_cuda((1, 128, 512, 512, 1), double_conv_output.shape)


    max_3d = nn.MaxPool3d((3, 3, 1), stride=(2, 2, 1), padding=(1, 1, 0))
    max_3d_output_1 = max_3d(input_tensor)
    print(f"MaxPool3d output 1 shape: {max_3d_output_1.shape}")
    assert_and_clear_cuda((1, 1, 256, 256, 1), max_3d_output_1.shape)

    max_3d_input = torch.randn(1, 128, 512, 512, 1)
    max_3d_output_2 = max_3d(max_3d_input)
    print(f"MaxPool3d output 2 shape: {max_3d_output_2.shape}")
    assert_and_clear_cuda((1, 128, 256, 256, 1), max_3d_output_2.shape)

    conv_transpose_3d = nn.ConvTranspose3d(
        1024, 512, 
        kernel_size=(3, 3, 1), 
        stride=(2, 2, 1), 
        padding=(1, 1, 0), 
        output_padding=(1, 1, 0)
    )
    conv_transpose_3d_input = torch.randn(1, 1024, 32, 32, 1)
    conv_transpose_3d_output = conv_transpose_3d(conv_transpose_3d_input)
    print(f"ConvTranspose3d output shape: {conv_transpose_3d_output.shape}")
    assert_and_clear_cuda((1, 512, 64, 64, 1), conv_transpose_3d_output.shape)


    up_sample = nn.Upsample(
        scale_factor=(2, 2, 1), 
        mode='trilinear', align_corners=True) # What difference does it make if align_corners is True or False?
    up_sample_output = up_sample(input_tensor)
    print(f"Upsample output shape: {up_sample_output.shape}")
    assert_and_clear_cuda((1, 1, 1024, 1024, 1), up_sample_output.shape)
                    

    # # print(f"\n{model}")

    with torch.no_grad():
        torch.cuda.empty_cache()

    # # Delete the model and the output tensor
    # del model
    # del output
    # torch.cuda.empty_cache()

test_unet_3d()
with torch.no_grad():
    torch.cuda.empty_cache()

In [40]:
with torch.no_grad():
    torch.cuda.empty_cache()

------

### Create data loader

In [41]:
# Code adapted from ...

def get_datasets(config):
    min_sigma = config["min_sigma"]
    max_sigma = config["max_sigma"]
    resize_square = config["resize_square"]
    device = config["device"]

    train_num_samples = config["train_num_samples"]
    train_ids = list(range(0, train_num_samples))
    dataset_train = DynamicImageStaticDenoisingDataset(
        data_path=config["train_data_path"],
        ids=train_ids,
        sigma=(min_sigma, max_sigma),
        resize_square=resize_square,
        # strides=[120, 120, 1],
        # patches_size=[120, 120, 1],
        # strides=[256, 256, 1], # stride < patch will allow overlapping patches, maybe good to blend the patches?
        # strides=[512, 512, 1],
        # patches_size=[512, 512, 1],
        device=device
    )

    val_num_samples = config["val_num_samples"]
    val_ids = list(range(0, val_num_samples))
    dataset_valid = DynamicImageStaticDenoisingDataset(
        data_path=config["val_data_path"],
        ids=val_ids,
        sigma=(min_sigma, max_sigma),
        resize_square=resize_square,
        # strides=[120, 120, 1],
        # patches_size=[120, 120, 1],
        # strides=[256, 256, 1], # stride < patch will allow overlapping patches, maybe good to blend the patches?
        # strides=[512, 512, 1],
        # patches_size=[512, 512, 1],
        device=device
    )

    test_num_samples = config["test_num_samples"]
    test_ids = list(range(0, test_num_samples))
    dataset_test = DynamicImageStaticDenoisingDataset(
        data_path=config["test_data_path"],
        ids=test_ids,
        sigma=(min_sigma, max_sigma),
        resize_square=resize_square,
        # strides=[120, 120, 1],
        # patches_size=[120, 120, 1],
        # strides=[256, 256, 1], # stride < patch will allow overlapping patches, maybe good to blend the patches?
        # strides=[512, 512, 1],
        # patches_size=[512, 512, 1],
        device=device
    )

    print(f"Number of training samples: {len(dataset_train)}")
    print(f"Number of validation samples: {len(dataset_valid)}")
    print(f"Number of test samples: {len(dataset_test)}")

    return dataset_train, dataset_valid, dataset_test



def get_dataloaders(config):

    dataset_train, dataset_valid, dataset_test = get_datasets(config)
    batch_size = config["batch_size"]
    device = config["device"]
    random_seed = config["random_seed"]

    # Create training dataloader
    # train_sampler = WeightedRandomSampler(dataset_train.samples_weights, len(dataset_train.samples_weights))
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=batch_size, 
        # sampler=train_sampler,
        generator=torch.Generator(device=device).manual_seed(random_seed),
        shuffle=True,
    )

    # Create validation dataloader 
    # val_sampler = WeightedRandomSampler(dataset_valid.samples_weights, len(dataset_valid.samples_weights))
    dataloader_valid = torch.utils.data.DataLoader(
        dataset_valid, batch_size=batch_size, 
        # sampler=val_sampler,
        generator=torch.Generator(device=device).manual_seed(random_seed),
        shuffle=False,
    )

    # Create test dataloader
    test_sampler = WeightedRandomSampler(dataset_test.samples_weights, len(dataset_test.samples_weights))
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size, 
        # sampler=test_sampler,
        generator=torch.Generator(device=device).manual_seed(random_seed),
        shuffle=False,
    )

    return (
        dataloader_train, 
        dataloader_valid, 
        dataloader_test,
    )

#### Test data loader

In [42]:
def test_dataloader():
    if DISABLING_TESTS: return
    dataloader_train, dataloader_valid, dataloader_test = get_dataloaders(get_config())
    for i, (x, y) in enumerate(dataloader_train):
        print(f"Batch {i+1}")
        print(f"x size: {x.size()}")
        print(f"y size: {y.size()}")
        plt.subplot(1, 2, 1)
        plt.imshow(x.squeeze(0).squeeze(0).squeeze(-1).to("cpu"), cmap='gray')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(y.squeeze(0).squeeze(0).squeeze(-1).to("cpu"), cmap='gray')
        plt.axis('off')
        plt.show();
        if i == 5:
            break

    del dataloader_train
    del dataloader_valid
    del dataloader_test

test_dataloader()

------

### Training

#### Code for one epoch

In [43]:
# Code taken from https://www.github.com/koflera/LearningRegularizationParameterMaps

# def train_iteration(optimizer, model, loss_func, sample):
#     optimizer.zero_grad(set_to_none=True)  # Zero your gradients for every batch!
#     noisy_image, clean_image = sample
#     denoised_image = model(noisy_image)
#     loss = loss_func(denoised_image, clean_image)
#     loss.backward()
    
#     if loss.item() != loss.item():
#         raise ValueError("NaN returned by loss function...")

#     optimizer.step()

#     denoised_image = denoised_image.squeeze(0).squeeze(0).squeeze(-1)
#     clean_image = clean_image.squeeze(0).squeeze(0).squeeze(-1)

#     psnr = PSNR(denoised_image, clean_image)
#     ssim = SSIM(denoised_image, clean_image)

#     return loss.item(), psnr, ssim


def train_epoch(model, data_loader, optimizer, loss_func) -> float:
    running_loss = 0.
    running_psnr = 0.
    running_ssim = 0.
    num_batches = len(data_loader)
    # for sample in tqdm(data_loader): # tqdm helps show a nice progress bar
    for sample in data_loader:
        # loss, psnr, ssim = train_iteration(optimizer, model, loss_func, sample)

        optimizer.zero_grad(set_to_none=True)  # Zero your gradients for every batch! TODO: Why?
        noisy_image_5d, clean_image_5d = sample
        denoised_image_5d = model(noisy_image_5d)
        loss = loss_func(denoised_image_5d, clean_image_5d)

        loss.backward()
        if loss.item() != loss.item():
            raise ValueError("NaN returned by loss function...")
        optimizer.step()

        denoised_image_2d = denoised_image_5d.squeeze(0).squeeze(0).squeeze(-1)
        clean_image_2d = clean_image_5d.squeeze(0).squeeze(0).squeeze(-1)
        psnr = PSNR(denoised_image_2d, clean_image_2d)
        ssim = SSIM(denoised_image_2d, clean_image_2d)

        running_loss += loss.item()
        running_psnr += psnr
        running_ssim += ssim

        # Free up memory
        del loss 
        del denoised_image_5d # Delete output of model
        del denoised_image_2d # Delete auxiliary variable
        del clean_image_2d # Delete auxiliary variable
        del noisy_image_5d # Noisy image is generated each time so can delete it
        del clean_image_5d # TODO: Is this a copy that I can delete or a reference to the original?

    avg_loss = running_loss / num_batches
    avg_psnr = running_psnr / num_batches
    avg_ssim = running_ssim / num_batches

    del running_loss
    del running_psnr
    del running_ssim
    del num_batches

    return avg_loss, avg_psnr, avg_ssim


# def validate_iteration(model, loss_func, sample):
#     noisy_image, clean_image = sample
#     denoised_image = model(noisy_image)
#     loss = loss_func(denoised_image, clean_image)
#     denoised_image = denoised_image.squeeze(0).squeeze(0).squeeze(-1)
#     clean_image = clean_image.squeeze(0).squeeze(0).squeeze(-1)

#     psnr = PSNR(denoised_image, clean_image)
#     ssim = SSIM(denoised_image, clean_image)

#     return loss.item(), psnr, ssim


def validate_epoch(model, data_loader, loss_func) -> float:
    running_loss = 0.
    running_psnr = 0.
    running_ssim = 0.
    num_batches = len(data_loader)
    # for sample in tqdm(data_loader): # tqdm helps show a nice progress bar
    for sample in data_loader:
        # loss, psnr, ssim = validate_iteration(model, loss_func, sample)

        noisy_image_5d, clean_image_5d = sample
        denoised_image_5d = model(noisy_image_5d)
        loss = loss_func(denoised_image_5d, clean_image_5d)

        denoised_image_2d = denoised_image_5d.squeeze(0).squeeze(0).squeeze(-1)
        clean_image_2d = clean_image_5d.squeeze(0).squeeze(0).squeeze(-1)
        psnr = PSNR(denoised_image_2d, clean_image_2d)
        ssim = SSIM(denoised_image_2d, clean_image_2d)

        running_loss += loss.item()
        running_psnr += psnr
        running_ssim += ssim

        # Free up memory
        del loss 
        del denoised_image_5d # Delete output of model
        del denoised_image_2d # Delete auxiliary variable
        del clean_image_2d # Delete auxiliary variable
        del noisy_image_5d # Noisy image is generated each time so can delete it
        del clean_image_5d # TODO: Is this a copy that I can delete or a reference to the original?

    avg_loss = running_loss / num_batches
    avg_psnr = running_psnr / num_batches
    avg_ssim = running_ssim / num_batches

    del running_loss
    del running_psnr
    del running_ssim
    del num_batches

    return avg_loss, avg_psnr, avg_ssim


#### Prep for training

In [44]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [45]:
def save_image(image_tensor_2D, image_name, folder_name):
    image_tensor_2D = torch.clamp(image_tensor_2D, 0, 1) # Clip the values to 0 and 1
    image_numpy = image_tensor_2D.to("cpu").detach().numpy()
    image_numpy_256 = image_numpy * 255
    image_numpy_256_uint8 = image_numpy_256.astype(np.uint8)
    image_to_save = Image.fromarray(image_numpy_256_uint8)  # TODO: Is there a shorter way to do this, similar to .convert("L")?
    image_to_save.save(f"{folder_name}/{image_name}.png")

In [46]:
def make_testcase():
    CHEST_XRAY_BASE_DATA_PATH = "../data/chest_xray"
    image = Image.open(f"{CHEST_XRAY_BASE_DATA_PATH}/train/NORMAL/IM-0115-0001.jpeg")
    image = crop_to_square_and_resize(image, 512)

    folder_name = "testcases_tmp"

    image_name = "chest_xray_clean"
    image.save(f"{folder_name}/{image_name}.png") 

    # Add noise to the image
    noisy_image_tensor_4D = add_noise(convert_to_tensor_4D(convert_to_numpy(image)), sigma=0.5)
    noisy_image_tensor_2D = noisy_image_tensor_4D.squeeze(0).squeeze(-1)
    noisy_image_name = "chest_xray_noisy"
    save_image(noisy_image_tensor_2D, "chest_xray_noisy", folder_name)

    # Read the saved images
    clean_image = Image.open(f"{folder_name}/{image_name}.png")
    noisy_image = Image.open(f"{folder_name}/{noisy_image_name}.png")
    

    plt.subplot(1, 2, 1)
    plt.imshow(clean_image, cmap='gray')
    plt.axis('off')
    plt.title("Clean Image")
    plt.subplot(1, 2, 2)
    plt.imshow(noisy_image, cmap='gray')
    plt.axis('off')
    plt.title("Noisy Image")
    plt.show();

# make_testcase()

#### Optional: Use wandb to log the training process

In [47]:
# Optional: Use wandb to log the training process
# !wandb login
def init_wandb(config):
    project_name = config["project"]
    os.environ['WANDB_NOTEBOOK_NAME'] = project_name
    os.environ['WANDB_MODE'] = config["wandb_mode"] # https://docs.wandb.ai/quickstart
    wandb.login()
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,

        # track hyperparameters and run metadata
        config=get_config(),
    )

In [48]:
# def temp_log_to_files():
#     model_states_dir = "tmp_2/model-chest_xray-2024_06_06_05_51_27"
#     config = get_config()
#     with open(f"{model_states_dir}/config.json", "w") as f:
#         json.dump(config, f, indent=4)
#     with open(f"{model_states_dir}/config.yaml", "w") as f:
#         yaml.dump(config, f)
#     with open(f"{model_states_dir}/config.txt", "w") as f:
#         f.write(str(config))

# def test_temp_log_to_files():
#     temp_log_to_files()

# test_temp_log_to_files()

#### Start training

In [49]:
# Code adapted from https://www.github.com/koflera/LearningRegularizationParameterMaps

def start_training(config, pretrained_model_path=None, is_state_dict=False, start_epoch=0):
    
    dataloader_train, dataloader_valid, dataloader_test = get_dataloaders(config)

    del dataloader_test # Not used for now

    if pretrained_model_path is None or is_state_dict:
        # Define CNN block
        unet = UNet3d(
            in_channels=config["in_channels"],
            out_channels=config["out_channels"],
            init_filters=config["init_filters"],
            n_blocks=config["n_blocks"],
            activation=config["activation"],
            downsampling_kernel=config["downsampling_kernel"],
            downsampling_mode=config["downsampling_mode"],
            upsampling_kernel=config["upsampling_kernel"],
            upsampling_mode=config["upsampling_mode"],
        ).to(DEVICE)

        # Construct primal-dual operator with nn
        pdhg = DynamicImageStaticPrimalDualNN(
            cnn_block=unet, 
            T=config["T"],
            phase="training",
            up_bound=config["up_bound"],
        ).to(DEVICE)
        if is_state_dict:
            pdhg.load_state_dict(torch.load(f"{model_states_dir}/{pretrained_model_path}.pt"))
    else:
        pdhg = torch.load(f"{model_states_dir}/{pretrained_model_path}.pt")

    pdhg.train(True)

    # TODO: Sometimes, creating the optimizer gives this error:
    #   AttributeError: partially initialized module 'torch._dynamo' has no attribute 'trace_rules' (most likely due to a circular import)
    optimizer = torch.optim.Adam(pdhg.parameters(), lr=config["learning_rate"])
    loss_function = torch.nn.MSELoss()

    num_epochs = config["epochs"]

    save_epoch_local = config["save_epoch_local"]
    save_epoch_wandb = config["save_epoch_wandb"]

    time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    project = config["project"]
    model_name = f"model-{project}-{time}"

    # Prepare to save the model
    save_dir = config["save_dir"]
    model_states_dir = f"{save_dir}/{model_name}"

    os.makedirs(model_states_dir, exist_ok=True)

    def log_to_files():
        with open(f"{model_states_dir}/config.json", "w") as f:
            json.dump(config, f, indent=4)
        with open(f"{model_states_dir}/config.yaml", "w") as f:
            yaml.dump(config, f)
        with open(f"{model_states_dir}/config.txt", "w") as f:
            f.write(str(config))
        with open(f"{model_states_dir}/unet.txt", "w") as f:
            f.write(str(unet))
        with open(f"{model_states_dir}/pdhg_net.txt", "w") as f:
            f.write(str(pdhg))

        def log_data(dataloader, stage):
            dataset = dataloader.dataset
            with open(f"{model_states_dir}/dataloader_{stage}.txt", "w") as f:
                f.write(f"Batch size: {dataloader.batch_size}\n\n")
                f.write(f"Number of batches: {len(dataloader)}\n\n")
                f.write(f"Number of samples: {len(dataset)}\n\n")
                f.write(f"Samples weights:\n{str(dataset.samples_weights)}\n\n")
                f.write(f"Sample 0 size:\n{str(len(dataset[0]))}  {str(dataset[0][0].size())}\n\n")
                f.write(f"Sample 0:\n{str(dataset[0])}\n\n")
        log_data(dataloader_train, "train")
        log_data(dataloader_valid, "val")
        # log_data(dataloader_test, "test")

    log_to_files()

    # noisy_image_path = "./testcases/chest_xray_noisy.png"
    # clean_image_path = "./testcases/chest_xray_clean.png"

    # def get_image(image_path):
    #     image = Image.open(image_path)
    #     image = image.convert("L")
    #     image_data = np.asarray(image)
    #     image_data = convert_to_tensor_4D(image_data)
    #     image_data = image_data.unsqueeze(0).to(DEVICE)
    #     return image_data

    # noisy_image_data = get_image(noisy_image_path)
    # clean_image_data = get_image(clean_image_path)

    # dataset_train = MyDataset(noisy_image_path, clean_image_path)
    # dataset_valid = MyDataset(noisy_image_path, clean_image_path)

    # dataloader_train = torch.utils.data.DataLoader(
    #     dataset_train, batch_size=1, 
    #     generator=torch.Generator(device=DEVICE),
    #     shuffle=True)
    # dataloader_valid = torch.utils.data.DataLoader(
    #     dataset_valid, batch_size=1, 
    #     generator=torch.Generator(device=DEVICE),
    #     shuffle=False)


    init_wandb(config)

    # for epoch in range(start_epoch, num_epochs):
    for epoch in tqdm(range(start_epoch, num_epochs)):

        # Model training
        pdhg.train(True)
        training_loss, training_psnr, training_ssim = train_epoch(pdhg, dataloader_train, optimizer, loss_function)
        # training_loss, training_psnr, training_ssim = train_iteration(optimizer, pdhg, loss_function, sample=(noisy_image_data, clean_image_data))
        # print(f"Epoch {epoch+1} - TRAINING LOSS: {training_loss} - TRAINING PSNR: {training_psnr} - TRAINING SSIM: {training_ssim}")

        # Optional: Use wandb to log training progress
        wandb.log({"training_loss": training_loss})
        wandb.log({"training PSNR": training_psnr})
        wandb.log({"training SSIM": training_ssim})

        del training_loss
        del training_psnr
        del training_ssim

        pdhg.train(False)
        with torch.no_grad():
            torch.cuda.empty_cache()

            # Model validation
            validation_loss, validation_psnr, validation_ssim = validate_epoch(pdhg, dataloader_valid, loss_function)
            # validation_loss, validation_psnr, validation_ssim = validate_iteration(pdhg, loss_function, sample=(noisy_image_data, clean_image_data))
            # print(f"Epoch {epoch+1} - VALIDATION LOSS: {validation_loss} - VALIDATION PSNR: {validation_psnr} - VALIDATION SSIM: {validation_ssim}")
            time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

            # Optional: Use wandb to log training progress
            wandb.log({"validation_loss": validation_loss})
            wandb.log({"validation PSNR": validation_psnr})
            wandb.log({"validation SSIM": validation_ssim})

            torch.cuda.empty_cache()


        if (epoch+1) % save_epoch_local == 0:
            current_model_name = f"model_epoch_{epoch+1}"
            torch.save(pdhg, f"{model_states_dir}/{current_model_name}.pt")
            
            print(f"Epoch {epoch+1} - VALIDATION LOSS: {validation_loss} - VALIDATION PSNR: {validation_psnr} - VALIDATION SSIM: {validation_ssim}")

        if (epoch+1) % save_epoch_wandb == 0:
            wandb.log_model(f"{model_states_dir}/{current_model_name}.pt", name=f"model_epoch_{epoch+1}")
            
        del validation_loss
        del validation_psnr
        del validation_ssim

        torch.cuda.empty_cache()


    # Save the entire model
    torch.save(pdhg, f"{model_states_dir}/final_model.pt")
    
    wandb.log_model(f"{model_states_dir}/final_model.pt", name=f"final_model")
    wandb.finish()
    
    with torch.no_grad():
        torch.cuda.empty_cache()

    return pdhg

In [50]:
with torch.no_grad():
    torch.cuda.empty_cache()

pdhg = start_training(get_config())

with torch.no_grad():
    torch.cuda.empty_cache()

100%|██████████| 200/200 [00:00<00:00, 531.38it/s]
100%|██████████| 8/8 [00:00<00:00, 422.29it/s]
100%|██████████| 1/1 [00:00<00:00, 156.53it/s]


Number of training samples: 200
Number of validation samples: 8
Number of test samples: 1


[34m[1mwandb[0m: Currently logged in as: [33mtrung-vuthanh24[0m ([33mwof[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 2/10000 [02:31<208:38:58, 75.13s/it]

Epoch 2 - VALIDATION LOSS: 0.000529649099917151 - VALIDATION PSNR: 33.54196548461914 - VALIDATION SSIM: 0.8761545017737343


  0%|          | 4/10000 [04:55<203:00:15, 73.11s/it]

Epoch 4 - VALIDATION LOSS: 0.0005272753405733965 - VALIDATION PSNR: 33.12583923339844 - VALIDATION SSIM: 0.8636786973334103


  0%|          | 6/10000 [07:16<198:18:08, 71.43s/it]

Epoch 6 - VALIDATION LOSS: 0.0005482014385052025 - VALIDATION PSNR: 32.74043273925781 - VALIDATION SSIM: 0.8676025050788522


  0%|          | 8/10000 [09:41<200:18:53, 72.17s/it]

Epoch 8 - VALIDATION LOSS: 0.0005328176594048273 - VALIDATION PSNR: 33.14311218261719 - VALIDATION SSIM: 0.8621784463845195


  0%|          | 9/10000 [10:53<199:34:33, 71.91s/it]

Epoch 10 - VALIDATION LOSS: 0.000640691021544626 - VALIDATION PSNR: 32.56898498535156 - VALIDATION SSIM: 0.8432286367046387


  0%|          | 12/10000 [14:29<199:39:52, 71.97s/it]

Epoch 12 - VALIDATION LOSS: 0.0006608612602576613 - VALIDATION PSNR: 32.43912124633789 - VALIDATION SSIM: 0.8413247543117851


  0%|          | 14/10000 [16:54<200:21:55, 72.23s/it]

Epoch 14 - VALIDATION LOSS: 0.0004593834310071543 - VALIDATION PSNR: 33.759403228759766 - VALIDATION SSIM: 0.8813983966838718


  0%|          | 16/10000 [19:13<196:04:15, 70.70s/it]

Epoch 16 - VALIDATION LOSS: 0.00041007623440236785 - VALIDATION PSNR: 34.25713348388672 - VALIDATION SSIM: 0.8907237733704746


  0%|          | 18/10000 [21:36<197:01:59, 71.06s/it]

Epoch 18 - VALIDATION LOSS: 0.000559211395739112 - VALIDATION PSNR: 33.036312103271484 - VALIDATION SSIM: 0.8685357870541216


  0%|          | 19/10000 [22:49<198:09:47, 71.47s/it]

Epoch 20 - VALIDATION LOSS: 0.00045891201807535253 - VALIDATION PSNR: 34.292633056640625 - VALIDATION SSIM: 0.8914911675924136


  0%|          | 22/10000 [26:22<197:18:09, 71.19s/it]

Epoch 22 - VALIDATION LOSS: 0.00048368175157520454 - VALIDATION PSNR: 33.88203811645508 - VALIDATION SSIM: 0.8823628774550856


  0%|          | 24/10000 [28:47<199:18:07, 71.92s/it]

Epoch 24 - VALIDATION LOSS: 0.0003871411936415825 - VALIDATION PSNR: 34.71660614013672 - VALIDATION SSIM: 0.899261038973987


  0%|          | 26/10000 [31:12<200:47:32, 72.47s/it]

Epoch 26 - VALIDATION LOSS: 0.00040500891373085324 - VALIDATION PSNR: 34.7791862487793 - VALIDATION SSIM: 0.8987308402458132


  0%|          | 28/10000 [33:39<202:09:24, 72.98s/it]

Epoch 28 - VALIDATION LOSS: 0.00035001253309019376 - VALIDATION PSNR: 34.83082580566406 - VALIDATION SSIM: 0.9062612866925299


  0%|          | 29/10000 [34:51<201:11:07, 72.64s/it]

Epoch 30 - VALIDATION LOSS: 0.00037827111555088777 - VALIDATION PSNR: 34.88362121582031 - VALIDATION SSIM: 0.9053624335483611


  0%|          | 32/10000 [38:26<199:13:18, 71.95s/it]

Epoch 32 - VALIDATION LOSS: 0.00043932616608799435 - VALIDATION PSNR: 34.39997863769531 - VALIDATION SSIM: 0.8941261506207436


  0%|          | 34/10000 [40:54<201:44:49, 72.88s/it]

Epoch 34 - VALIDATION LOSS: 0.0002972186866827542 - VALIDATION PSNR: 35.72433853149414 - VALIDATION SSIM: 0.9175557815672457


  0%|          | 36/10000 [43:17<199:09:38, 71.96s/it]

Epoch 36 - VALIDATION LOSS: 0.00027310905534250196 - VALIDATION PSNR: 36.071781158447266 - VALIDATION SSIM: 0.9248201044979394


  0%|          | 38/10000 [45:41<199:12:27, 71.99s/it]

Epoch 38 - VALIDATION LOSS: 0.0003665745571197476 - VALIDATION PSNR: 34.615421295166016 - VALIDATION SSIM: 0.906040623887837


  0%|          | 39/10000 [46:50<196:50:38, 71.14s/it]

Epoch 40 - VALIDATION LOSS: 0.0004912187614536379 - VALIDATION PSNR: 33.435672760009766 - VALIDATION SSIM: 0.8811187613446116


  0%|          | 42/10000 [50:21<193:54:58, 70.10s/it]

Epoch 42 - VALIDATION LOSS: 0.00038893837700015865 - VALIDATION PSNR: 34.494598388671875 - VALIDATION SSIM: 0.9018046757439078


  0%|          | 44/10000 [52:43<195:29:18, 70.69s/it]

Epoch 44 - VALIDATION LOSS: 0.00036939786514267325 - VALIDATION PSNR: 34.74403762817383 - VALIDATION SSIM: 0.9039224660063983


  0%|          | 46/10000 [55:05<195:50:50, 70.83s/it]

Epoch 46 - VALIDATION LOSS: 0.00030232896460802294 - VALIDATION PSNR: 35.81128692626953 - VALIDATION SSIM: 0.9209830250827672


  0%|          | 48/10000 [57:26<195:24:12, 70.68s/it]

Epoch 48 - VALIDATION LOSS: 0.00034077424788847566 - VALIDATION PSNR: 34.9600830078125 - VALIDATION SSIM: 0.9119537046234012


  0%|          | 49/10000 [58:34<193:04:32, 69.85s/it]

Epoch 50 - VALIDATION LOSS: 0.00021923283020441886 - VALIDATION PSNR: 37.046287536621094 - VALIDATION SSIM: 0.9386927822607756


  1%|          | 52/10000 [1:02:02<191:53:40, 69.44s/it]

Epoch 52 - VALIDATION LOSS: 0.0003673026949400082 - VALIDATION PSNR: 35.33648681640625 - VALIDATION SSIM: 0.910733312981695


  1%|          | 54/10000 [1:04:26<195:24:47, 70.73s/it]

Epoch 54 - VALIDATION LOSS: 0.00029702629217354115 - VALIDATION PSNR: 35.50544357299805 - VALIDATION SSIM: 0.9196727258666456


  1%|          | 56/10000 [1:06:48<196:07:28, 71.00s/it]

Epoch 56 - VALIDATION LOSS: 0.0002957500983029604 - VALIDATION PSNR: 35.736873626708984 - VALIDATION SSIM: 0.9229258887722791


  1%|          | 58/10000 [1:09:08<194:19:23, 70.36s/it]

Epoch 58 - VALIDATION LOSS: 0.0002680224788491614 - VALIDATION PSNR: 36.27826690673828 - VALIDATION SSIM: 0.928620468129456


  1%|          | 59/10000 [1:10:16<192:12:16, 69.60s/it]

Epoch 60 - VALIDATION LOSS: 0.0003491407851470285 - VALIDATION PSNR: 35.05546569824219 - VALIDATION SSIM: 0.9118540765836238


  1%|          | 62/10000 [1:13:49<194:20:04, 70.40s/it]

Epoch 62 - VALIDATION LOSS: 0.00027957662859989796 - VALIDATION PSNR: 35.796905517578125 - VALIDATION SSIM: 0.9246150663497822


  1%|          | 64/10000 [1:16:13<196:10:23, 71.08s/it]

Epoch 64 - VALIDATION LOSS: 0.0003529121859173756 - VALIDATION PSNR: 34.927581787109375 - VALIDATION SSIM: 0.9108519602858127


  1%|          | 66/10000 [1:18:32<193:51:35, 70.25s/it]

Epoch 66 - VALIDATION LOSS: 0.00035089821722067427 - VALIDATION PSNR: 34.998321533203125 - VALIDATION SSIM: 0.9105672003024518


  1%|          | 68/10000 [1:20:51<192:28:45, 69.77s/it]

Epoch 68 - VALIDATION LOSS: 0.000347106069966685 - VALIDATION PSNR: 34.970523834228516 - VALIDATION SSIM: 0.9122961050192266


  1%|          | 69/10000 [1:22:03<194:27:31, 70.49s/it]

Epoch 70 - VALIDATION LOSS: 0.00023465396861865884 - VALIDATION PSNR: 36.85015869140625 - VALIDATION SSIM: 0.9356771672672928


  1%|          | 72/10000 [1:25:32<193:45:58, 70.26s/it]

Epoch 72 - VALIDATION LOSS: 0.00040097639794112183 - VALIDATION PSNR: 34.385955810546875 - VALIDATION SSIM: 0.9022229073853791


  1%|          | 74/10000 [1:27:53<194:32:11, 70.56s/it]

Epoch 74 - VALIDATION LOSS: 0.00023762480486766435 - VALIDATION PSNR: 36.63970184326172 - VALIDATION SSIM: 0.935305886307925


  1%|          | 76/10000 [1:30:15<194:36:24, 70.59s/it]

Epoch 76 - VALIDATION LOSS: 0.0003940462302125525 - VALIDATION PSNR: 34.74559783935547 - VALIDATION SSIM: 0.9054066377715915


  1%|          | 78/10000 [1:32:38<196:27:44, 71.28s/it]

Epoch 78 - VALIDATION LOSS: 0.00034513271566538606 - VALIDATION PSNR: 34.963287353515625 - VALIDATION SSIM: 0.912837093700707


  1%|          | 79/10000 [1:33:49<196:41:55, 71.38s/it]

Epoch 80 - VALIDATION LOSS: 0.00026064098892675247 - VALIDATION PSNR: 35.9832649230957 - VALIDATION SSIM: 0.9302587318755686


  1%|          | 82/10000 [1:37:25<197:28:28, 71.68s/it]

Epoch 82 - VALIDATION LOSS: 0.000311141305246565 - VALIDATION PSNR: 35.959598541259766 - VALIDATION SSIM: 0.9215619608459173


  1%|          | 84/10000 [1:39:49<198:51:43, 72.20s/it]

Epoch 84 - VALIDATION LOSS: 0.00034039319143630564 - VALIDATION PSNR: 35.30596160888672 - VALIDATION SSIM: 0.9150253140468971


  1%|          | 86/10000 [1:42:14<199:05:58, 72.30s/it]

Epoch 86 - VALIDATION LOSS: 0.00031303641480917577 - VALIDATION PSNR: 35.61842727661133 - VALIDATION SSIM: 0.9201453929680586


  1%|          | 88/10000 [1:44:34<195:47:48, 71.11s/it]

Epoch 88 - VALIDATION LOSS: 0.00033145905763376504 - VALIDATION PSNR: 34.996612548828125 - VALIDATION SSIM: 0.9152106139618306


  1%|          | 89/10000 [1:45:43<194:07:15, 70.51s/it]

Epoch 90 - VALIDATION LOSS: 0.0002912766594818095 - VALIDATION PSNR: 35.82451248168945 - VALIDATION SSIM: 0.9247400425116421


  1%|          | 92/10000 [1:49:15<194:31:38, 70.68s/it]

Epoch 92 - VALIDATION LOSS: 0.0002547591002439731 - VALIDATION PSNR: 36.3867073059082 - VALIDATION SSIM: 0.9316996127523184


  1%|          | 94/10000 [1:51:32<191:48:18, 69.71s/it]

Epoch 94 - VALIDATION LOSS: 0.0003408956890780246 - VALIDATION PSNR: 35.07303237915039 - VALIDATION SSIM: 0.9142447826727926


  1%|          | 96/10000 [1:53:50<191:08:46, 69.48s/it]

Epoch 96 - VALIDATION LOSS: 0.00032813549933052855 - VALIDATION PSNR: 35.54044723510742 - VALIDATION SSIM: 0.9184089979340732


  1%|          | 98/10000 [1:56:13<194:07:46, 70.58s/it]

Epoch 98 - VALIDATION LOSS: 0.0003458922565187095 - VALIDATION PSNR: 35.56175994873047 - VALIDATION SSIM: 0.9153265869057029


  1%|          | 99/10000 [1:57:23<193:22:38, 70.31s/it]

Epoch 100 - VALIDATION LOSS: 0.0003343791941006202 - VALIDATION PSNR: 34.9278450012207 - VALIDATION SSIM: 0.9154135895027817


  1%|          | 102/10000 [2:00:56<194:38:58, 70.80s/it]

Epoch 102 - VALIDATION LOSS: 0.00033844957033579703 - VALIDATION PSNR: 35.00516128540039 - VALIDATION SSIM: 0.9148561459319889


  1%|          | 104/10000 [2:03:21<196:22:30, 71.44s/it]

Epoch 104 - VALIDATION LOSS: 0.0003360724585945718 - VALIDATION PSNR: 35.02655792236328 - VALIDATION SSIM: 0.915411765521109


  1%|          | 106/10000 [2:05:50<200:57:38, 73.12s/it]

Epoch 106 - VALIDATION LOSS: 0.0003593874935177155 - VALIDATION PSNR: 34.768558502197266 - VALIDATION SSIM: 0.9107075823293624


  1%|          | 108/10000 [2:08:12<198:08:52, 72.11s/it]

Epoch 108 - VALIDATION LOSS: 0.00034939427314384375 - VALIDATION PSNR: 35.13239288330078 - VALIDATION SSIM: 0.9133848748758138


  1%|          | 109/10000 [2:09:24<197:39:35, 71.94s/it]

Epoch 110 - VALIDATION LOSS: 0.0002870567222998943 - VALIDATION PSNR: 35.59452819824219 - VALIDATION SSIM: 0.9243994943804146


  1%|          | 112/10000 [2:12:55<193:41:56, 70.52s/it]

Epoch 112 - VALIDATION LOSS: 0.0003005559501616517 - VALIDATION PSNR: 36.033512115478516 - VALIDATION SSIM: 0.9252033938596547


  1%|          | 114/10000 [2:15:18<195:23:55, 71.15s/it]

Epoch 114 - VALIDATION LOSS: 0.00028121394279878587 - VALIDATION PSNR: 36.472042083740234 - VALIDATION SSIM: 0.9278268400105536


  1%|          | 116/10000 [2:17:36<192:31:51, 70.12s/it]

Epoch 116 - VALIDATION LOSS: 0.000359990015567746 - VALIDATION PSNR: 35.04943084716797 - VALIDATION SSIM: 0.9123132260185536


  1%|          | 118/10000 [2:19:59<194:16:33, 70.77s/it]

Epoch 118 - VALIDATION LOSS: 0.00029601975438708905 - VALIDATION PSNR: 35.634117126464844 - VALIDATION SSIM: 0.9244800361987352


  1%|          | 119/10000 [2:21:06<190:42:18, 69.48s/it]

Epoch 120 - VALIDATION LOSS: 0.0002469412029313389 - VALIDATION PSNR: 36.44025802612305 - VALIDATION SSIM: 0.9339513930444867


  1%|          | 122/10000 [2:24:37<191:33:39, 69.81s/it]

Epoch 122 - VALIDATION LOSS: 0.0002518455794415786 - VALIDATION PSNR: 36.671024322509766 - VALIDATION SSIM: 0.9335212541900874


  1%|          | 124/10000 [2:26:54<190:02:21, 69.27s/it]

Epoch 124 - VALIDATION LOSS: 0.0002763662678262335 - VALIDATION PSNR: 36.289215087890625 - VALIDATION SSIM: 0.928150092772156


  1%|▏         | 126/10000 [2:29:14<190:27:59, 69.44s/it]

Epoch 126 - VALIDATION LOSS: 0.00033188222369062714 - VALIDATION PSNR: 35.02590560913086 - VALIDATION SSIM: 0.9153557106952369


  1%|▏         | 128/10000 [2:31:34<191:52:43, 69.97s/it]

Epoch 128 - VALIDATION LOSS: 0.00030083262390689924 - VALIDATION PSNR: 35.61758804321289 - VALIDATION SSIM: 0.9228709043703414


  1%|▏         | 129/10000 [2:32:44<191:40:52, 69.91s/it]

Epoch 130 - VALIDATION LOSS: 0.0003105562182099675 - VALIDATION PSNR: 35.66880416870117 - VALIDATION SSIM: 0.9211706869410873


  1%|▏         | 132/10000 [2:36:15<192:57:36, 70.39s/it]

Epoch 132 - VALIDATION LOSS: 0.00035625053715193644 - VALIDATION PSNR: 34.889183044433594 - VALIDATION SSIM: 0.9122425550304055


  1%|▏         | 134/10000 [2:38:29<188:43:06, 68.86s/it]

Epoch 134 - VALIDATION LOSS: 0.00033243604593735654 - VALIDATION PSNR: 35.61548614501953 - VALIDATION SSIM: 0.9169084310282171


  1%|▏         | 136/10000 [2:40:47<188:25:14, 68.77s/it]

Epoch 136 - VALIDATION LOSS: 0.000248892754825647 - VALIDATION PSNR: 36.21527862548828 - VALIDATION SSIM: 0.9331591341245176


  1%|▏         | 138/10000 [2:43:04<188:08:07, 68.68s/it]

Epoch 138 - VALIDATION LOSS: 0.00031843715623836033 - VALIDATION PSNR: 35.555931091308594 - VALIDATION SSIM: 0.9198114882367254


  1%|▏         | 139/10000 [2:44:15<190:30:19, 69.55s/it]

Epoch 140 - VALIDATION LOSS: 0.00033345598967571277 - VALIDATION PSNR: 35.36564254760742 - VALIDATION SSIM: 0.9173023254754842


  1%|▏         | 142/10000 [2:47:38<186:57:22, 68.27s/it]

Epoch 142 - VALIDATION LOSS: 0.00036535420986183453 - VALIDATION PSNR: 34.74984359741211 - VALIDATION SSIM: 0.9095958930162191


  1%|▏         | 144/10000 [2:49:54<186:06:54, 67.98s/it]

Epoch 144 - VALIDATION LOSS: 0.00027643107023322955 - VALIDATION PSNR: 35.88651657104492 - VALIDATION SSIM: 0.9268287914208174


  1%|▏         | 146/10000 [2:52:09<186:11:36, 68.02s/it]

Epoch 146 - VALIDATION LOSS: 0.0002766620837064693 - VALIDATION PSNR: 36.26758575439453 - VALIDATION SSIM: 0.9290746532284617


  1%|▏         | 148/10000 [2:54:25<185:07:58, 67.65s/it]

Epoch 148 - VALIDATION LOSS: 0.0002946799013443524 - VALIDATION PSNR: 35.946510314941406 - VALIDATION SSIM: 0.925978309684217


  1%|▏         | 149/10000 [2:55:31<183:44:23, 67.15s/it]

Epoch 150 - VALIDATION LOSS: 0.00034150109422625974 - VALIDATION PSNR: 35.085445404052734 - VALIDATION SSIM: 0.9137340226638614


  2%|▏         | 152/10000 [2:58:53<183:05:15, 66.93s/it]

Epoch 152 - VALIDATION LOSS: 0.000309484592435183 - VALIDATION PSNR: 35.80126953125 - VALIDATION SSIM: 0.9222302638097704


  2%|▏         | 154/10000 [3:01:14<188:45:15, 69.01s/it]

Epoch 154 - VALIDATION LOSS: 0.0003561236408131663 - VALIDATION PSNR: 34.63604736328125 - VALIDATION SSIM: 0.9102271287562549


  2%|▏         | 156/10000 [3:03:29<186:10:07, 68.08s/it]

Epoch 156 - VALIDATION LOSS: 0.0002564784472269821 - VALIDATION PSNR: 36.45441436767578 - VALIDATION SSIM: 0.9328806550278663


  2%|▏         | 158/10000 [3:05:45<186:00:01, 68.04s/it]

Epoch 158 - VALIDATION LOSS: 0.00031043196577229537 - VALIDATION PSNR: 35.40005111694336 - VALIDATION SSIM: 0.9190166779707671


  2%|▏         | 159/10000 [3:06:50<183:29:58, 67.13s/it]

Epoch 160 - VALIDATION LOSS: 0.0002694737695492222 - VALIDATION PSNR: 36.38864517211914 - VALIDATION SSIM: 0.9302303025422394


  2%|▏         | 162/10000 [3:10:15<184:31:29, 67.52s/it]

Epoch 162 - VALIDATION LOSS: 0.00029538580656662816 - VALIDATION PSNR: 36.070098876953125 - VALIDATION SSIM: 0.9262582060577274


  2%|▏         | 164/10000 [3:12:29<184:08:45, 67.40s/it]

Epoch 164 - VALIDATION LOSS: 0.00026145575611735694 - VALIDATION PSNR: 36.44545364379883 - VALIDATION SSIM: 0.9325863391139507


  2%|▏         | 166/10000 [3:14:41<182:11:12, 66.69s/it]

Epoch 166 - VALIDATION LOSS: 0.00024891187331377296 - VALIDATION PSNR: 36.58866882324219 - VALIDATION SSIM: 0.9346448818268777


  2%|▏         | 168/10000 [3:16:54<181:44:51, 66.55s/it]

Epoch 168 - VALIDATION LOSS: 0.00028515247322502546 - VALIDATION PSNR: 35.67227554321289 - VALIDATION SSIM: 0.9241216882035732


  2%|▏         | 169/10000 [3:18:03<183:56:29, 67.36s/it]

Epoch 170 - VALIDATION LOSS: 0.00019492914179863874 - VALIDATION PSNR: 37.219886779785156 - VALIDATION SSIM: 0.9450116080608368


  2%|▏         | 172/10000 [3:21:23<181:09:14, 66.36s/it]

Epoch 172 - VALIDATION LOSS: 0.0003023121762453229 - VALIDATION PSNR: 36.26837921142578 - VALIDATION SSIM: 0.9268178145405352


  2%|▏         | 174/10000 [3:23:39<183:20:07, 67.17s/it]

Epoch 174 - VALIDATION LOSS: 0.0002973877217300469 - VALIDATION PSNR: 35.7850456237793 - VALIDATION SSIM: 0.9248237727540731


  2%|▏         | 176/10000 [3:25:49<180:18:05, 66.07s/it]

Epoch 176 - VALIDATION LOSS: 0.00034867860995291267 - VALIDATION PSNR: 35.30437469482422 - VALIDATION SSIM: 0.91454977151376


  2%|▏         | 178/10000 [3:28:07<183:54:08, 67.40s/it]

Epoch 178 - VALIDATION LOSS: 0.000286972128378693 - VALIDATION PSNR: 35.799659729003906 - VALIDATION SSIM: 0.9261088892021179


  2%|▏         | 179/10000 [3:29:11<180:53:11, 66.31s/it]

Epoch 180 - VALIDATION LOSS: 0.00027443369526736205 - VALIDATION PSNR: 36.400203704833984 - VALIDATION SSIM: 0.9297228274735212


  2%|▏         | 182/10000 [3:32:35<184:21:29, 67.60s/it]

Epoch 182 - VALIDATION LOSS: 0.00035158868558937684 - VALIDATION PSNR: 35.001548767089844 - VALIDATION SSIM: 0.9138487049567252


  2%|▏         | 184/10000 [3:34:45<180:48:41, 66.31s/it]

Epoch 184 - VALIDATION LOSS: 0.0002520295975045883 - VALIDATION PSNR: 36.45884323120117 - VALIDATION SSIM: 0.9341832568040639


  2%|▏         | 186/10000 [3:37:00<182:07:49, 66.81s/it]

Epoch 186 - VALIDATION LOSS: 0.00029419652673823293 - VALIDATION PSNR: 35.777835845947266 - VALIDATION SSIM: 0.9250544037533998


  2%|▏         | 188/10000 [3:39:10<180:11:28, 66.11s/it]

Epoch 188 - VALIDATION LOSS: 0.0002715908867685357 - VALIDATION PSNR: 35.93291091918945 - VALIDATION SSIM: 0.928843917667806


  2%|▏         | 189/10000 [3:40:19<182:19:37, 66.90s/it]

Epoch 190 - VALIDATION LOSS: 0.00032606338936602697 - VALIDATION PSNR: 35.10631561279297 - VALIDATION SSIM: 0.9176880382221342


  2%|▏         | 192/10000 [3:43:40<182:08:01, 66.85s/it]

Epoch 192 - VALIDATION LOSS: 0.00024160092470992822 - VALIDATION PSNR: 36.464111328125 - VALIDATION SSIM: 0.9355003487617076


  2%|▏         | 194/10000 [3:45:54<181:50:59, 66.76s/it]

Epoch 194 - VALIDATION LOSS: 0.0002790745857055299 - VALIDATION PSNR: 35.84130859375 - VALIDATION SSIM: 0.9275471413393318


  2%|▏         | 196/10000 [3:48:03<179:38:14, 65.96s/it]

Epoch 196 - VALIDATION LOSS: 0.00027364871675672475 - VALIDATION PSNR: 36.544795989990234 - VALIDATION SSIM: 0.9313261407248378


  2%|▏         | 198/10000 [3:50:18<181:13:26, 66.56s/it]

Epoch 198 - VALIDATION LOSS: 0.0002934743133664597 - VALIDATION PSNR: 35.7060661315918 - VALIDATION SSIM: 0.9238618588876724


  2%|▏         | 199/10000 [3:51:22<178:54:21, 65.71s/it]

Epoch 200 - VALIDATION LOSS: 0.00036478156471275724 - VALIDATION PSNR: 34.66748046875 - VALIDATION SSIM: 0.90987455339095


  2%|▏         | 202/10000 [3:54:45<181:24:41, 66.65s/it]

Epoch 202 - VALIDATION LOSS: 0.0002964882314699935 - VALIDATION PSNR: 35.962215423583984 - VALIDATION SSIM: 0.9259514474279136


  2%|▏         | 204/10000 [3:56:59<182:34:59, 67.10s/it]

Epoch 204 - VALIDATION LOSS: 0.00028318175736785633 - VALIDATION PSNR: 35.89763641357422 - VALIDATION SSIM: 0.9268292911542355


  2%|▏         | 206/10000 [3:59:11<180:30:53, 66.35s/it]

Epoch 206 - VALIDATION LOSS: 0.00031411071722686756 - VALIDATION PSNR: 35.175682067871094 - VALIDATION SSIM: 0.9195810060505867


  2%|▏         | 208/10000 [4:01:26<182:14:12, 67.00s/it]

Epoch 208 - VALIDATION LOSS: 0.0002825392730301246 - VALIDATION PSNR: 35.88070297241211 - VALIDATION SSIM: 0.9263142667081355


  2%|▏         | 209/10000 [4:02:33<182:52:33, 67.24s/it]

Epoch 210 - VALIDATION LOSS: 0.0003208733414794551 - VALIDATION PSNR: 35.283756256103516 - VALIDATION SSIM: 0.9188420100540222


  2%|▏         | 212/10000 [4:05:52<181:55:55, 66.91s/it]

Epoch 212 - VALIDATION LOSS: 0.00036683038706541993 - VALIDATION PSNR: 34.86371612548828 - VALIDATION SSIM: 0.9099643606356977


  2%|▏         | 214/10000 [4:08:04<179:54:50, 66.19s/it]

Epoch 214 - VALIDATION LOSS: 0.00029069105767121073 - VALIDATION PSNR: 35.67029571533203 - VALIDATION SSIM: 0.9238623330842257


  2%|▏         | 216/10000 [4:10:19<182:11:33, 67.04s/it]

Epoch 216 - VALIDATION LOSS: 0.00035213758656027494 - VALIDATION PSNR: 35.328243255615234 - VALIDATION SSIM: 0.9149887891363502


  2%|▏         | 218/10000 [4:12:29<178:38:16, 65.74s/it]

Epoch 218 - VALIDATION LOSS: 0.0003116064763162285 - VALIDATION PSNR: 35.36986541748047 - VALIDATION SSIM: 0.9195555667069554


  2%|▏         | 219/10000 [4:13:38<181:16:52, 66.72s/it]

Epoch 220 - VALIDATION LOSS: 0.00036297169026511256 - VALIDATION PSNR: 34.77337646484375 - VALIDATION SSIM: 0.9095520766729713


  2%|▏         | 222/10000 [4:16:56<178:40:48, 65.79s/it]

Epoch 222 - VALIDATION LOSS: 0.00033085439281421714 - VALIDATION PSNR: 34.86854934692383 - VALIDATION SSIM: 0.9154157181285918


  2%|▏         | 224/10000 [4:19:15<184:07:11, 67.80s/it]

Epoch 224 - VALIDATION LOSS: 0.0003851678811770398 - VALIDATION PSNR: 34.77357482910156 - VALIDATION SSIM: 0.9066855047324002


  2%|▏         | 226/10000 [4:21:22<178:01:49, 65.57s/it]

Epoch 226 - VALIDATION LOSS: 0.00021680699683201965 - VALIDATION PSNR: 36.7559928894043 - VALIDATION SSIM: 0.9391128912281096


  2%|▏         | 228/10000 [4:23:40<183:05:43, 67.45s/it]

Epoch 228 - VALIDATION LOSS: 0.00024628059327369556 - VALIDATION PSNR: 36.37434768676758 - VALIDATION SSIM: 0.9313689734283387


  2%|▏         | 229/10000 [4:24:45<181:05:20, 66.72s/it]

Epoch 230 - VALIDATION LOSS: 0.00027362506170902634 - VALIDATION PSNR: 35.94867706298828 - VALIDATION SSIM: 0.9285044610626996


  2%|▏         | 232/10000 [4:28:09<183:37:53, 67.68s/it]

Epoch 232 - VALIDATION LOSS: 0.000264236916336813 - VALIDATION PSNR: 36.64875793457031 - VALIDATION SSIM: 0.9316752725005149


  2%|▏         | 234/10000 [4:30:19<180:10:02, 66.41s/it]

Epoch 234 - VALIDATION LOSS: 0.00036179366361466236 - VALIDATION PSNR: 35.21176528930664 - VALIDATION SSIM: 0.9121333028935492


  2%|▏         | 236/10000 [4:32:37<183:24:02, 67.62s/it]

Epoch 236 - VALIDATION LOSS: 0.00029359376640059054 - VALIDATION PSNR: 35.638851165771484 - VALIDATION SSIM: 0.9243041349476278


  2%|▏         | 238/10000 [4:34:54<184:22:38, 67.99s/it]

Epoch 238 - VALIDATION LOSS: 0.0003333763770569931 - VALIDATION PSNR: 35.67246627807617 - VALIDATION SSIM: 0.9195459636522233


  2%|▏         | 239/10000 [4:35:58<181:06:53, 66.80s/it]

Epoch 240 - VALIDATION LOSS: 0.00036813282440562034 - VALIDATION PSNR: 35.1348876953125 - VALIDATION SSIM: 0.9117472208134533


  2%|▏         | 242/10000 [4:39:22<183:27:04, 67.68s/it]

Epoch 242 - VALIDATION LOSS: 0.00036942904444003943 - VALIDATION PSNR: 34.87533187866211 - VALIDATION SSIM: 0.9083314280927777


  2%|▏         | 244/10000 [4:41:31<179:15:30, 66.15s/it]

Epoch 244 - VALIDATION LOSS: 0.00030009218789928127 - VALIDATION PSNR: 36.0730094909668 - VALIDATION SSIM: 0.9239983661218285


  2%|▏         | 246/10000 [4:43:49<182:53:02, 67.50s/it]

Epoch 246 - VALIDATION LOSS: 0.00029161613383621443 - VALIDATION PSNR: 35.682525634765625 - VALIDATION SSIM: 0.9244437668302208


  2%|▏         | 248/10000 [4:46:00<180:09:11, 66.50s/it]

Epoch 248 - VALIDATION LOSS: 0.00028319407647359185 - VALIDATION PSNR: 35.677696228027344 - VALIDATION SSIM: 0.9252815439525545


  2%|▏         | 249/10000 [4:47:09<182:09:15, 67.25s/it]

Epoch 250 - VALIDATION LOSS: 0.00034875919300247915 - VALIDATION PSNR: 34.65412139892578 - VALIDATION SSIM: 0.9123527431230547


  3%|▎         | 252/10000 [4:50:28<180:49:51, 66.78s/it]

Epoch 252 - VALIDATION LOSS: 0.00030589326343033463 - VALIDATION PSNR: 35.669189453125 - VALIDATION SSIM: 0.9219303118260502


  3%|▎         | 254/10000 [4:52:38<178:25:57, 65.91s/it]

Epoch 254 - VALIDATION LOSS: 0.00028151731294201454 - VALIDATION PSNR: 36.1456413269043 - VALIDATION SSIM: 0.9279932082232236


  3%|▎         | 256/10000 [4:54:56<182:29:18, 67.42s/it]

Epoch 256 - VALIDATION LOSS: 0.00028035600917064585 - VALIDATION PSNR: 36.72664260864258 - VALIDATION SSIM: 0.9305701383580715


  3%|▎         | 258/10000 [4:57:06<179:51:09, 66.46s/it]

Epoch 258 - VALIDATION LOSS: 0.00031594229221809655 - VALIDATION PSNR: 35.10581588745117 - VALIDATION SSIM: 0.9184106435637771


  3%|▎         | 259/10000 [4:58:16<182:52:15, 67.58s/it]

Epoch 260 - VALIDATION LOSS: 0.0003010830860148417 - VALIDATION PSNR: 35.83244705200195 - VALIDATION SSIM: 0.9230903165376783


  3%|▎         | 262/10000 [5:01:39<182:10:18, 67.35s/it]

Epoch 262 - VALIDATION LOSS: 0.00036459519287745934 - VALIDATION PSNR: 34.68967056274414 - VALIDATION SSIM: 0.9098358403179198


  3%|▎         | 264/10000 [5:03:52<180:20:15, 66.68s/it]

Epoch 264 - VALIDATION LOSS: 0.0003356430252097198 - VALIDATION PSNR: 35.34543991088867 - VALIDATION SSIM: 0.9164926865060925


  3%|▎         | 266/10000 [5:06:06<181:21:41, 67.07s/it]

Epoch 266 - VALIDATION LOSS: 0.00032929301414696965 - VALIDATION PSNR: 35.280433654785156 - VALIDATION SSIM: 0.9169970019387305


  3%|▎         | 268/10000 [5:08:19<179:52:09, 66.54s/it]

Epoch 268 - VALIDATION LOSS: 0.00026458745560375974 - VALIDATION PSNR: 36.08882522583008 - VALIDATION SSIM: 0.9296759120440782


  3%|▎         | 269/10000 [5:09:25<179:13:04, 66.30s/it]

Epoch 270 - VALIDATION LOSS: 0.0003118315744359279 - VALIDATION PSNR: 35.22855758666992 - VALIDATION SSIM: 0.9200015018681288


  3%|▎         | 272/10000 [5:12:48<180:19:15, 66.73s/it]

Epoch 272 - VALIDATION LOSS: 0.00030378371047845576 - VALIDATION PSNR: 35.77657699584961 - VALIDATION SSIM: 0.922428287843585


  3%|▎         | 274/10000 [5:15:02<181:38:36, 67.23s/it]

Epoch 274 - VALIDATION LOSS: 0.00029949019335617777 - VALIDATION PSNR: 35.56055450439453 - VALIDATION SSIM: 0.922325068249464


  3%|▎         | 276/10000 [5:17:16<180:47:29, 66.93s/it]

Epoch 276 - VALIDATION LOSS: 0.00029611906393256504 - VALIDATION PSNR: 35.46913146972656 - VALIDATION SSIM: 0.9222652174130379


  3%|▎         | 278/10000 [5:19:33<182:58:24, 67.75s/it]

Epoch 278 - VALIDATION LOSS: 0.00025011565958266146 - VALIDATION PSNR: 36.29338836669922 - VALIDATION SSIM: 0.9326839519882202


  3%|▎         | 279/10000 [5:20:40<182:49:54, 67.71s/it]

Epoch 280 - VALIDATION LOSS: 0.00035907120764022693 - VALIDATION PSNR: 34.953704833984375 - VALIDATION SSIM: 0.9121282201841026


  3%|▎         | 282/10000 [5:24:05<184:46:58, 68.45s/it]

Epoch 282 - VALIDATION LOSS: 0.0003399142988200765 - VALIDATION PSNR: 35.2580451965332 - VALIDATION SSIM: 0.9160443183009923


  3%|▎         | 284/10000 [5:26:15<179:20:36, 66.45s/it]

Epoch 284 - VALIDATION LOSS: 0.00035362361450097524 - VALIDATION PSNR: 34.78614044189453 - VALIDATION SSIM: 0.9109451564477979


  3%|▎         | 286/10000 [5:28:33<183:23:33, 67.97s/it]

Epoch 286 - VALIDATION LOSS: 0.0002630924736877205 - VALIDATION PSNR: 36.34474182128906 - VALIDATION SSIM: 0.9309860425941348


  3%|▎         | 288/10000 [5:30:43<178:53:44, 66.31s/it]

Epoch 288 - VALIDATION LOSS: 0.0003333433123771101 - VALIDATION PSNR: 34.92708969116211 - VALIDATION SSIM: 0.9146633686380385


  3%|▎         | 289/10000 [5:31:52<181:01:24, 67.11s/it]

Epoch 290 - VALIDATION LOSS: 0.0004067187473992817 - VALIDATION PSNR: 34.32199478149414 - VALIDATION SSIM: 0.902222355234787


  3%|▎         | 292/10000 [5:35:12<179:26:38, 66.54s/it]

Epoch 292 - VALIDATION LOSS: 0.0002822849801304983 - VALIDATION PSNR: 35.78861999511719 - VALIDATION SSIM: 0.9262097134944797


  3%|▎         | 294/10000 [5:37:31<183:19:36, 68.00s/it]

Epoch 294 - VALIDATION LOSS: 0.000267428681581805 - VALIDATION PSNR: 36.44463348388672 - VALIDATION SSIM: 0.9310924164261818


  3%|▎         | 296/10000 [5:39:40<179:02:27, 66.42s/it]

Epoch 296 - VALIDATION LOSS: 0.00036858440944342874 - VALIDATION PSNR: 34.861995697021484 - VALIDATION SSIM: 0.9102322366385982


  3%|▎         | 298/10000 [5:41:59<182:33:14, 67.74s/it]

Epoch 298 - VALIDATION LOSS: 0.00031470633621211164 - VALIDATION PSNR: 35.20939636230469 - VALIDATION SSIM: 0.9198216558849812


  3%|▎         | 299/10000 [5:43:04<180:17:48, 66.91s/it]

Epoch 300 - VALIDATION LOSS: 0.0003168524490320124 - VALIDATION PSNR: 36.01068878173828 - VALIDATION SSIM: 0.92207514126648


  3%|▎         | 302/10000 [5:46:27<181:52:47, 67.52s/it]

Epoch 302 - VALIDATION LOSS: 0.00020059000326000387 - VALIDATION PSNR: 37.3159294128418 - VALIDATION SSIM: 0.9447886064318419


  3%|▎         | 304/10000 [5:48:41<181:20:39, 67.33s/it]

Epoch 304 - VALIDATION LOSS: 0.00031128250520851 - VALIDATION PSNR: 35.54544448852539 - VALIDATION SSIM: 0.920008974990487


  3%|▎         | 306/10000 [5:50:56<181:35:48, 67.44s/it]

Epoch 306 - VALIDATION LOSS: 0.0002584232006483944 - VALIDATION PSNR: 36.15325164794922 - VALIDATION SSIM: 0.9314689606826008


  3%|▎         | 308/10000 [5:53:09<180:47:30, 67.15s/it]

Epoch 308 - VALIDATION LOSS: 0.0003163663459417876 - VALIDATION PSNR: 35.10554504394531 - VALIDATION SSIM: 0.9192226920357496


  3%|▎         | 309/10000 [5:54:19<183:02:24, 68.00s/it]

Epoch 310 - VALIDATION LOSS: 0.0003111279656877741 - VALIDATION PSNR: 36.06949234008789 - VALIDATION SSIM: 0.9230974881803692


  3%|▎         | 312/10000 [5:57:43<183:40:39, 68.25s/it]

Epoch 312 - VALIDATION LOSS: 0.0003063157892029267 - VALIDATION PSNR: 35.78925704956055 - VALIDATION SSIM: 0.9218193131660521


  3%|▎         | 314/10000 [5:59:54<179:39:36, 66.77s/it]

Epoch 314 - VALIDATION LOSS: 0.00028628524341911543 - VALIDATION PSNR: 36.08013153076172 - VALIDATION SSIM: 0.9272073215476722


  3%|▎         | 316/10000 [6:02:12<182:17:29, 67.77s/it]

Epoch 316 - VALIDATION LOSS: 0.0002860404674720485 - VALIDATION PSNR: 35.88433837890625 - VALIDATION SSIM: 0.9261549570622891


  3%|▎         | 318/10000 [6:04:22<178:31:16, 66.38s/it]

Epoch 318 - VALIDATION LOSS: 0.0003292380843049614 - VALIDATION PSNR: 35.75611877441406 - VALIDATION SSIM: 0.9196263131128103


  3%|▎         | 319/10000 [6:05:32<181:13:01, 67.39s/it]

Epoch 320 - VALIDATION LOSS: 0.0003833877053693868 - VALIDATION PSNR: 34.39378356933594 - VALIDATION SSIM: 0.9050351282667518


  3%|▎         | 322/10000 [6:08:50<178:38:03, 66.45s/it]

Epoch 322 - VALIDATION LOSS: 0.0003269564695074223 - VALIDATION PSNR: 35.52947998046875 - VALIDATION SSIM: 0.9180052292878627


  3%|▎         | 324/10000 [6:11:06<180:34:39, 67.18s/it]

Epoch 324 - VALIDATION LOSS: 0.0002837726588040823 - VALIDATION PSNR: 35.836952209472656 - VALIDATION SSIM: 0.926407300871104


  3%|▎         | 326/10000 [6:13:22<181:46:41, 67.65s/it]

Epoch 326 - VALIDATION LOSS: 0.00032182926588575356 - VALIDATION PSNR: 34.99773025512695 - VALIDATION SSIM: 0.9173120824667216


  3%|▎         | 328/10000 [6:15:33<178:40:46, 66.51s/it]

Epoch 328 - VALIDATION LOSS: 0.00038305046837194823 - VALIDATION PSNR: 34.44379425048828 - VALIDATION SSIM: 0.9060647207965702


  3%|▎         | 329/10000 [6:16:41<179:20:03, 66.76s/it]

Epoch 330 - VALIDATION LOSS: 0.0002789246846077731 - VALIDATION PSNR: 35.75346374511719 - VALIDATION SSIM: 0.9278768568938374


  3%|▎         | 332/10000 [6:20:00<177:21:01, 66.04s/it]

Epoch 332 - VALIDATION LOSS: 0.00028816280064347666 - VALIDATION PSNR: 36.24591827392578 - VALIDATION SSIM: 0.9274474269187898


  3%|▎         | 334/10000 [6:22:21<183:52:48, 68.48s/it]

Epoch 334 - VALIDATION LOSS: 0.00036209350946592167 - VALIDATION PSNR: 35.1313591003418 - VALIDATION SSIM: 0.9121992681185155


  3%|▎         | 336/10000 [6:24:36<182:48:49, 68.10s/it]

Epoch 336 - VALIDATION LOSS: 0.0003404253839107696 - VALIDATION PSNR: 34.99803161621094 - VALIDATION SSIM: 0.9145178802606762


  3%|▎         | 338/10000 [6:26:55<184:17:17, 68.66s/it]

Epoch 338 - VALIDATION LOSS: 0.00033933422491827514 - VALIDATION PSNR: 35.226829528808594 - VALIDATION SSIM: 0.9158319172499254


  3%|▎         | 339/10000 [6:28:02<182:26:44, 67.99s/it]

Epoch 340 - VALIDATION LOSS: 0.00035500552166922716 - VALIDATION PSNR: 34.98568344116211 - VALIDATION SSIM: 0.9117821680172981


  3%|▎         | 342/10000 [6:31:30<183:26:47, 68.38s/it]

Epoch 342 - VALIDATION LOSS: 0.00036986489885748597 - VALIDATION PSNR: 35.09318542480469 - VALIDATION SSIM: 0.9109006434108018


  3%|▎         | 344/10000 [6:33:50<185:54:25, 69.31s/it]

Epoch 344 - VALIDATION LOSS: 0.00029688537961192196 - VALIDATION PSNR: 36.10093307495117 - VALIDATION SSIM: 0.926054184999287


  3%|▎         | 346/10000 [6:36:03<182:10:40, 67.93s/it]

Epoch 346 - VALIDATION LOSS: 0.0002865320329874521 - VALIDATION PSNR: 36.176998138427734 - VALIDATION SSIM: 0.9271213128800839


  3%|▎         | 348/10000 [6:38:22<183:41:35, 68.51s/it]

Epoch 348 - VALIDATION LOSS: 0.0003101384500041604 - VALIDATION PSNR: 35.458160400390625 - VALIDATION SSIM: 0.9213855432092846


  3%|▎         | 349/10000 [6:39:31<183:38:35, 68.50s/it]

Epoch 350 - VALIDATION LOSS: 0.0003353826023158035 - VALIDATION PSNR: 35.340763092041016 - VALIDATION SSIM: 0.9164963974298537


  4%|▎         | 352/10000 [6:43:03<187:00:18, 69.78s/it]

Epoch 352 - VALIDATION LOSS: 0.00035099336673738435 - VALIDATION PSNR: 35.14740753173828 - VALIDATION SSIM: 0.9127686008268745


  4%|▎         | 354/10000 [6:45:22<186:01:08, 69.42s/it]

Epoch 354 - VALIDATION LOSS: 0.00030607601365773007 - VALIDATION PSNR: 35.45933532714844 - VALIDATION SSIM: 0.9202018550044299


  4%|▎         | 356/10000 [6:47:41<185:55:15, 69.40s/it]

Epoch 356 - VALIDATION LOSS: 0.0003178999922965886 - VALIDATION PSNR: 35.547584533691406 - VALIDATION SSIM: 0.9188956957210302


  4%|▎         | 358/10000 [6:50:04<189:18:19, 70.68s/it]

Epoch 358 - VALIDATION LOSS: 0.0003399076304049231 - VALIDATION PSNR: 35.074249267578125 - VALIDATION SSIM: 0.9148462395936995


  4%|▎         | 359/10000 [6:51:14<188:18:41, 70.32s/it]

Epoch 360 - VALIDATION LOSS: 0.00028202015528222546 - VALIDATION PSNR: 36.222938537597656 - VALIDATION SSIM: 0.927976135694325


  4%|▎         | 362/10000 [6:54:49<190:03:23, 70.99s/it]

Epoch 362 - VALIDATION LOSS: 0.0002650203532539308 - VALIDATION PSNR: 36.09346008300781 - VALIDATION SSIM: 0.9295093246584833


  4%|▎         | 364/10000 [6:57:13<191:24:19, 71.51s/it]

Epoch 364 - VALIDATION LOSS: 0.00036496916982287075 - VALIDATION PSNR: 34.82459259033203 - VALIDATION SSIM: 0.9103722999065146


  4%|▎         | 366/10000 [6:59:34<189:18:15, 70.74s/it]

Epoch 366 - VALIDATION LOSS: 0.00030326830528792925 - VALIDATION PSNR: 35.56913757324219 - VALIDATION SSIM: 0.9228832676136495


  4%|▎         | 368/10000 [7:01:58<190:33:52, 71.22s/it]

Epoch 368 - VALIDATION LOSS: 0.0003961091861128807 - VALIDATION PSNR: 34.302223205566406 - VALIDATION SSIM: 0.9034333898619189


  4%|▎         | 369/10000 [7:03:09<190:13:44, 71.11s/it]

Epoch 370 - VALIDATION LOSS: 0.0003786295637837611 - VALIDATION PSNR: 34.76189422607422 - VALIDATION SSIM: 0.9085719345287457


  4%|▎         | 372/10000 [7:06:41<189:52:24, 71.00s/it]

Epoch 372 - VALIDATION LOSS: 0.00031507014500675723 - VALIDATION PSNR: 35.24254608154297 - VALIDATION SSIM: 0.9190383468151391


In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

------

### Inference

Demo the model

#### Test denoising

In [None]:
# # CODE TO INFER AND SHOW SOME RESULTS HERE



# def simple_plot(input_image_tensor_5D, subplot_index, image_name, clean_image_tensor_5D, folder_name, num_rows=2, num_cols=3):
#     plot_image_tensor_2D = input_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1)
#     clean_image_tensor_2D = clean_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1)
#     psnr_value = PSNR(clean_image_tensor_2D, plot_image_tensor_2D)
#     ssim_value = SSIM(clean_image_tensor_2D, plot_image_tensor_2D)
#     plt.subplot(num_rows, num_cols, subplot_index)
#     plt.axis('off')
#     plt.imshow(plot_image_tensor_2D.to("cpu").detach().numpy(), cmap='gray')
#     plt.title(f"{image_name} PSNR: {psnr_value:.2f} dB\n{image_name} SSIM: {ssim_value:.2f}", fontsize=10)
#     # Write the image to a file
#     save_image(plot_image_tensor_2D, f"{image_name}", folder_name)

# def get_image_tensor_5D(image):
#     image = image.convert("L")
#     image_numpy = np.asarray(image)
#     image_tensor_4D = convert_to_tensor_4D(image_numpy)
#     image_tensor_5D = image_tensor_4D.unsqueeze(0).to(DEVICE)
#     return image_tensor_5D

# def denoise(pdhg: DynamicImageStaticPrimalDualNN, noisy_image_tensor_5D):
#     pdhg.eval()
#     with torch.no_grad():
#         best_lambda_map = pdhg.get_lambda_cnn(noisy_image_tensor_5D)
#     x_denoised_lambda_map_best_tensor_5D = reconstruct_with_PDHG(noisy_image_tensor_5D, best_lambda_map, pdhg.T)
#     # x_denoised_lambda_map_best_tensor_5D = torch.clamp(x_denoised_lambda_map_best_tensor_5D, 0, 1)
#     with torch.no_grad():
#         torch.cuda.empty_cache()
#     return best_lambda_map, x_denoised_lambda_map_best_tensor_5D


# def brute_force_lambda(noisy_image_tensor_5D, clean_image_tensor_5D, T, min_value=0.01, max_value=0.1, num_values=10):
#     # TODO: Brute-force single lambda
#     best_psnr = 0
#     best_lambda = 0
#     lambas = list(np.linspace(min_value, max_value, num_values))
#     psnr_values = []
#     for lambda_value in lambas:
#         with torch.no_grad():
#             x_denoised_single_lambda_tensor_5D = reconstruct_with_PDHG(noisy_image_tensor_5D, lambda_value, T)
#         psnr_value = PSNR(clean_image_tensor_5D, x_denoised_single_lambda_tensor_5D)
#         psnr_value = psnr_value.item()
#         # Convert to float
#         psnr_value = np.float64(psnr_value)
#         if psnr_value > best_psnr:
#             best_psnr = psnr_value
#             best_lambda = lambda_value
#         psnr_values.append(psnr_value)

#     # Plot the PSNR values
#     plt.plot(lambas, psnr_values)
#     plt.xlabel("Lambda")
#     plt.ylabel("PSNR")
#     plt.title("PSNR vs Lambda")
#     plt.show()
    
#     return best_lambda


# def test_denoise(pdhg: DynamicImageStaticPrimalDualNN=None, model_name="", best_lambda=None):
#     """
#     Testing denoising with pre-trained parameters.
#     """
#     clean_image = Image.open(f"testcases/chest_xray_clean.png")
#     noisy_image = Image.open(f"testcases/chest_xray_noisy.png")
#     clean_image_tensor_5D = get_image_tensor_5D(clean_image)
#     noisy_image_tensor_5D = get_image_tensor_5D(noisy_image)

#     if best_lambda is None:
#         best_lambda = brute_force_lambda(noisy_image_tensor_5D, clean_image_tensor_5D, T=pdhg.T, min_value=0.01, max_value=1, num_values=100)

#     print(f"Best lambda: {best_lambda}")

#     k_w, k_h = 256, 256

#     folder_name = f"./tmp/images/model_{model_name}-kernel_{k_w}-best_lambda_{str(best_lambda).replace('.', '_')}-time_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
#     os.makedirs(folder_name, exist_ok=True)

#     plt.figure(figsize=(15, 6)) # Set the size of the plot

#     simple_plot(clean_image_tensor_5D, 1, "clean", clean_image_tensor_5D, folder_name)
#     simple_plot(noisy_image_tensor_5D, 2, "noisy", clean_image_tensor_5D, folder_name)

#     x_denoised_single_lambda_tensor_5D = reconstruct_with_PDHG(noisy_image_tensor_5D, best_lambda, T=pdhg.T)
    
#     best_lambda_map, x_denoised_lambda_map_tensor_5D = denoise(pdhg, noisy_image_tensor_5D)

#     # Clip to [0, 1]. The calculations may make it slightly below 0 and above 1
#     x_denoised_single_lambda_tensor_5D = torch.clamp(x_denoised_single_lambda_tensor_5D, 0, 1)
#     x_denoised_lambda_map_tensor_5D = torch.clamp(x_denoised_lambda_map_tensor_5D, 0, 1)

#     simple_plot(x_denoised_single_lambda_tensor_5D, 3, f"single_lambda_best_{str(best_lambda).replace('.', '_')}", clean_image_tensor_5D, folder_name)
#     simple_plot(x_denoised_lambda_map_tensor_5D, 4, "lambda_map_best_using_function", clean_image_tensor_5D, folder_name)

#     lambda_map_1 = best_lambda_map[:, 0:1, :, :, :]
#     lambda_map_2 = best_lambda_map[:, 1:2, :, :, :]
#     lambda_map_3 = best_lambda_map[:, 2:3, :, :, :]

#     lambda_map_1 = torch.clamp(lambda_map_1, 0, 1)
#     lambda_map_2 = torch.clamp(lambda_map_2, 0, 1)
#     lambda_map_3 = torch.clamp(lambda_map_3, 0, 1)

#     simple_plot(lambda_map_1, 5, "lambda_map_1", clean_image_tensor_5D, folder_name)
#     simple_plot(lambda_map_3, 6, "lambda_map_3", clean_image_tensor_5D, folder_name)

#     plt.savefig(f"{folder_name}/results.png")

#     plt.show();

#     with open(f"{folder_name}/log.txt", "w") as f:
#         f.write(f"Best lambda: {best_lambda}\n")
#         f.write(f"PSNR (single lambda): {PSNR(clean_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1), x_denoised_single_lambda_tensor_5D.squeeze(0).squeeze(0).squeeze(-1))}\n")
#         f.write(f"PSNR (lambda map): {PSNR(clean_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1), x_denoised_lambda_map_tensor_5D.squeeze(0).squeeze(0).squeeze(-1))}\n")
#         f.write(f"Config: {get_config()}\n")

#     with torch.no_grad():
#         torch.cuda.empty_cache()

# model_dir = "./tmp_2/model-2024_06_05_23_51_27"
# epoch = 4000
# pdhg = torch.load(f"{model_dir}/model_epoch_{epoch}.pt")

# test_denoise(
#     pdhg=pdhg,
#     model_name=f"chest_xray_demo-epoch_{epoch}",
#     best_lambda=0.08
# )

# with torch.no_grad():
#     torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
# def temp_test():

---

### Create a video

In [None]:
# def create_video(model_name, start_epoch=20, end_epoch=10_000, step=20):
#     clean_image_path = "./test_cases/turtle_clean/turtle clean.png"
#     noisy_image_path = "./test_cases/turtle_noisy/turtle noisy.png"
#     clean_image_tensor_5D = get_image_tensor_5D(clean_image_path)
#     noisy_image_tensor_5D = get_image_tensor_5D(noisy_image_path)
#     clean_image_tensor_2D = clean_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1)
#     noisy_image_tensor_2D = noisy_image_tensor_5D.squeeze(0).squeeze(0).squeeze(-1)

#     psnr_noisy = PSNR(noisy_image_tensor_2D, clean_image_tensor_2D)
#     ssim_noisy = SSIM(noisy_image_tensor_2D, clean_image_tensor_2D)



#     frames_folder = f"./tmp/{model_name}"
#     model_folder=f"./tmp_2/{model_name}"
#     os.makedirs(frames_folder, exist_ok=True)
#     os.makedirs(f"{frames_folder}/denoised", exist_ok=True)
#     os.makedirs(f"{frames_folder}/lambda_map_1", exist_ok=True)
#     os.makedirs(f"{frames_folder}/lambda_map_2", exist_ok=True)
#     os.makedirs(f"{frames_folder}/lambda_map_3", exist_ok=True)

#     with open(f"./tmp/{model_name}/metrics.csv", "w") as f:
#         f.write(f"Image, PSNR, SSIM\n")
#         f.write(f"Noisy, {psnr_noisy:.2f}, {ssim_noisy:.2f}\n")

#         for epoch in range(start_epoch, end_epoch + 1, step):
#             model_name = f"model_epoch_{epoch}"
#             pdhg = torch.load(f"{model_folder}/{model_name}.pt")
#             best_lambda_map, x_denoised_lambda_map_best_tensor_5D = denoise(pdhg, noisy_image_tensor_5D)
#             x_denoised_lambda_map_best_tensor_5D = torch.clamp(x_denoised_lambda_map_best_tensor_5D, 0, 1)

#             x_denoised_lambda_map_best_tensor_2D = x_denoised_lambda_map_best_tensor_5D.squeeze(0).squeeze(0).squeeze(-1)
#             psnr_denoised = PSNR(x_denoised_lambda_map_best_tensor_2D, clean_image_tensor_2D)
#             ssim_denoised = SSIM(x_denoised_lambda_map_best_tensor_2D, clean_image_tensor_2D)
#             f.write(f"{epoch}, {psnr_denoised:.2f}, {ssim_denoised:.2f}\n")

#             denoised_image_to_save = Image.fromarray((x_denoised_lambda_map_best_tensor_2D.to("cpu").detach().numpy() * 255).astype(np.uint8))
#             denoised_image_to_save.save(f"{frames_folder}/denoised/{epoch}.png")

#             lambda_map_1 = best_lambda_map[:, 0:1, :, :, :]
#             lambda_map_2 = best_lambda_map[:, 1:2, :, :, :]
#             lambda_map_3 = best_lambda_map[:, 2:3, :, :, :]
#             lambda_map_1 = torch.clamp(lambda_map_1, 0, 1)
#             lambda_map_2 = torch.clamp(lambda_map_2, 0, 1)
#             lambda_map_3 = torch.clamp(lambda_map_3, 0, 1)

#             lambda_map_1_to_save = Image.fromarray((lambda_map_1.squeeze(0).squeeze(0).squeeze(-1).to("cpu").detach().numpy() * 255).astype(np.uint8))
#             lambda_map_1_to_save.save(f"{frames_folder}/lambda_map_1/{epoch}.png")

#             lambda_map_2_to_save = Image.fromarray((lambda_map_2.squeeze(0).squeeze(0).squeeze(-1).to("cpu").detach().numpy() * 255).astype(np.uint8))
#             lambda_map_2_to_save.save(f"{frames_folder}/lambda_map_2/{epoch}.png")

#             lambda_map_3_to_save = Image.fromarray((lambda_map_3.squeeze(0).squeeze(0).squeeze(-1).to("cpu").detach().numpy() * 255).astype(np.uint8))
#             lambda_map_3_to_save.save(f"{frames_folder}/lambda_map_3/{epoch}.png")
        

#     # # Create the video
#     # frames = []
#     # for epoch in range(start_epoch, end_epoch + 1, step):
#     #     frames.append(cv2.imread(f"{frames_folder}/frame_{epoch}.png"))
#     # height, width, layers = frames[0].shape
#     # size = (width, height)
#     # out = cv2.VideoWriter(f"{frames_folder}/video.avi", cv2.VideoWriter_fourcc(*'DIVX'), 1, size)
#     # for i in range(len(frames)):
#     #     out.write(frames[i])
#     # out.release()

# create_video("model_turtle_2024_06_04_04_19_21", start_epoch=20, end_epoch=10_000, step=20)

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
# def test_revisualise():

#     def vis(image_folder, image_name):
#         image_path = f"{image_folder}/{image_name}.png"
#         image = Image.open(image_path)
#         plt.imshow(image, cmap='gray')
#         plt.show();

#     image_folder = "tmp/PRESENT/presentation-img_turtle-best_lambda_0_07-kernel_256-model_-trained_on_-time_2024_06_04_22_59_31-epoch_100_000"
#     image_names = [
#         "lambda_map_3",
#         "single_lambda_best_0_06000000000000001",
#         "clean",
#     ]

#     for image_name in image_names:
#         vis(image_folder, image_name)

# test_revisualise()