In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
import torch
import skimage.io as skio

from tqdm import tqdm
from src.utils.dataset import DatasetSUPPORT_test_stitch, FrameReader,DatasetSUPPORT_incremental_load
from model.SUPPORT import SUPPORT
from src.utils.util import parse_arguments,get_coordinate_generator
import argparse
import random
import logging
import time
from torch.utils.tensorboard import SummaryWriter
from src.utils.dataset import gen_train_dataloader, random_transform
from src.utils.util import parse_arguments
from model.SUPPORT import SUPPORT


In [2]:
def train(train_dataloader, model, optimizer, rng, writer, epoch, opt):
    """
    Train a model for a single epoch

    Arguments:
        train_dataloader: (Pytorch DataLoader)
        model: (Pytorch nn.Module)
        optimizer: (Pytorch optimzer)
        rng: numpy random number generator
        writer: (Tensorboard writer)
        epoch: epoch of training (int)
        opt: argparse dictionary

    Returns:
        loss_list: list of total loss of each batch ([float])
        loss_list_l1: list of L1 loss of each batch ([float])
        loss_list_l2: list of L2 loss of each batch ([float])
        corr_list: list of correlation of each batch ([float])
    """

    is_rotate = True if model.bs_size[0] == model.bs_size[1] else False
    
    # initialize
    model.train()
    loss_list_l1 = []
    loss_list_l2 = []
    loss_list = []

    L1_pixelwise = torch.nn.L1Loss()
    L2_pixelwise = torch.nn.MSELoss()

    loss_coef = opt.loss_coef

    # training
    for i, data in enumerate(tqdm(train_dataloader)):

        (noisy_image, _, ds_idx) = data
        noisy_image, _ = random_transform(noisy_image, None, rng, is_rotate)
        
        B, T, X, Y = noisy_image.shape
        noisy_image = noisy_image.cuda()
        noisy_image_target = torch.unsqueeze(noisy_image[:, int(T/2), :, :], dim=1)

        optimizer.zero_grad()
        noisy_image_denoised = model(noisy_image)
        loss_l1_pixelwise = L1_pixelwise(noisy_image_denoised, noisy_image_target)
        loss_l2_pixelwise = L2_pixelwise(noisy_image_denoised, noisy_image_target)
        loss_sum = loss_coef[0] * loss_l1_pixelwise + loss_coef[1] * loss_l2_pixelwise
        loss_sum.backward()
        optimizer.step()

        loss_list_l1.append(loss_l1_pixelwise.item())
        loss_list_l2.append(loss_l2_pixelwise.item())
        loss_list.append(loss_sum.item())

        # print log
        if (epoch % opt.logging_interval == 0) and (i % opt.logging_interval_batch == 0):
            loss_mean = np.mean(np.array(loss_list))
            loss_mean_l1 = np.mean(np.array(loss_list_l1))
            loss_mean_l2 = np.mean(np.array(loss_list_l2))

            ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            writer.add_scalar("Loss_l1/train_batch", loss_mean, epoch*len(train_dataloader) + i)
            writer.add_scalar("Loss_l2/train_batch", loss_mean_l1, epoch*len(train_dataloader) + i)
            writer.add_scalar("Loss/train_batch", loss_mean_l2, epoch*len(train_dataloader) + i)
            
            logging.info(f"[{ts}] Epoch [{epoch}/{opt.n_epochs}] Batch [{i+1}/{len(train_dataloader)}] "+\
                f"loss : {loss_mean:.4f}, loss_l1 : {loss_mean_l1:.4f}, loss_l2 : {loss_mean_l2:.4f}")

    return loss_list, loss_list_l1, loss_list_l2

In [3]:
numFramesToPred = 9
opt = argparse.Namespace(random_seed=0,
                         epoch=4,
                         n_epochs=20,
                         exp_name='mytest',
                         results_dir='Y:\\WJC\\SUPPORT\\results',
                         input_frames=numFramesToPred,
                         is_folder=False,
                         noisy_data=['E:\\20230621-162555\\FR00001_PV00001.raw'],
                         patch_size=[numFramesToPred, 128, 128],
                         patch_interval=[1, 64, 64],
                         batch_size=16,
                         totalFramesPerEpoch=10000,
                         nConsecFrames=32,
                         model='.\\results\\saved_models\\mytest\\model_6.pth',
                         depth=5,
                         blind_conv_channels=64,
                         one_by_one_channels=[32, 16],
                         last_layer_channels=[64, 32, 16],
                         bs_size=[4, 4],
                         bp=False,
                         unet_channels=[16, 32, 64, 128, 256],
                         lr=0.0005,
                         loss_coef=[0.5, 0.5],
                         use_CPU=False, n_cpu=8,
                         logging_interval_batch=50,
                         logging_interval=1,
                         sample_interval=10,
                         sample_max_t=600,
                         checkpoint_interval=1)


