# Noise2Void Experiment

[CVPR 2019 : Noise2Void - Learning Denoising from Single Noisy Image](https://openaccess.thecvf.com/content_CVPR_2019/papers/Krull_Noise2Void_-_Learning_Denoising_From_Single_Noisy_Images_CVPR_2019_paper.pdf)

Key idea: 
- Blind-spot network(BSN), network learn the mapping between mask pixel(s) and every other pixels
- the expect of output identical to clean image when noise mean is zero and i.i.d and there traing data inf.

Pros: 
1. single noisy image denoising, has tremendous practicality on medical denoising task
2. can adaptive to any noise distribution since its learn denoising directly from noisy image 

Cons:
1. by masking certain pixel(s), the quality of denoising decrease, especially for high-frequency content
2. strong assumption of zero mean noise and pixels i.i.d, leading to poor performance when it dealing with structured noise 
3. denoising performance slightly degrade to N2N but still better than BM3D

In [None]:
import os
import copy
import json
import time

import numpy as np
import matplotlib.pyplot as plt

from module.utils import calculate_metrics, display_image_in_detail, plot_2d_data, timer_decorator, display_4d_image, timer_decorator
from module.datasets import load_4d_dicom, save_4d_dicom, restore_data

from module.models import UNet2_5D
from module.datasets import MaskDataset


import h5py
from tqdm.notebook import tqdm


import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset

from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"device: {device}")

## Load preprocessed data

In [None]:
# load noisy data
with h5py.File('./dataset/preprocessed/PT_20p 150_120 OSEM_real_0.00_batch.h5', 'r') as f:
    noisy_data = f['dataset'][...]
    restore_info = json.loads(f['restore_info'][()])
    
print(f"Noisy data...{noisy_data.dtype} (shape:{noisy_data.shape}; range:[{np.min(noisy_data)},{np.max(noisy_data)}]; mean:{np.mean(noisy_data)}); std:{np.std(noisy_data)}")

print(restore_info)

display_image_in_detail(noisy_data[0, 11, 38], title="prepared noisy data")

## Process Denoising

## 1. import denoising network

In [None]:
model = UNet2_5D(in_channels=3, out_channels=1)
model = model.to(device)
print("The number of parameters of the network is: ",  sum(p.numel() for p in model.parameters() if p.requires_grad))

summary(model, [(1, 192, 192), (1, 192, 192), (1, 192, 192)])

## 2. create mask dataset and dataloader

In [None]:
def split_tensor(data_tensor):
    """split tensor into train, test, vali tensor

    Args:
        data_tensor (_type_): _description_

    Returns:
        _type_: _description_
    """
    test_tensor = data_tensor[0:1]
    
    rest_tensor = data_tensor[1:]

    total_samples = rest_tensor.shape[0]
    
    train_ration = 0.8
    train_length = int(train_ration * total_samples)
    val_length = total_samples - train_length
    
    train_tensor, val_tensor = random_split(rest_tensor, [train_length, val_length])

    return train_tensor.dataset, val_tensor.dataset, test_tensor


# convert to ndarray to tensor
noisy_tensor = torch.tensor(noisy_data[:, :, np.newaxis, :, :, :], dtype=torch.float32) 
noisy_tensor = noisy_tensor.to(device)
print(f"noisy_tensor\n shape: {noisy_tensor.shape}; range:({torch.min(noisy_tensor)},{torch.max(noisy_tensor)}) ;mean:{torch.mean(noisy_tensor)}; std:{torch.std(noisy_tensor)} ") # (batch, time, channel, depth, height, width


# create dataset

train_tensor, val_tensor, test_tensor = split_tensor(noisy_tensor) # split input tensor 

num_masks = 1

train_dataset = MaskDataset(train_tensor, num_masks)
val_dataset = MaskDataset(val_tensor, num_masks)
test_dataset = MaskDataset(test_tensor, num_masks)


# create dataloader
batch_size = 32
num_workers = 12

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
# example dataloader feed
top_slice, mask_middle_slice, bottom_slice, middle_slice = next(iter(train_loader))

print(f"top_slice: {top_slice.dtype} {top_slice.shape} range:({torch.max(top_slice)},{torch.min(top_slice)}); mean:{torch.mean(top_slice)}; std:{torch.std(top_slice)}")
print(f"mask_middle_slice: {mask_middle_slice.dtype} {mask_middle_slice.shape} range:({torch.max(mask_middle_slice)},{torch.min(mask_middle_slice)}); mean:{torch.mean(mask_middle_slice)}; std:{torch.std(mask_middle_slice)}")
print(f"bottom_slice: {bottom_slice.dtype} {bottom_slice.shape} range:({torch.max(bottom_slice)},{torch.min(bottom_slice)}); mean:{torch.mean(bottom_slice)}; std:{torch.std(bottom_slice)}")
print(f"middle_slice: {middle_slice.dtype} {middle_slice.shape} range:({torch.max(middle_slice)},{torch.min(middle_slice)}); mean:{torch.mean(middle_slice)}; std:{torch.std(middle_slice)}")

