In [1]:
import os
import copy
import json
import time

import numpy as np
import matplotlib.pyplot as plt

from module.utils import calculate_metrics, display_image_in_detail, plot_2d_data, timer_decorator, display_4d_image, timer_decorator
from module.datasets import load_4d_dicom, save_4d_dicom, restore_data, split_data

from module.models import UNet2D
from module.datasets import Nb2Nb2D_Dataset
from module.loss import SSIMLoss, SSIM_MAELoss, SSIM_MSELoss



import h5py
from tqdm.notebook import tqdm


import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset

from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"device: {device}")

%matplotlib tk

device: cuda:0


In [2]:
# load noisy data
with h5py.File('./dataset/preprocessed/PT_20p 150_120 OSEM_real_0.00_batch.h5', 'r') as f:
    noisy_data = f['dataset'][...]
    restore_info = json.loads(f['restore_info'][()])
    
print(f"Noisy data...{noisy_data.dtype} (shape:{noisy_data.shape}; range:[{np.min(noisy_data)},{np.max(noisy_data)}]; mean:{np.mean(noisy_data)}); std:{np.std(noisy_data)}")

print(restore_info)


Noisy data...float32 (shape:(11, 24, 71, 192, 192); range:[0.0,1.0]; mean:0.5030226111412048); std:0.02147510275244713
{'original_min': -32768.0, 'original_max': 32767.0, 'z_score_mean': 201.4710693359375, 'z_score_std_dev': 1407.2664794921875, 'noise_min': 0.0, 'noise_max': 1.0}


In [3]:
noisy_tensor = torch.tensor(noisy_data[..., np.newaxis, :, :, :], dtype=torch.float32) 

In [4]:
#
time_idx = 11
depth_idx = 38

#
top = noisy_tensor[0, time_idx, :, depth_idx-1]
middle = noisy_tensor[0, time_idx, :, depth_idx]
bottom = noisy_tensor[0, time_idx, :, depth_idx+1]


In [5]:
def random_neighbor_subsample(tensor, k=2):
    """
    Perform random neighbor sub-sampling on a batch tensor.

    This function randomly selects two neighboring cells from kxk patches
    in the given tensor.

    Args:
        tensor (torch.Tensor): Input tensor of shape [batch, channels, height, width].
        k (int, optional): The size of the cell for sub-sampling. Defaults to 2.

    Returns:
        tuple: Two sub-sampled tensors (g1, g2) each of shape [batch, channels, height//k, width//k].
    """
    B, C, H, W = tensor.shape
    unfolded = tensor.unfold(2, k, k).unfold(3, k, k)
    unfolded = unfolded.contiguous().view(B, C, H//k, W//k, k*k)
    
    idx1, idx2 = torch.randperm(k*k)[:2].to(tensor.device)
    
    g1 = unfolded[..., idx1].squeeze(-1)
    g2 = unfolded[..., idx2].squeeze(-1)
    
    return g1, g2


In [14]:
g1_top, g2_top = random_neighbor_subsample(top.unsqueeze(0))
g1_middle, g2_middle = random_neighbor_subsample(middle.unsqueeze(0))
g1_bottom, g2_bottom = random_neighbor_subsample(bottom.unsqueeze(0))

In [34]:
#plt.imshow(g2_bottom.squeeze(0, 1).numpy(), cmap='hot')
plt.imshow(bottom.squeeze(0).numpy(), cmap='hot')

plt.axis('off')

plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0, wspace=0)
plt.margins(0,0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig('bottom.png', bbox_inches='tight', pad_inches=0)