In [33]:
cuda = torch.cuda.is_available() and (not opt.use_CPU)
cuda

True

In [None]:
random.seed(0)
torch.manual_seed(0)

# ----------
# Initialize: Create sample and checkpoint directories
# ----------
print(opt)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
rng = np.random.default_rng(opt.random_seed)

os.makedirs(opt.results_dir + "/images/{}".format(opt.exp_name), exist_ok=True)
os.makedirs(opt.results_dir + "/saved_models/{}".format(opt.exp_name), exist_ok=True)
os.makedirs(opt.results_dir + "/logs".format(opt.exp_name), exist_ok=True)
logging.basicConfig(level=logging.INFO, filename=opt.results_dir + "/logs/{}.log".format(opt.exp_name),\
    filemode="a", format="%(name)s - %(levelname)s - %(message)s")
writer = SummaryWriter(opt.results_dir + "/tsboard/{}".format(opt.exp_name))

# ----------
# Model, Optimizers, and Loss
# ----------
model = SUPPORT(in_channels=opt.input_frames, mid_channels=opt.unet_channels, depth=opt.depth,\
     blind_conv_channels=opt.blind_conv_channels, one_by_one_channels=opt.one_by_one_channels,\
            last_layer_channels=opt.last_layer_channels, bs_size=opt.bs_size, bp=opt.bp)

optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

if cuda:
    model = model.cuda()

# print(opt.results_dir + "/saved_models/%s/model_%d.pth" % (opt.exp_name, opt.epoch-1))
# exit()

if opt.epoch != 0:
    model.load_state_dict(torch.load(opt.results_dir + "/saved_models/%s/model_%d.pth" % (opt.exp_name, opt.epoch-1)))
    optimizer.load_state_dict(torch.load(opt.results_dir + "/saved_models/%s/optimizer_%d.pth" % (opt.exp_name, opt.epoch-1)))
    print('Loaded pre-trained model and optimizer weights of epoch {}'.format(opt.epoch-1))

# ----------
# Training & Validation
# ----------
for epoch in range(opt.epoch, opt.n_epochs):
    #reload random parts of the data every epoch (when too large to fit all in memory)
    dataloader_train = gen_train_dataloader(opt.patch_size, opt.patch_interval, opt.batch_size, \
        opt.noisy_data,totalFrames=opt.totalFramesPerEpoch,numConsecFrames=opt.nConsecFrames)

    loss_list, loss_list_l1, loss_list_l2 =\
        train(dataloader_train, model, optimizer, rng, writer, epoch, opt)

    # logging
    ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

    if (epoch % opt.logging_interval == 0):
        loss_mean = np.mean(np.array(loss_list))
        loss_mean_l1 = np.mean(np.array(loss_list_l1))
        loss_mean_l2 = np.mean(np.array(loss_list_l2))

        writer.add_scalar("Loss/train", loss_mean, epoch)
        writer.add_scalar("Loss_l1/train", loss_mean_l1, epoch)
        writer.add_scalar("Loss_l2/train", loss_mean_l2, epoch)
        logging.info(f"[{ts}] Epoch [{epoch}/{opt.n_epochs}] "+\
            f"loss : {loss_mean:.4f}, loss_l1 : {loss_mean_l1:.4f}, loss_l2 : {loss_mean_l2:.4f}")

    if (opt.checkpoint_interval != -1) and (epoch % opt.checkpoint_interval == 0):
        model_loc = opt.results_dir + "/saved_models/%s/model_%d.pth" % (opt.exp_name, epoch)
        torch.save(model.state_dict(), model_loc)
        torch.save(optimizer.state_dict(), opt.results_dir + "/saved_models/%s/optimizer_%d.pth" % (opt.exp_name, epoch))

    # if (epoch % opt.sample_interval == 0):
    #     skio.imsave(opt.results_dir + "/images/%s/denoised_%d.pth" % (opt.exp_name, epoch), )

    