# select the show baych index
idx = 1

fig, axes = plt.subplots(1, 5, figsize=(25, 4))

axes[0].imshow(top_slice[idx].squeeze().cpu().numpy(), cmap='hot')
axes[0].set_title('Top')

axes[1].imshow(mask_middle_slice[idx].squeeze().cpu().numpy(), cmap='hot')
axes[1].set_title('Masked middle')

axes[2].imshow(bottom_slice[idx].squeeze().cpu().numpy(), cmap='hot')
axes[2].set_title('Bottom')

axes[3].imshow(middle_slice[idx].squeeze().cpu().numpy(), cmap='hot')
axes[3].set_title('True middle')

#
difference = (mask_middle_slice[idx] != middle_slice[idx]).float().squeeze().cpu().numpy()
axes[4].imshow(difference, cmap='hot')
axes[4].set_title('mask position')
    
for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

## 3.  define training modules 

In [None]:
def train(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    pbar = tqdm(train_loader, desc="Training", dynamic_ncols=True)
    for top_slice, mask_middle_slice, bottom_slice, middle_slice in train_loader:
        #top_slice, mask_middle_slice, bottom_slice, middle_slice = top_slice.to(device), mask_middle_slice.to(device), bottom_slice.to(device), middle_slice.to(device)
        # Forward
        outputs = model(top_slice, mask_middle_slice, bottom_slice)
        loss = criterion(outputs, middle_slice)
        # Backward
        optimizer.step()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({"batch_loss": loss.item()})
    
    scheduler.step()
    avg_loss = total_loss / len(train_loader)
    return avg_loss


def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0 
    pbar = tqdm(val_loader, desc="Validating", dynamic_ncols=True)
    with torch.no_grad():
        for top_slice, mask_middle_slice, bottom_slice, middle_slice in pbar:
            #top_slice, mask_middle_slice, bottom_slice, middle_slice = top_slice.to(device), mask_middle_slice.to(device), bottom_slice.to(device), middle_slice.to(device)
            outputs = model(top_slice, mask_middle_slice, bottom_slice)
            loss = criterion(outputs, middle_slice)
            total_loss += loss.item()
            pbar.set_postfix({"batch_loss": loss.item()})
    
    avg_loss = total_loss / len(val_loader)
    
    return avg_loss

@timer_decorator
def test(model, test_loader, device):
    model.eval()
    preds = []
    targets = []
    pbar = tqdm(test_loader, desc="Testing", dynamic_ncols=True)
    with torch.no_grad():
        for top_slice, mask_middle_slice, bottom_slice, middle_slice in pbar:
            #top_slice, mask_middle_slice, bottom_slice, middle_slice = top_slice.to(device), mask_middle_slice.to(device), bottom_slice.to(device), middle_slice.to(device)
            
            pred = model(top_slice, mask_middle_slice, bottom_slice)
            preds.append(pred.cpu()) # predict
            targets.append(middle_slice.cpu()) # original input data
            
    return preds, targets


@timer_decorator
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5, save_path='path/to/your/directory'):
    best_loss = float('inf')
    early_stopping_counter = 0
    best_epoch = -1
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, scheduler, device)
        val_loss = validate(model, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Save the model with the best validation loss
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            best_epoch = epoch + 1  # 1-based counting for epoch
            best_save_path = os.path.join(save_path, f'best_model_epoch_{best_epoch}.pth')
            torch.save(best_model_wts, best_save_path)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f"Early stopping after {patience} epochs without improvement.")
                break

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model


In [None]:
# define loss, optimizer, lr_scheduler
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)


## 4. process training

In [None]:
trained_model = train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=200, patience=10, save_path="./check_points")

## 5. process denoising

In [None]:
# load best model
best_model = model.load_state_dict(torch.load('check_points/best_model_epoch_30.pth'))

# denoising 
denoised_tensor = test(model, test_loader, device)

## 6. metrics evaluation

# Save denormalized denoised data into 16-bit DICOM files

In [None]:
# denormalized denoised data
restored_data = restore_data(denoised_data, restore_info)
print(f"restore_data: {restored_data.dtype} shape:{restored_data.shape}; range:({np.min(restored_data)},{np.max(restored_data)}); mean:{np.mean(restored_data)}; std:{np.std(restored_data)}")
display_4d_image(restored_data)

In [None]:
# save denormalized denoised data as 16-bit gray-scale .DICOM files
origin_dicom_folder = './dataset/10_05_2021_PET_only/PT_20p 150_120 OSEM'

output_folder = './dataset/denoised/BM4D/PT_20p 150_120 OSEM_real_0.00_batch'

save_4d_dicom(origin_dicom_folder, restored_data, output_folder)