In [1]:
from __future__ import division
import os
import logging
import time
import glob
import datetime
import argparse
import numpy as np
from scipy.io import loadmat, savemat

import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from arch_unet import UNet
import utils as util
from collections import OrderedDict
from tqdm import tqdm

from Sber_utils import Masker, load_network, generate_mask, interpolate_mask, depth_to_space

In [20]:
parser = argparse.ArgumentParser()
parser.add_argument("--noisetype", type=str, default="gauss25", choices=['gauss25', 'gauss5_50', 'poisson30', 'poisson5_50'])
parser.add_argument('--checkpoint', type=str, default='./*.pth')
parser.add_argument('--test_dir', type=str, default='./data/test')
parser.add_argument('--save_test_path', type=str, default='./test')
parser.add_argument('--log_name', type=str, default='b2u_unet_g25_112rf20')
parser.add_argument('--gpu_devices', default='0', type=str)
parser.add_argument('--parallel', action='store_true')
parser.add_argument('--n_feature', type=int, default=48)
parser.add_argument('--n_channel', type=int, default=1)
parser.add_argument("--beta", type=float, default=20.0)
parser.add_argument('--dataset_name', type=str)
parser.add_argument('--repeat_times', type=int, default=1)
parser.add_argument('--_continue', type=bool, default=True)

opt = parser.parse_args(['--noisetype', 'gauss5_50', 
                        # '--checkpoint', './pretrained_models/g5-50_112rf20_beta19.4.pth',
                        # '--checkpoint', './experiments/my_models/b2u_second/models/epoch_model_100.pth',
                        # '--checkpoint', './experiments/my_models/b2u_first/models/epoch_model_100.pth',
                        '--checkpoint', './experiments/my_models/b2u_crystal_first/models/epoch_model_120.pth',
                        '--test_dir', './data/test',
                        '--save_test_path', './test',
                        '--log_name', 'b2u_Crystal_focus_0_dose_180_ep120',
                        '--beta', '19.4',
                        '--dataset_name', 'Crystal_focus_0_dose_180'])
                        
systime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')
operation_seed_counter = 0
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_devices
torch.set_num_threads(8)

