In [54]:
from __future__ import division
import os
import logging
import time
import glob
import datetime
import argparse
import numpy as np
from scipy.io import loadmat, savemat

import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from arch_unet import UNet
import utils as util
from collections import OrderedDict
from tqdm import tqdm

In [169]:
parser = argparse.ArgumentParser()
parser.add_argument("--noisetype", type=str, default="gauss25", choices=['gauss25', 'gauss5_50', 'poisson30', 'poisson5_50'])
parser.add_argument('--checkpoint', type=str, default='./*.pth')
parser.add_argument('--test_dir', type=str, default='./data/test')
parser.add_argument('--save_test_path', type=str, default='./test')
parser.add_argument('--log_name', type=str, default='b2u_unet_g25_112rf20')
parser.add_argument('--gpu_devices', default='0', type=str)
parser.add_argument('--parallel', action='store_true')
parser.add_argument('--n_feature', type=int, default=48)
parser.add_argument('--n_channel', type=int, default=1)
parser.add_argument("--beta", type=float, default=20.0)
parser.add_argument('--dataset_name', type=str)
parser.add_argument('--repeat_times', type=int, default=3)

opt = parser.parse_args(['--noisetype', 'gauss5_50', 
                        # '--checkpoint', './pretrained_models/g5-50_112rf20_beta19.4.pth',
                        '--checkpoint', './pretrained_models/Confocal_MICE_112rf20_beta19.7.pth',
                        '--test_dir', './data/test',
                        '--save_test_path', './test',
                        '--log_name', 'b2u_sunet_fmdd_112rf20',
                        '--beta', '19.4',
                        '--dataset_name', 'Crop'])
                        
systime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')
operation_seed_counter = 0
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_devices
torch.set_num_threads(8)

In [170]:
class MicroscopyDataset(Dataset):
    def __init__(self, img_dir,channels=3, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_filenames = os.listdir(img_dir)
        if channels == 1:
            self.channels = 0
        else:
            self.channels = channels

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        name = os.path.split(self.img_filenames[idx])[-1].split('.')[0]

        image = cv2.imread(os.path.join(self.img_dir, self.img_filenames[idx]), self.channels)

        if self.transform:
            image = self.transform(image)

        return image, name

In [171]:
class Masker(object):
    def __init__(self, width=4, mode='interpolate', mask_type='all'):
        self.width = width
        self.mode = mode
        self.mask_type = mask_type

    def mask(self, img, mask_type=None, mode=None):
        # This function generates masked images given random masks
        if mode is None:
            mode = self.mode
        if mask_type is None:
            mask_type = self.mask_type

        n, c, h, w = img.shape
        mask = generate_mask(img, width=self.width, mask_type=mask_type)
        mask_inv = torch.ones(mask.shape).to(img.device) - mask
        if mode == 'interpolate':
            masked = interpolate_mask(img, mask, mask_inv)
        else:
            raise NotImplementedError

        net_input = masked
        return net_input, mask

    def train(self, img):
        n, c, h, w = img.shape
        tensors = torch.zeros((n,self.width**2,c,h,w), device=img.device)
        masks = torch.zeros((n,self.width**2,1,h,w), device=img.device)
        for i in range(self.width**2):
            x, mask = self.mask(img, mask_type='fix_{}'.format(i))
            tensors[:,i,...] = x
            masks[:,i,...] = mask
        tensors = tensors.view(-1, c, h, w)
        masks = masks.view(-1, 1, h, w)
        return tensors, masks

In [172]:
def load_network(load_path, network, strict=True):
    assert load_path is not None
    # logger.info("Loading model from [{:s}] ...".format(load_path))
    if isinstance(network, nn.DataParallel) or isinstance(
        network, nn.parallel.DistributedDataParallel
    ):
        network = network.module
    load_net = torch.load(load_path)
    load_net_clean = OrderedDict()  # remove unnecessary 'module.'
    for k, v in load_net.items():
        if k.startswith("module."):
            load_net_clean[k[7:]] = v
        else:
            load_net_clean[k] = v
    network.load_state_dict(load_net_clean, strict=strict)
    return network

In [173]:
class AugmentNoise(object):
    def __init__(self, style):
        print(style)
        if style.startswith('gauss'):
            self.params = [
                float(p) / 255.0 for p in style.replace('gauss', '').split('_')
            ]
            if len(self.params) == 1:
                self.style = "gauss_fix"
            elif len(self.params) == 2:
                self.style = "gauss_range"
        elif style.startswith('poisson'):
            self.params = [
                float(p) for p in style.replace('poisson', '').split('_')
            ]
            if len(self.params) == 1:
                self.style = "poisson_fix"
            elif len(self.params) == 2:
                self.style = "poisson_range"

    def add_train_noise(self, x):
        shape = x.shape
        if self.style == "gauss_fix":
            std = self.params[0]
            std = std * torch.ones((shape[0], 1, 1, 1), device=x.device)
            noise = torch.cuda.FloatTensor(shape, device=x.device)
            torch.normal(mean=0.0,
                         std=std,
                         generator=get_generator(),
                         out=noise)
            return x + noise
        elif self.style == "gauss_range":
            min_std, max_std = self.params
            std = torch.rand(size=(shape[0], 1, 1, 1),
                             device=x.device) * (max_std - min_std) + min_std
            noise = torch.cuda.FloatTensor(shape, device=x.device)
            torch.normal(mean=0, std=std, generator=get_generator(), out=noise)
            return x + noise
        elif self.style == "poisson_fix":
            lam = self.params[0]
            lam = lam * torch.ones((shape[0], 1, 1, 1), device=x.device)
            noised = torch.poisson(lam * x, generator=get_generator()) / lam
            return noised
        elif self.style == "poisson_range":
            min_lam, max_lam = self.params
            lam = torch.rand(size=(shape[0], 1, 1, 1),
                             device=x.device) * (max_lam - min_lam) + min_lam
            noised = torch.poisson(lam * x, generator=get_generator()) / lam
            return noised

    def add_valid_noise(self, x):
        shape = x.shape
        if self.style == "gauss_fix":
            std = self.params[0]
            return np.array(x + np.random.normal(size=shape) * std,
                            dtype=np.float32)
        elif self.style == "gauss_range":
            min_std, max_std = self.params
            std = np.random.uniform(low=min_std, high=max_std, size=(1, 1, 1))
            return np.array(x + np.random.normal(size=shape) * std,
                            dtype=np.float32)
        elif self.style == "poisson_fix":
            lam = self.params[0]
            return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)
        elif self.style == "poisson_range":
            min_lam, max_lam = self.params
            lam = np.random.uniform(low=min_lam, high=max_lam, size=(1, 1, 1))
            return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)


