In [1]:
from __future__ import division
import os
import logging
import time
import glob
import datetime
import argparse
import numpy as np
from scipy.io import loadmat, savemat

import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from arch_unet import UNet
import utils as util
from collections import OrderedDict
from tqdm import tqdm

from Sber_utils import Masker, load_network, generate_mask, interpolate_mask, depth_to_space

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument("--noisetype", type=str, default="gauss25", choices=['gauss25', 'gauss5_50', 'poisson30', 'poisson5_50'])
parser.add_argument('--checkpoint', type=str, default='./*.pth')
parser.add_argument('--test_dir', type=str, default='./data/test')
parser.add_argument('--save_test_path', type=str, default='./test')
parser.add_argument('--log_name', type=str, default='b2u_unet_g25_112rf20')
parser.add_argument('--gpu_devices', default='0', type=str)
parser.add_argument('--parallel', action='store_true')
parser.add_argument('--n_feature', type=int, default=48)
parser.add_argument('--n_channel', type=int, default=1)
parser.add_argument("--beta", type=float, default=20.0)
parser.add_argument('--dataset_name', type=str)
parser.add_argument('--repeat_times', type=int, default=3)

opt = parser.parse_args(['--noisetype', 'gauss5_50', 
                        # '--checkpoint', './pretrained_models/g5-50_112rf20_beta19.4.pth',
                        '--checkpoint', './experiments/my_models/2023-07-13-14-50/models/epoch_model_100.pth',
                        '--test_dir', './data/test',
                        '--save_test_path', './test',
                        '--log_name', 'b2u_sunet_fmdd_112rf20',
                        '--beta', '19.4',
                        '--dataset_name', 'small_test'])
                        
systime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')
operation_seed_counter = 0
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_devices
torch.set_num_threads(8)

In [6]:
class MicroscopyDataset(Dataset):
    def __init__(self, img_dir,channels=3, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_filenames = os.listdir(img_dir)
        if channels == 1:
            self.channels = 0
        else:
            self.channels = channels

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        name = os.path.split(self.img_filenames[idx])[-1].split('.')[0]

        image = cv2.imread(os.path.join(self.img_dir, self.img_filenames[idx]), self.channels)

        if self.transform:
            image = self.transform(image)

        return image, name

In [7]:
# Dataset
test_transforms = transforms.Compose([transforms.ToTensor()])
dataset_dir = os.path.join(opt.test_dir, opt.dataset_name)
test_dataset = MicroscopyDataset(dataset_dir,opt.n_channel, transform=test_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


# Masker
masker = Masker(width=4, mode='interpolate', mask_type='all')
# Network
network = UNet(in_channels=opt.n_channel,
                out_channels=opt.n_channel,
                wf=opt.n_feature)
if opt.parallel:
    network = torch.nn.DataParallel(network)
network = network.cuda()
# load pre-trained model
network = load_network(opt.checkpoint, network, strict=True)
beta = opt.beta

# turn on eval mode
network.eval()

# validation
save_test_path = os.path.join(opt.save_test_path, opt.dataset_name)
os.makedirs(save_test_path, exist_ok=True)

validation_path = os.path.join(save_test_path, opt.log_name)
os.makedirs(validation_path, exist_ok=True)
np.random.seed(101)

In [None]:

save_dir = validation_path
os.makedirs(save_dir, exist_ok=True)

repeat_times = opt.repeat_times

for i in range(repeat_times):
    for im, name in tqdm(test_dataloader):
        # print('im shape', im.shape)
        
        if type(im) == torch.Tensor:
            
            if len(im.shape) == 4: 
                im = im.squeeze(0)
            im = im.permute(1,2,0)
            im = im.numpy()
        # print(im.shape)
        origin255 = im.copy() * 255
        origin255 = origin255.astype(np.uint8)
        im = np.array(im, dtype=np.float32) #/ 255.0
        
        
        noisy_im = im.copy() # add
        
        noisy255 = noisy_im.copy()
        noisy255 = np.clip(noisy255 * 255.0 + 0.5, 0,
                            255).astype(np.uint8)


        # padding to square
        H = noisy_im.shape[0]
        W = noisy_im.shape[1]
        val_size = (max(H, W) + 31) // 32 * 32
        noisy_im = np.pad(
            noisy_im,
            [[0, val_size - H], [0, val_size - W], [0, 0]],
            'reflect')

        transformer = transforms.Compose([transforms.ToTensor()])
        noisy_im = transformer(noisy_im)
        noisy_im = torch.unsqueeze(noisy_im, 0)
        noisy_im = noisy_im.cuda()
        with torch.no_grad():
            n, c, h, w = noisy_im.shape
            
            net_input, mask = masker.train(noisy_im)
           
            noisy_output = (network(net_input)*mask).view(n,-1,c,h,w).sum(dim=1)
            exp_output = network(noisy_im)
        pred_dn = noisy_output[:, :, :H, :W]
        pred_exp = exp_output[:, :, :H, :W]
        pred_mid = (pred_dn + beta*pred_exp) / (1 + beta)

        pred_dn = pred_dn.permute(0, 2, 3, 1)
        pred_exp = pred_exp.permute(0, 2, 3, 1)
        pred_mid = pred_mid.permute(0, 2, 3, 1)

        pred_dn = pred_dn.cpu().data.clamp(0, 1).numpy().squeeze(0)
        pred_exp = pred_exp.cpu().data.clamp(0, 1).numpy().squeeze(0)
        pred_mid = pred_mid.cpu().data.clamp(0, 1).numpy().squeeze(0)

        pred255_dn = np.clip(pred_dn * 255.0 + 0.5, 0,
                            255).astype(np.uint8)
        pred255_exp = np.clip(pred_exp * 255.0 + 0.5, 0,
                            255).astype(np.uint8)
        pred255_mid = np.clip(pred_mid * 255.0 + 0.5, 0,
                            255).astype(np.uint8)                   

        

        # visualization
        if opt.n_channel == 1:
            color_mode = 'L'
            save_path = os.path.join(save_dir, f"{name}-{i}_clean.png")
            Image.fromarray(origin255.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_noisy.png")    
            Image.fromarray(noisy255.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_dn.png") 
            Image.fromarray(pred255_dn.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_exp.png") 
            Image.fromarray(pred255_exp.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_mid.png") 
            Image.fromarray(pred255_mid.squeeze()).convert(color_mode).save(save_path)
        else:
            color_mode = 'L'
            save_path = os.path.join(save_dir, f"{name}-{i}_clean.png")
            Image.fromarray(origin255).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_noisy.png")    
            Image.fromarray(noisy255).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_dn.png") 
            Image.fromarray(pred255_dn).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_exp.png") 
            Image.fromarray(pred255_exp).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_mid.png") 
            Image.fromarray(pred255_mid).convert(color_mode).save(save_path)

In [None]:
torch.cuda.empty_cache()