Namespace(random_seed=0, epoch=4, n_epochs=20, exp_name='mytest', results_dir='./results', input_frames=9, is_folder=False, noisy_data=['C:\\Voltage imaging\\20230621-162555\\FR00001_PV00001.raw'], patch_size=[9, 128, 128], patch_interval=[1, 64, 64], batch_size=16, totalFramesPerEpoch=10000, nConsecFrames=32, model='.\\results\\saved_models\\mytest\\model_6.pth', depth=5, blind_conv_channels=64, one_by_one_channels=[32, 16], last_layer_channels=[64, 32, 16], bs_size=[4, 4], bp=False, unet_channels=[16, 32, 64, 128, 256], lr=0.0005, loss_coef=[0.5, 0.5], use_CPU=False, n_cpu=8, logging_interval_batch=50, logging_interval=1, sample_interval=10, sample_max_t=600, checkpoint_interval=1)
Loaded pre-trained model and optimizer weights of epoch 3
file 1 of 1


Loading file 1: 100%|███████████████| 312/312 [00:11<00:00, 26.46it/s]




 36%|███████████████████████████▍                                                | 13655/37908 [16:04<28:08, 14.36it/s]

In [5]:
def validate(test_dataloader, model):
    """
    Validate a model with a test data
    
    Arguments:
        test_dataloader: (Pytorch DataLoader)
            Should be DatasetFRECTAL_test_stitch!
        model: (Pytorch nn.Module)

    Returns:
        denoised_stack: denoised image stack (Numpy array with dimension [T, X, Y])
    """
    with torch.no_grad():
        model.eval()
        # initialize denoised stack to NaN array.
        denoised_stack = np.zeros(test_dataloader.dataset.noisy_image.shape, dtype=np.float32)
        
        # stitching denoised stack
        # insert the results if the stack value was NaN
        # or, half of the output volume
        for _, (noisy_image, _, single_coordinate,mean,std) in enumerate(tqdm(test_dataloader, desc="validate")):
            noisy_image = noisy_image.cuda() #[b, z, y, x]
            noisy_image_denoised = model(noisy_image)
            T = noisy_image.size(1)
            for bi in range(noisy_image.size(0)): 
                stack_start_w = int(single_coordinate['stack_start_w'][bi])
                stack_end_w = int(single_coordinate['stack_end_w'][bi])
                patch_start_w = int(single_coordinate['patch_start_w'][bi])
                patch_end_w = int(single_coordinate['patch_end_w'][bi])

                stack_start_h = int(single_coordinate['stack_start_h'][bi])
                stack_end_h = int(single_coordinate['stack_end_h'][bi])
                patch_start_h = int(single_coordinate['patch_start_h'][bi])
                patch_end_h = int(single_coordinate['patch_end_h'][bi])

                stack_start_s = int(single_coordinate['init_s'][bi])
                
                denoised_stack[stack_start_s+(T//2), stack_start_h:stack_end_h, stack_start_w:stack_end_w] \
                    = (noisy_image_denoised[bi].squeeze()[patch_start_h:patch_end_h, patch_start_w:patch_end_w]*std+mean).cpu()

        # change nan values to 0 and denormalize
#         denoised_stack = denoised_stack * test_dataloader.dataset.std_image.numpy() + test_dataloader.dataset.mean_image.numpy()

        return denoised_stack


def normalize(image):
    """
    Normalize the image to [mean/std]=[0/1]

    Arguments:
        image: image stack (Pytorch Tensor with dimension [T, X, Y])

    Returns:
        image: normalized image stack (Pytorch Tensor with dimension [T, X, Y])
        mean_image: mean of the image stack (np.float)
        std_image: standard deviation of the image stack (np.float)
    """
    mean_image = torch.mean(image)
    std_image = torch.std(image)

    image -= mean_image
    image /= std_image

    return image, mean_image, std_image

In [15]:
    
print(opt)
########## Change it with your data ##############
data_file = opt.noisy_data[0]
output_file = "./results/denoised_0.h5"
patch_size = opt.patch_size
patch_interval = [1, 32, 32]
batch_size = 16    # lower it if memory exceeds.
##################################################

Namespace(random_seed=0, epoch=0, n_epochs=5, exp_name='mytest', results_dir='./results', input_frames=9, is_folder=False, noisy_data=['C:\\Voltage imaging\\20230621-162555\\FR00001_PV00001.raw'], patch_size=[9, 128, 128], patch_interval=[1, 64, 64], batch_size=16, totalFramesPerEpoch=10000, nConsecFrames=32, model='.\\results\\saved_models\\mytest\\model_6.pth', depth=5, blind_conv_channels=64, one_by_one_channels=[32, 16], last_layer_channels=[64, 32, 16], bs_size=[4, 4], bp=False, unet_channels=[16, 32, 64, 128, 256], lr=0.0005, loss_coef=[0.5, 0.5], use_CPU=False, n_cpu=8, logging_interval_batch=50, logging_interval=1, sample_interval=10, sample_max_t=600, checkpoint_interval=1)


In [16]:
model = SUPPORT(in_channels=opt.input_frames, mid_channels=opt.unet_channels, depth=opt.depth,\
         blind_conv_channels=opt.blind_conv_channels, one_by_one_channels=opt.one_by_one_channels,\
                last_layer_channels=opt.last_layer_channels, bs_size=opt.bs_size, bp=opt.bp).cuda()

In [17]:
epoch = 3
model_loc = opt.results_dir + "/saved_models/%s/model_%d.pth" % (opt.exp_name, epoch)
model_loc

'./results/saved_models/mytest/model_3.pth'

In [18]:
# model_file
model.load_state_dict(torch.load(model_loc))

<All keys matched successfully>

In [27]:
reader = FrameReader(data_file,maxFrames=2400,width=588,height=624,gap=1728,shuffle=False)
testset = DatasetSUPPORT_incremental_load(reader, patch_size=patch_size,\
    patch_interval=patch_interval)
testloader = torch.utils.data.DataLoader(testset, batch_size=10)
test_dataloader = testloader

init: torch.Size([36, 624, 588])


In [28]:
stack_shape = test_dataloader.dataset.output_size
stack_shape

torch.Size([2400, 624, 588])

In [29]:
pwd

'C:\\Voltage imaging\\SUPPORT'

In [30]:
from tifffile import imwrite, memmap

# with h5py.File(output_file, 'w') as hdf5_file:
#     # create an HDF5 dataset to store the denoised stack
#     denoised_stack = hdf5_file.create_dataset("denoised_stack", stack_shape, dtype=np.uint8)
with memmap(output_file, shape=stack_shape, dtype=np.uint8) as denoised_stack:
    with torch.no_grad():
        model.eval()
        # stitching denoised stack
        for _, (noisy_image, _, single_coordinate) in enumerate(tqdm(test_dataloader, desc="validate")):
            
            noisy_image, mean_image, std_image = normalize(noisy_image)
            noisy_image = noisy_image.cuda() #[b, z, y, x]
#             print(noisy_image.shape)
            noisy_image_denoised = model(noisy_image)
            T = noisy_image.size(1)
            for bi in range(noisy_image.size(0)): 
                stack_start_w = int(single_coordinate['stack_start_w'][bi])
                stack_end_w = int(single_coordinate['stack_end_w'][bi])
                patch_start_w = int(single_coordinate['patch_start_w'][bi])
                patch_end_w = int(single_coordinate['patch_end_w'][bi])

                stack_start_h = int(single_coordinate['stack_start_h'][bi])
                stack_end_h = int(single_coordinate['stack_end_h'][bi])
                patch_start_h = int(single_coordinate['patch_start_h'][bi])
                patch_end_h = int(single_coordinate['patch_end_h'][bi])

                stack_start_s = int(single_coordinate['init_s'][bi])

                denoised_stack[stack_start_s+(T//2), stack_start_h:stack_end_h, stack_start_w:stack_end_w] \
                    = (noisy_image_denoised[bi].squeeze()[patch_start_h:patch_end_h, patch_start_w:patch_end_w]*std_image + mean_image).cpu().numpy()

validate: 100%|██████████████████████████████████████████████████████████████████| 65063/65063 [22:11<00:00, 48.86it/s]


In [31]:
import tifffile
with h5py.File(output_file, 'r') as h5_file:
        data = h5_file["denoised_stack"][:]

with tifffile.TiffWriter(output_file+".tif", bigtiff=True) as tiff:
    for i in range(data.shape[0]):
        tiff.save(data[i])


  tiff.save(data[i])
