In [1]:
import torch
from tcupgan import DataGenerator, LSTMUNet
from tcupgan.disc import PatchDiscriminator
from tcupgan.losses import fc_tversky
from tcupgan.vanilla import TemporalUNet
from tcupgan.fc_utils import query_images_and_masks
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import time
from torchinfo import summary
%matplotlib inline
import glob, json
import netCDF4 as nc

Initialize the LSTM and 3D Conv U-Net architectures so we can compare performance on both

In [2]:
device = 'cpu'
hidden = [8, 18, 24, 32, 64, 128, 128]
bottleneck_dims = [64, 32, 16, 8]
generator_lstm = LSTMUNet(hidden_dims=hidden, bottleneck_dims=bottleneck_dims,
                     input_channels=1, output_channels=1).to(device)
# device = 'cuda'
hidden = [8, 18, 24, 32, 64, 128, 128]
bottleneck_dims = [64, 32, 16, 8]
generator_unet = TemporalUNet(hidden_dims=hidden, bottleneck_dims=bottleneck_dims,
                     input_channels=1, output_channels=1).to(device)


Load the checkpoints for each model

In [3]:
generator_unet.eval()
generator_lstm.eval()

generator_lstm.load_state_dict(torch.load('./checkpoints-FC.resized-gamma05/generator_ep_150.pth', map_location=device))
generator_unet.load_state_dict(torch.load('./checkpoints-3DUNET/generator_ep_150.pth', map_location=device))

<All keys matched successfully>

In [4]:
def get_pred(img, which_gen):
    img = np.expand_dims(img, axis=0)
    with torch.no_grad():
        IMG = torch.Tensor(img).to(device)
        pred = np.array(which_gen(IMG).cpu().numpy()[0])
        
    return pred

Load a list of image slices and corresponding subject IDs

In [5]:
with open('./image_stack_list.json', 'r') as f:
    slice_mapping = json.load(f)

FileNotFoundError: [Errno 2] No such file or directory: './image_stack_list.json'

In [6]:
all_slices = []
for kk in slice_mapping:
    if len(kk)>10:
        print(f"Subject {kk[0]} has {len(kk)} slices")
        all_slices.append(kk[0])

Subject 50494631 has 50 slices
Subject 56399759 has 25 slices
Subject 56399824 has 60 slices
Subject 56399954 has 189 slices
Subject 56400162 has 114 slices
Subject 56400344 has 21 slices
Subject 56400369 has 30 slices
Subject 56400408 has 150 slices
Subject 56400598 has 67 slices
Subject 56400789 has 105 slices
Subject 56477542 has 285 slices
Subject 56478334 has 274 slices
Subject 56479153 has 395 slices
Subject 56480550 has 85 slices
Subject 56480760 has 160 slices
Subject 56481693 has 13 slices
Subject 56481883 has 305 slices


Loop through each subject, build the image cube from the 2D slices and predict. We will save the outputs into a NetCDF file in the `3D_cube_renditions` directory

In [7]:
for each_sequence in all_slices:
    matched_seq = [i for i in slice_mapping if each_sequence in i][0]
    
    # query the input images from the subject images and build the cube
    # `umii_path` corresponds to the root folder containing the images and masks
    # directories from the DataLad repo
    imgs, masks = query_images_and_masks(matched_seq, umii_path='../umii-fatchecker-dataset/')
    
    data_3d = np.asarray(imgs[f'{matched_seq[0]}_4'])/255.
    data_mask = np.asarray(masks[f'{matched_seq[0]}_4'])
    
    # get the prediction from the LSTM and 3D U-Nets
    pred_mask_3d_unet = get_pred(data_3d, generator_unet)
    pred_mask_3d_lstm = get_pred(data_3d, generator_lstm)

    # save out to a NetCDF file
    with nc.Dataset(f'./3D_cube_renditions/{matched_seq[0]}.nc', 'w') as dset:
        dset.createDimension('x', 256)
        dset.createDimension('y', 256)
        dset.createDimension('z', len(data_3d))

        imgVar = dset.createVariable('img', 'f8', ('x', 'y', 'z'))
        unetVar = dset.createVariable('unet', 'f8', ('x', 'y', 'z'))
        tcupVar = dset.createVariable('tcup', 'f8', ('x', 'y', 'z'))

        imgVar[:] = np.transpose(np.squeeze(data_3d,axis=1), (2, 1, 0))
        unetVar[:] = np.transpose(np.squeeze(pred_mask_3d_unet, axis=1), (2, 1, 0))
        tcupVar[:] = np.transpose(np.squeeze(pred_mask_3d_lstm, axis=1), (2, 1, 0))

    print(f'Done {each_sequence}')

Done 56479153
