In [2]:
import torch
from tifffile import imsave

from config import Eval
from model import UNet2D, Criterion
from data import getDataLoader

# configurations
config = Eval()
device = torch.device('cuda')

# load model
net = UNet2D(config).to(device)  # config not important since load
net.load_state_dict(torch.load(
    "{}.pt".format(config.cpt_load_path), 
    map_location=device)['net']
)
net.eval().half()

# dataloader
dataloader = getDataLoader(config)[0]

In [3]:
with torch.no_grad():
    frame = None
    outputs = None
    for i, (frames, _) in enumerate(dataloader):
        # forward, normalize each subframe
        output  = net(frames.half().to(device))
        output /= output.max()

        # store subframe to a [100, *output.shape] tensor, i.e., otuputs
        if outputs == None: outputs = output
        else: outputs = torch.cat((outputs, output))

        # combine the all subframe, i.e., outputs, to a frame
        if len(outputs) != 100: continue
        if frame == None: 
            frame  = dataloader.dataset.combineFrame(outputs) # type: ignore
        else: 
            frame += dataloader.dataset.combineFrame(outputs) # type: ignore
        outputs = None

        # store the frame to tif after 100 frame (non-convolve, normolized)
        if (i+1) % 100 != 0: continue
        imsave('data/eval/i.tif', 
            (frame.cpu().detach() * 255).to(torch.uint8).numpy())
        frame = None

In [3]:
# add Gaussian blur and save
kernel = Criterion.gaussianKernel(3).half().to(device)
frame  = Criterion.gaussianBlur3d(frame, kernel)  # type: ignore
frame /= frame.max()
imsave('data/eval/30.tif', (frame.cpu().detach() * 255).to(torch.uint8).numpy())

In [10]:
frame = None
outputs = None
for i, (frames, labels) in enumerate(dataloader):
    # store subframe to a [100, *output.shape] tensor, i.e., otuputs
    if outputs == None: outputs = labels.half().to(device)
    else: outputs = torch.cat((outputs, labels.half().to(device)))

    # combine the all subframe, i.e., outputs, to a frame
    if len(outputs) != 100: continue
    if frame == None:
        frame  = dataloader.dataset.combineFrame(outputs) # type: ignore
    else:
        frame += dataloader.dataset.combineFrame(outputs) # type: ignore
    outputs = None
# add Gaussian blur and save
kernel = Criterion.gaussianKernel(3).half().to(device)
frame  = Criterion.gaussianBlur3d(frame, kernel)  # type: ignore
frame /= frame.max()
imsave('data/eval/label.tif', (frame.cpu().detach() * 255).to(torch.uint8).numpy())

In [8]:
# add Gaussian blur and save
kernel = Criterion.gaussianKernel(3).half().to(device)
frame  = Criterion.gaussianBlur3d(frame, kernel)  # type: ignore
frame /= frame.max()
imsave('data/eval/label.tif', (frame.cpu().detach() * 255).to(torch.uint8).numpy())