def space_to_depth(x, block_size):
    n, c, h, w = x.size()
    unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
    return unfolded_x.view(n, c * block_size**2, h // block_size,
                           w // block_size)

In [174]:
def generate_mask(img, width=4, mask_type='random'):
    # This function generates random masks with shape (N x C x H/2 x W/2)
    n, c, h, w = img.shape
    mask = torch.zeros(size=(n * h // width * w // width * width**2, ),
                       dtype=torch.int64,
                       device=img.device)
    idx_list = torch.arange(
        0, width**2, 1, dtype=torch.int64, device=img.device)
    rd_idx = torch.zeros(size=(n * h // width * w // width, ),
                         dtype=torch.int64,
                         device=img.device)

    if mask_type == 'random':
        torch.randint(low=0,
                      high=len(idx_list),
                      size=(n * h // width * w // width, ),
                      device=img.device,
                      generator=get_generator(device=img.device),
                      out=rd_idx)
    elif mask_type == 'batch':
        rd_idx = torch.randint(low=0,
                               high=len(idx_list),
                               size=(n, ),
                               device=img.device,
                               generator=get_generator(device=img.device)).repeat(h // width * w // width)
    elif mask_type == 'all':
        rd_idx = torch.randint(low=0,
                               high=len(idx_list),
                               size=(1, ),
                               device=img.device,
                               generator=get_generator(device=img.device)).repeat(n * h // width * w // width)
    elif 'fix' in mask_type:
        index = mask_type.split('_')[-1]
        index = torch.from_numpy(np.array(index).astype(
            np.int64)).type(torch.int64)
        rd_idx = index.repeat(n * h // width * w // width).to(img.device)

    rd_pair_idx = idx_list[rd_idx]
    rd_pair_idx += torch.arange(start=0,
                                end=n * h // width * w // width * width**2,
                                step=width**2,
                                dtype=torch.int64,
                                device=img.device)

    mask[rd_pair_idx] = 1

    mask = depth_to_space(mask.type_as(img).view(
        n, h // width, w // width, width**2).permute(0, 3, 1, 2), block_size=width).type(torch.int64)

    return mask


def interpolate_mask(tensor, mask, mask_inv):
    n, c, h, w = tensor.shape
    device = tensor.device
    mask = mask.to(device)
    kernel = np.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)])

    kernel = kernel[np.newaxis, np.newaxis, :, :]
    kernel = torch.Tensor(kernel).to(device)
    kernel = kernel / kernel.sum()

    filtered_tensor = torch.nn.functional.conv2d(
        tensor.view(n*c, 1, h, w), kernel, stride=1, padding=1)

    return filtered_tensor.view_as(tensor) * mask + tensor * mask_inv

In [175]:
def depth_to_space(x, block_size):
    return torch.nn.functional.pixel_shuffle(x, block_size)

In [176]:
# Dataset
test_transforms = transforms.Compose([transforms.ToTensor()])
dataset_dir = os.path.join(opt.test_dir, opt.dataset_name)
test_dataset = MicroscopyDataset(dataset_dir,opt.n_channel, transform=test_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


# Noise adder
noise_adder = AugmentNoise(style=opt.noisetype)
# Masker
masker = Masker(width=4, mode='interpolate', mask_type='all')
# Network
network = UNet(in_channels=opt.n_channel,
                out_channels=opt.n_channel,
                wf=opt.n_feature)
if opt.parallel:
    network = torch.nn.DataParallel(network)
network = network.cuda()
# load pre-trained model
network = load_network(opt.checkpoint, network, strict=True)
beta = opt.beta

# turn on eval mode
network.eval()

# validation
save_test_path = os.path.join(opt.save_test_path, opt.dataset_name)
os.makedirs(save_test_path, exist_ok=True)

validation_path = os.path.join(save_test_path, opt.log_name)
os.makedirs(validation_path, exist_ok=True)
np.random.seed(101)

gauss5_50


In [177]:
torch.cuda.empty_cache()

In [179]:

save_dir = validation_path
os.makedirs(save_dir, exist_ok=True)

repeat_times = opt.repeat_times

for i in range(repeat_times):
    for im, name in tqdm(test_dataloader):
        # print('im shape', im.shape)
        
        if type(im) == torch.Tensor:
            
            if len(im.shape) == 4: 
                im = im.squeeze(0)
            im = im.permute(1,2,0)
            im = im.numpy()
        # print(im.shape)
        origin255 = im.copy() * 255
        origin255 = origin255.astype(np.uint8)
        im = np.array(im, dtype=np.float32) #/ 255.0
        # noisy_im = noise_adder.add_valid_noise(im)
        
        noisy_im = im.copy() # add
        # print('noisy_im.shape ', noisy_im.shape)
        noisy255 = noisy_im.copy()
        noisy255 = np.clip(noisy255 * 255.0 + 0.5, 0,
                            255).astype(np.uint8)


        # padding to square
        H = noisy_im.shape[0]
        W = noisy_im.shape[1]
        val_size = (max(H, W) + 31) // 32 * 32
        noisy_im = np.pad(
            noisy_im,
            [[0, val_size - H], [0, val_size - W], [0, 0]],
            'reflect')

        transformer = transforms.Compose([transforms.ToTensor()])
        noisy_im = transformer(noisy_im)
        noisy_im = torch.unsqueeze(noisy_im, 0)
        noisy_im = noisy_im.cuda()
        with torch.no_grad():
            n, c, h, w = noisy_im.shape
            # print('noisy_im.shape ', noisy_im.shape)
            net_input, mask = masker.train(noisy_im)
            # print('net_input shape', net_input.shape)
            noisy_output = (network(net_input)*mask).view(n,-1,c,h,w).sum(dim=1)
            exp_output = network(noisy_im)
        pred_dn = noisy_output[:, :, :H, :W]
        pred_exp = exp_output[:, :, :H, :W]
        pred_mid = (pred_dn + beta*pred_exp) / (1 + beta)

        pred_dn = pred_dn.permute(0, 2, 3, 1)
        pred_exp = pred_exp.permute(0, 2, 3, 1)
        pred_mid = pred_mid.permute(0, 2, 3, 1)

        pred_dn = pred_dn.cpu().data.clamp(0, 1).numpy().squeeze(0)
        pred_exp = pred_exp.cpu().data.clamp(0, 1).numpy().squeeze(0)
        pred_mid = pred_mid.cpu().data.clamp(0, 1).numpy().squeeze(0)

        pred255_dn = np.clip(pred_dn * 255.0 + 0.5, 0,
                            255).astype(np.uint8)
        pred255_exp = np.clip(pred_exp * 255.0 + 0.5, 0,
                            255).astype(np.uint8)
        pred255_mid = np.clip(pred_mid * 255.0 + 0.5, 0,
                            255).astype(np.uint8)                   

        

        # visualization
        # print('noisy255 shape', noisy255.shape)
        if opt.n_channel == 1:
            color_mode = 'L'
            save_path = os.path.join(save_dir, f"{name}-{i}_clean.png")
            Image.fromarray(origin255.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_noisy.png")    
            Image.fromarray(noisy255.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_dn.png") 
            Image.fromarray(pred255_dn.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_exp.png") 
            Image.fromarray(pred255_exp.squeeze()).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_mid.png") 
            Image.fromarray(pred255_mid.squeeze()).convert(color_mode).save(save_path)
        else:
            color_mode = 'L'
            save_path = os.path.join(save_dir, f"{name}-{i}_clean.png")
            Image.fromarray(origin255).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_noisy.png")    
            Image.fromarray(noisy255).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_dn.png") 
            Image.fromarray(pred255_dn).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_exp.png") 
            Image.fromarray(pred255_exp).convert(color_mode).save(save_path)

            save_path = os.path.join(save_dir, f"{name}-{i}_mid.png") 
            Image.fromarray(pred255_mid).convert(color_mode).save(save_path)




100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
100%|██████████| 1/1 [00:00<00:00,  1.59it/s]


In [37]:
image = cv2.imread('./data/test/Microscopy/crop_noisy.tiff')
image.shape

(768, 1024, 3)

In [38]:
os.mkdir('./data/test/Crop')

In [39]:
cv2.imwrite('./data/test/Crop/crop_.tiff',image[:480,:640,:] )

True