In [21]:
class Image_cropping():
    def __init__(self, img_size, crop_size, margin) -> None:
        self.margin = margin
        self.img_size = img_size
        self.stride_v = crop_size[0] - margin
        self.stride_h = crop_size[1] - margin
        self.crop_size = crop_size
        self.h_steps = int(np.ceil(img_size[1]/self.stride_h))
        self.v_steps = int(np.ceil(img_size[0]/self.stride_v))
        self.crop_nums = self.h_steps * self.v_steps

    def get_position(self, index):
        v_step = int(np.floor(index/self.v_steps))
        h_step = index - (v_step*self.v_steps)

        return v_step, h_step
    
    
    def crop_image(self, img, index):
        v_step, h_step = self.get_position(index)

        return img[v_step*self.stride_v : v_step*self.stride_v + self.crop_size[0],
                                     h_step*self.stride_h : h_step*self.stride_h + self.crop_size[1]]
    
    def concat(self, img_array):
        img = np.empty(self.img_size)
        i = 0
        for v in range(self.v_steps):
            for h in range(self.h_steps):
                flag_v = 1
                flag_h = 1

                if v == 0:
                    flag_v = 0
                if h == 0:
                    flag_h = 0

                
                img[v*self.stride_v + flag_v * (self.margin//2): v*self.stride_v + self.crop_size[0],
                                     h*self.stride_h + flag_h * (self.margin//2):h*self.stride_h + self.crop_size[1]] = img_array[i] [flag_v * (self.margin//2):, flag_h * (self.margin//2):]


                i+=1

        return img

In [22]:
def apply_median(img, percentile):
    tresh = np.percentile(img.ravel(), percentile)
    mask = np.zeros_like(img)
    mask[img < tresh] = 1
    median_img = cv2.medianBlur(img, 7)
    img = img*(1 - mask) + median_img * mask
    return img

In [23]:
class MicroscopyDataset(Dataset):
    def __init__(self, img_dir, cropping, channels=3, transform=None, apply_median=False):
        self.cropping_object = cropping
        self.img_dir = img_dir
        self.transform = transform
        self.crop_nums = cropping.crop_nums
        self.img_filenames = os.listdir(img_dir)
        self.apply_median = apply_median
        if channels == 1:
            self.channels = 0
        else:
            self.channels = channels

    def __len__(self):
        return (len(self.img_filenames) * self.crop_nums)

    def __getitem__(self, crop_idx):
        image_idx = crop_idx // self.crop_nums
        img_crop_idx = crop_idx % self.crop_nums
        # print(f'img index {image_idx}, img_crop_idx {img_crop_idx}')
        name = os.path.split(self.img_filenames[image_idx])[-1].split('.')[0]

        image = cv2.imread(os.path.join(self.img_dir, self.img_filenames[image_idx]), self.channels)

        crop = self.cropping_object.crop_image(image, img_crop_idx)

        if self.apply_median:
            crop = apply_median(crop, 10)
        

        if self.transform:
            crop = self.transform(crop)

        return crop, name

In [24]:
# validation
save_test_path = os.path.join(opt.save_test_path, opt.dataset_name)
os.makedirs(save_test_path, exist_ok=True)

validation_path = os.path.join(save_test_path, opt.log_name)
os.makedirs(validation_path, exist_ok=True)
np.random.seed(101)


# Dataset
test_transforms = transforms.Compose([transforms.ToTensor()])
dataset_dir = os.path.join(opt.test_dir, opt.dataset_name)
cropping = Image_cropping((768, 1024), (768, 1024), 0)
test_dataset = MicroscopyDataset(dataset_dir, cropping, opt.n_channel, transform=test_transforms, apply_median=False)
if opt._continue == True:
    ready_filenames = os.listdir(validation_path)
    ready_filenames = [x.split('.')[0] for x in ready_filenames]
    filenames = os.listdir(dataset_dir)
    print('all images: ', len(filenames))
    print('prepared images: ', len(ready_filenames))

    rest_filenames = [x for x in filenames if x.split('.')[0] not in ready_filenames]
    print('rest images: ', len(rest_filenames))
    test_dataset.img_filenames = rest_filenames

test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


    


# Masker
masker = Masker(width=2, mode='interpolate', mask_type='all')
# Network
network = UNet(in_channels=opt.n_channel,
                out_channels=opt.n_channel,
                wf=opt.n_feature)
if opt.parallel:
    network = torch.nn.DataParallel(network)
network = network.cuda()
# load pre-trained model
network = load_network(opt.checkpoint, network, strict=True)
beta = opt.beta

# turn on eval mode
network.eval()


all images:  55
prepared images:  0
rest images:  55


UNet(
  (head): Sequential(
    (0): LR(
      (block): Sequential(
        (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
      )
    )
    (1): LR(
      (block): Sequential(
        (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
      )
    )
  )
  (down_path): ModuleList(
    (0-4): 5 x LR(
      (block): Sequential(
        (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
      )
    )
  )
  (up_path): ModuleList(
    (0): UP(
      (conv_1): LR(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): LeakyReLU(negative_slope=0.1, inplace=True)
        )
      )
      (conv_2): LR(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), pad

In [25]:

save_dir = validation_path
os.makedirs(save_dir, exist_ok=True)

repeat_times = opt.repeat_times

for i in range(repeat_times):
    mid = []
    for im, name in tqdm(test_dataloader):
        
        if type(im) == torch.Tensor:
            
            if len(im.shape) == 4: 
                im = im.squeeze(0)
            im = im.permute(1,2,0)
            im = im.numpy()

        origin255 = im.copy() * 255
        origin255 = origin255.astype(np.uint8)
        im = np.array(im, dtype=np.float32) #/ 255.0
        
        noisy_im = im.copy() # add
        noisy255 = noisy_im.copy()
        noisy255 = np.clip(noisy255 * 255.0 + 0.5, 0,
                            255).astype(np.uint8)


        # padding to square
        H = noisy_im.shape[0]
        W = noisy_im.shape[1]
        val_size = (max(H, W) + 31) // 32 * 32
        noisy_im = np.pad(
            noisy_im,
            [[0, val_size - H], [0, val_size - W], [0, 0]],
            'reflect')

        transformer = transforms.Compose([transforms.ToTensor()])
        noisy_im = transformer(noisy_im)
        noisy_im = torch.unsqueeze(noisy_im, 0)
        noisy_im = noisy_im.cuda()
        with torch.no_grad():
            n, c, h, w = noisy_im.shape
            
            net_input, mask = masker.train(noisy_im)
            
            noisy_output = (network(net_input)*mask).view(n,-1,c,h,w).sum(dim=1)
            exp_output = network(noisy_im)
        pred_dn = noisy_output[:, :, :H, :W]
        pred_exp = exp_output[:, :, :H, :W]
        pred_mid = (pred_dn + beta*pred_exp) / (1 + beta)

        pred_mid = pred_mid.permute(0, 2, 3, 1)


        pred_mid = pred_mid.cpu().data.clamp(0, 1).numpy().squeeze(0)
        pred255_mid = np.clip(pred_mid * 255.0 + 0.5, 0,
                            255).astype(np.uint8)                   


        mid.append(pred255_mid.squeeze())

        if len(mid) == cropping.crop_nums:
            mid_img = cropping.concat(mid)
        else:
            continue

        color_mode = 'L'

        save_path = os.path.join(save_dir, f"{name[0]}.png") 
        Image.fromarray(mid_img).convert(color_mode).save(save_path)

        mid = []

100%|██████████| 55/55 [00:24<00:00,  2.20it/s]


In [26]:
torch.cuda.empty_cache()