In [1]:
# !pip install nibabel
# !pip install medpy
# !pip install tensorboard
# !pip install --upgrade tensorboard
# !pip install --upgrade torch

In [2]:
import os
from os import listdir
from os.path import isfile, join
import time
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import nibabel as nib
from medpy.io import load
import random

%matplotlib inline

## utils

In [3]:
def mpl_image_grid(images):
    """
    Create an image grid from an array of images. Show up to 16 images in one figure

    Arguments:
        image {Torch tensor} -- NxWxH array of images

    Returns:
        Matplotlib figure
    """
    # Create a figure to contain the plot.
    n = min(images.shape[0], 16) # no more than 16 thumbnails
    rows = 4
    cols = (n // 4) + (1 if (n % 4) != 0 else 0)
    figure = plt.figure(figsize=(2*rows, 2*cols))
    plt.subplots_adjust(0, 0, 1, 1, 0.001, 0.001)
    for i in range(n):
        # Start next subplot.
        plt.subplot(cols, rows, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        if images.shape[1] == 3:
            # this is specifically for 3 softmax'd classes with 0 being bg
            # We are building a probability map from our three classes using 
            # fractional probabilities contained in the mask
            vol = images[i].detach().numpy()
            img = [[[(1-vol[0,x,y])*vol[1,x,y], (1-vol[0,x,y])*vol[2,x,y], 0] \
                            for y in range(vol.shape[2])] \
                            for x in range(vol.shape[1])]
            plt.imshow(img)
        else: # plotting only 1st channel
            plt.imshow((images[i, 0]*255).int(), cmap= "gray")

    return figure

def log_to_tensorboard(writer, loss, data, target, prediction_softmax, prediction, counter):
    """Logs data to Tensorboard

    Arguments:
        writer {SummaryWriter} -- PyTorch Tensorboard wrapper to use for logging
        loss {float} -- loss
        data {tensor} -- image data
        target {tensor} -- ground truth label
        prediction_softmax {tensor} -- softmax'd prediction
        prediction {tensor} -- raw prediction (to be used in argmax)
        counter {int} -- batch and epoch counter
    """
    writer.add_scalar("Loss",\
                    loss, counter)
    writer.add_figure("Image Data",\
        mpl_image_grid(data.float().cpu()), global_step=counter)
    writer.add_figure("Mask",\
        mpl_image_grid(target.float().cpu()), global_step=counter)
    writer.add_figure("Probability map",\
        mpl_image_grid(prediction_softmax.cpu()), global_step=counter)
    writer.add_figure("Prediction",\
        mpl_image_grid(torch.argmax(prediction.cpu(), dim=1, keepdim=True)), global_step=counter)

def save_numpy_as_image(arr, path):
    """
    This saves image (2D array) as a file using matplotlib

    Arguments:
        arr {array} -- 2D array of pixels
        path {string} -- path to file
    """
    plt.imshow(arr, cmap="gray") #Needs to be in row,col order
    plt.savefig(path)

def med_reshape(image, new_shape):
    """
    This function reshapes 3D data to new dimension padding with zeros
    and leaving the content in the top-left corner

    Arguments:
        image {array} -- 3D array of pixel data
        new_shape {3-tuple} -- expected output shape

    Returns:
        3D array of desired shape, padded with zeroes
    """

    reshaped_image = np.zeros(new_shape)
    x, y, z = image.shape
    reshaped_image[:x, :y, :z] = image

    return reshaped_image

## volume_stats

In [4]:
def Dice3d(a, b):
    """
    This will compute the Dice Similarity coefficient for two 3-dimensional volumes
    Volumes are expected to be of the same size. We are expecting binary masks -
    0's are treated as background and anything else is counted as data

    Arguments:
        a {Numpy array} -- 3D array with first volume
        b {Numpy array} -- 3D array with second volume

    Returns:
        float
    """
    if len(a.shape) != 3 or len(b.shape) != 3:
        raise Exception(f"Expecting 3 dimensional inputs, got {a.shape} and {b.shape}")

    if a.shape != b.shape:
        raise Exception(f"Expecting inputs of the same shape, got {a.shape} and {b.shape}")

    intersection = np.sum((a>0) * (b>0))
    volumes = np.sum(a>0) + np.sum(b>0)

    if volumes == 0:
        return -1

    return 2.*float(intersection) / float(volumes)

def Jaccard3d(a, b):
    """
    This will compute the Jaccard Similarity coefficient for two 3-dimensional volumes
    Volumes are expected to be of the same size. We are expecting binary masks - 
    0's are treated as background and anything else is counted as data

    Arguments:
        a {Numpy array} -- 3D array with first volume
        b {Numpy array} -- 3D array with second volume

    Returns:
        float
    """
    if len(a.shape) != 3 or len(b.shape) != 3:
        raise Exception(f"Expecting 3 dimensional inputs, got {a.shape} and {b.shape}")

    if a.shape != b.shape:
        raise Exception(f"Expecting inputs of the same shape, got {a.shape} and {b.shape}")
        
    intersection = np.sum((a>0) * (b>0))
    union = np.sum(((a>0) + (b>0))>0)
    
    if union == 0:
        return -1
    
    return float(intersection) / float(union)

## HippocampusDatasetLoader

In [5]:
def LoadHippocampusData(root_dir, y_shape, z_shape):
    '''
    This function loads our dataset form disk into memory,
    reshaping output to common size

    Arguments:
        volume {Numpy array} -- 3D array representing the volume

    Returns:
        Array of dictionaries with data stored in seg and image fields as 
        Numpy arrays of shape [AXIAL_WIDTH, Y_SHAPE, Z_SHAPE]
    '''

    image_dir = join(root_dir, 'images')
    label_dir = join(root_dir, 'labels')

    images = [f for f in listdir(image_dir) if (
        isfile(join(image_dir, f)) and f[0] != ".")]

    out = []
    for f in images:

        # We would benefit from mmap load method here if dataset doesn't fit into memory
        # Images are loaded here using MedPy's load method. We will ignore header 
        # since we will not use it
        image, _ = load(join(image_dir, f))
        label, _ = load(join(label_dir, f))

        # normalize all images (but not labels) so that values are in [0..1] range
        image = image / np.max(image)

        # reshape data since CNN tensors need to be of the same size.
        # since we feed individual slices to the CNN, we only need to extend 2 dimensions (coronal and sagittal) out of 3 

        image = med_reshape(image, new_shape=(image.shape[0], y_shape, z_shape))
        label = med_reshape(label, new_shape=(label.shape[0], y_shape, z_shape)).astype(int)

        # Why do we need to cast label to int?
        # To get distinct class of labels 

        out.append({"image": image, "seg": label, "filename": f})

    # Hippocampus dataset only takes about 300 Mb RAM, so we can afford to keep it all in RAM
    print(f"Processed {len(out)} files, total {sum([x['image'].shape[0] for x in out])} slices")
    return np.array(out)

## SlicesDataset

In [6]:
class SlicesDataset(Dataset):
    """
    This class represents an indexable Torch dataset
    which could be consumed by the PyTorch DataLoader class
    """
    def __init__(self, data):
        self.data = data

        self.slices = []

        for i, d in enumerate(data):
            for j in range(d["image"].shape[0]):
                self.slices.append((i, j))

    def __getitem__(self, idx):
        """
        This method is called by PyTorch DataLoader class to return a sample with id idx

        Arguments: 
            idx {int} -- id of sample

        Returns:
            Dictionary of 2 Torch Tensors of dimensions [1, W, H]
        """
        slc = self.slices[idx]
        sample = dict()
        sample["id"] = idx
        
        # create two new keys in the "sample" dictionary, "image" and "seg"
        i, j = slc
        sample['image'] = torch.from_numpy(self.data[i]['image'][j, :, :][None, :])
        sample['seg'] = torch.from_numpy(self.data[i]['seg'][j, :, :][None, :])

        return sample

    def __len__(self):
        """
        This method is called by PyTorch DataLoader class to return number of samples in the dataset

        Returns:
            int
        """
        return len(self.slices)

## RecursiveUNet

In [7]:
class UNet(nn.Module):
    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):
        # norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNet, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
        for i in range(1, num_downs):
            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),
                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
                                             outermost=True)

        self.model = unet_block

    def forward(self, x):
        return self.model(x)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        # downconv
        pool = nn.MaxPool2d(2, stride=2)
        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)

        # upconv
        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)

        if outermost:
            final = nn.Conv2d(out_channels, num_classes, kernel_size=1)
            down = [conv1, conv2]
            up = [conv3, conv4, final]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels,
                                        kernel_size=2, stride=2)
            model = [pool, conv1, conv2, upconv]
        else:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)

            down = [pool, conv1, conv2]
            up = [conv3, conv4, upconv]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(inplace=True))
        return layer

    @staticmethod
    def expand(in_channels, out_channels, kernel_size=3):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        return layer

    @staticmethod
    def center_crop(layer, target_width, target_height):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_width) // 2
        xy2 = (layer_height - target_height) // 2
        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])
            return torch.cat([x, crop], 1)

## UNetInferenceAgent

In [8]:
class UNetInferenceAgent:
    """
    Stores model and parameters and some methods to handle inferencing
    """
    def __init__(self, parameter_file_path='', model=None, device="cpu", patch_size=64):

        self.model = model
        self.patch_size = patch_size
        self.device = device

        if model is None:
            self.model = UNet(num_classes=3)

        if parameter_file_path:
            self.model.load_state_dict(torch.load(parameter_file_path, map_location=self.device))

        self.model.to(device)

    def single_volume_inference(self, volume):
        """
        Runs inference on a single volume of conformant patch size

        Arguments:
            volume {Numpy array} -- 3D array representing the volume

        Returns:
            3D NumPy array with prediction mask
        """
        self.model.eval()

        # volume is a numpy array of shape [X,Y,Z] and I will slice X axis
        slices = []

        # create mask for each slice across the X (0th) dimension. 
        # put all slices into a 3D Numpy array

        for ix in range(0, volume.shape[0]):
#             img = volume[ix,:,:]
#             slc = img.astype(np.single)/np.max(img)
#             slc_tensor = torch.from_numpy(slc).unsqueeze(0).unsqueeze(0).to(self.device)
            slice_tensor = torch.from_numpy(volume[ix,:,:].astype(np.single)).unsqueeze(0).unsqueeze(0)
            pred = self.model(slice_tensor.to(self.device))
            mask = torch.argmax(np.squeeze(pred.cpu().detach()), dim=0)
            slices.append(mask)
        return np.dstack(slices).transpose(2, 0, 1)


    
    def single_volume_inference_unpadded(self, volume, patch_size):
        """
        Runs inference on a single volume of arbitrary patch size,
        padding it to the conformant size first

        Arguments:
            volume {Numpy array} -- 3D array representing the volume

        Returns:
            3D NumPy array with prediction mask
        """
        
        volume = med_reshape(volume, (volume.shape[0], patch_size, patch_size))
        
        return single_volume_inference(self, volume)

## UNetExperiment

In [9]:
class UNetExperiment:
    """
    This class implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    The basic life cycle of a UNetExperiment is:

        run():
            for epoch in n_epochs:
                train()
                validate()
        test()
    """
    def __init__(self, config, split, dataset):
        self.n_epochs = config.n_epochs
        self.split = split
        self._time_start = ""
        self._time_end = ""
        self.epoch = 0
        self.name = config.name

        # Create output folders
        dirname = f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{self.name}'
        self.out_dir = os.path.join(config.test_results_dir, dirname)
        os.makedirs(self.out_dir, exist_ok=True)

        # Create data loaders
        self.train_loader = DataLoader(SlicesDataset(dataset[split["train"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)
        self.val_loader = DataLoader(SlicesDataset(dataset[split["val"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)

        # access volumes directly for testing
        self.test_data = dataset[split["test"]]

        if not torch.cuda.is_available():
            print("WARNING: No CUDA device is found. This may take significantly longer!")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # use a recursive UNet model from German Cancer Research Center, Division of Medical Image Computing
        self.model = UNet()
        self.model.to(self.device)

        # use a standard cross-entropy loss since the model output is essentially
        # a tensor with softmax prediction of each pixel's probability of belonging to a certain class
        self.loss_function = torch.nn.CrossEntropyLoss()

        # use standard SGD method to optimize the weights
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.learning_rate)
        
        # Scheduler helps to update learning rate automatically
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

        # Set up Tensorboard. By default it saves data into runs folder. You need to launch
        self.tensorboard_train_writer = SummaryWriter(comment="_train")
        self.tensorboard_val_writer = SummaryWriter(comment="_val")

    def train(self):
        """
        This method is executed once per epoch and takes 
        care of model weight update cycle
        """
        print(f"Training epoch {self.epoch}...")
        self.model.train()

        # Loop over the minibatches
        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            # Feed data to the model and feed target to the loss function
            data = batch['image'].float()
            target = batch['seg']
            prediction = self.model(data.to(self.device))
            prediction_softmax = F.softmax(prediction, dim=1)
            loss = self.loss_function(prediction_softmax, target[:, 0, :, :].to(self.device))

            # What does each dimension of variable prediction represent?
            # batch_size, 3 classes, coronal, axial

            loss.backward()
            self.optimizer.step()

            if (i % 10) == 0:
                # Output to console on every 10th batch
                print(f"\nEpoch: {self.epoch} Train loss: {loss}, {100*(i+1)/len(self.train_loader):.1f}% complete")

                counter = 100*self.epoch + 100*(i/len(self.train_loader))

                log_to_tensorboard(
                    self.tensorboard_train_writer,
                    loss,
                    data,
                    target,
                    prediction_softmax,
                    prediction,
                    counter)

            print(".", end='')

        print("\nTraining complete")

    def validate(self):
        """
        This method runs validation cycle, using same metrics as 
        Train method. Note that model needs to be switched to eval
        mode and no_grad needs to be called so that gradients do not 
        propagate
        """
        print(f"Validating epoch {self.epoch}...")

        # Turn off gradient accumulation by switching model to "eval" mode
        self.model.eval()
        loss_list = []

        with torch.no_grad():
            for i, batch in enumerate(self.val_loader):              
                data = batch['image'].float()
                target = batch['seg']
                prediction = self.model(data.to(self.device))
                prediction_softmax = F.softmax(prediction, dim=1)
                loss = self.loss_function(prediction_softmax, target[:, 0, :, :].to(self.device))

                print(f"Batch {i}. Data shape {data.shape} Loss {loss}")

                # We report loss that is accumulated across all of validation set
                loss_list.append(loss.item())

        self.scheduler.step(np.mean(loss_list))

        log_to_tensorboard(
            self.tensorboard_val_writer,
            np.mean(loss_list),
            data,
            target,
            prediction_softmax, 
            prediction,
            (self.epoch+1) * 100)
        print(f"Validation complete")

    def save_model_parameters(self):
        """
        Saves model parameters to a file in results directory
        """
        path = os.path.join(self.out_dir, "model.pth")

        torch.save(self.model.state_dict(), path)

    def load_model_parameters(self, path=''):
        """
        Loads model parameters from a supplied path or a
        results directory
        """
        if not path:
            model_path = os.path.join(self.out_dir, "model.pth")
        else:
            model_path = path

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            raise Exception(f"Could not find path {model_path}")

    def run_test(self):
        """
        This runs test cycle on the test dataset.
        Note that process and evaluations are quite different
        Here we are computing a lot more metrics and returning
        a dictionary that could later be persisted as JSON
        """
        print("Testing...")
        self.model.eval()

        inference_agent = UNetInferenceAgent(model=self.model, device=self.device)

        out_dict = {}
        out_dict["volume_stats"] = []
        dc_list = []
        jc_list = []

        # for every in test set
        for i, x in enumerate(self.test_data):
            pred_label = inference_agent.single_volume_inference(x["image"])

            # We compute and report Dice and Jaccard similarity coefficients which 
            # assess how close our volumes are to each other

            dc = Dice3d(pred_label, x["seg"])
            jc = Jaccard3d(pred_label, x["seg"])
            dc_list.append(dc)
            jc_list.append(jc)

            # STAND-OUT SUGGESTION: By way of exercise, consider also outputting:
            # * Sensitivity and specificity (and explain semantic meaning in terms of 
            #   under/over segmenting)
            # * Dice-per-slice and render combined slices with lowest and highest DpS
            # * Dice per class (anterior/posterior)

            out_dict["volume_stats"].append({
                "filename": x['filename'],
                "dice": dc,
                "jaccard": jc
                })
            print(f"{x['filename']} Dice {dc:.4f}. {100*(i+1)/len(self.test_data):.2f}% complete")

        out_dict["overall"] = {
            "mean_dice": np.mean(dc_list),
            "mean_jaccard": np.mean(jc_list)}

        print("\nTesting complete.")
        return out_dict

    def run(self):
        """
        Kicks off train cycle and writes model parameter file at the end
        """
        self._time_start = time.time()

        print("Experiment started.")

        # Iterate over epochs
        for self.epoch in range(self.n_epochs):
            self.train()
            self.validate()

        # save model for inferencing
        self.save_model_parameters()

        self._time_end = time.time()
        print(f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(self._time_end - self._time_start))}")

In [None]:
class Config:
    """
    Holds configuration parameters
    """
    def __init__(self):
        self.name = "Basic_unet"
        self.root_dir = r"../data"
        self.n_epochs = 10
        self.learning_rate = 0.0002
        self.batch_size = 8
        self.patch_size = 64
        self.test_results_dir = "test_results"

if __name__ == "__main__":
    # Get configuration
    c = Config()

    # Load data
    print("Loading data...")

    data = LoadHippocampusData(c.root_dir, y_shape = c.patch_size, z_shape = c.patch_size)

    # create three keys in the dictionary ("train", "val" and "test") with each key storing the array with indices 
    split = dict()
    n = len(data)
    idx_list = list(range(n))
    random.shuffle(idx_list)
    split['train'] = idx_list[ :int(n * 0.7)]
    split['val'] = idx_list[int(n * 0.7) : int(n * 0.85)]
    split['test'] = idx_list[int(n * 0.85):]  

    # Set up and run experiment
    exp = UNetExperiment(c, split, data)

    # run training
    exp.run()

    # prep and run testing
    results_json = exp.run_test()

    results_json["config"] = vars(c)

    with open(os.path.join(exp.out_dir, "results.json"), 'w') as out_file:
        json.dump(results_json, out_file, indent=2, separators=(',', ': '))

Loading data...
Processed 260 files, total 9198 slices
Experiment started.
Training epoch 0...

Epoch: 0 Train loss: 1.0856724977493286, 0.1% complete
..........
Epoch: 0 Train loss: 0.6045093536376953, 1.4% complete
..........
Epoch: 0 Train loss: 0.575187623500824, 2.6% complete
..........
Epoch: 0 Train loss: 0.567802369594574, 3.9% complete
..........
Epoch: 0 Train loss: 0.564964234828949, 5.1% complete
..........
Epoch: 0 Train loss: 0.568565309047699, 6.3% complete
..........
Epoch: 0 Train loss: 0.567985475063324, 7.6% complete
..........
Epoch: 0 Train loss: 0.553459107875824, 8.8% complete
..........
Epoch: 0 Train loss: 0.569023072719574, 10.1% complete
..........
Epoch: 0 Train loss: 0.574790894985199, 11.3% complete
..........
Epoch: 0 Train loss: 0.569877564907074, 12.5% complete
..........
Epoch: 0 Train loss: 0.561027467250824, 13.8% complete
..........
Epoch: 0 Train loss: 0.579002320766449, 15.0% complete
..........
Epoch: 0 Train loss: 0.593955934047699, 16.3% comple

Batch 39. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577476441860199
Batch 40. Data shape torch.Size([8, 1, 64, 64]) Loss 0.564964234828949
Batch 41. Data shape torch.Size([8, 1, 64, 64]) Loss 0.572227418422699
Batch 42. Data shape torch.Size([8, 1, 64, 64]) Loss 0.564933717250824
Batch 43. Data shape torch.Size([8, 1, 64, 64]) Loss 0.578483521938324
Batch 44. Data shape torch.Size([8, 1, 64, 64]) Loss 0.565055787563324
Batch 45. Data shape torch.Size([8, 1, 64, 64]) Loss 0.587943971157074
Batch 46. Data shape torch.Size([8, 1, 64, 64]) Loss 0.593009889125824
Batch 47. Data shape torch.Size([8, 1, 64, 64]) Loss 0.584190309047699
Batch 48. Data shape torch.Size([8, 1, 64, 64]) Loss 0.568107545375824
Batch 49. Data shape torch.Size([8, 1, 64, 64]) Loss 0.568077027797699
Batch 50. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577842652797699
Batch 51. Data shape torch.Size([8, 1, 64, 64]) Loss 0.575431764125824
Batch 52. Data shape torch.Size([8, 1, 64, 64]) Loss 0.567344605922699
Batch 

Batch 154. Data shape torch.Size([8, 1, 64, 64]) Loss 0.574363648891449
Batch 155. Data shape torch.Size([8, 1, 64, 64]) Loss 0.556938111782074
Batch 156. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577354371547699
Batch 157. Data shape torch.Size([8, 1, 64, 64]) Loss 0.597709596157074
Batch 158. Data shape torch.Size([8, 1, 64, 64]) Loss 0.582633912563324
Batch 159. Data shape torch.Size([8, 1, 64, 64]) Loss 0.569145143032074
Batch 160. Data shape torch.Size([8, 1, 64, 64]) Loss 0.564781129360199
Batch 161. Data shape torch.Size([8, 1, 64, 64]) Loss 0.579215943813324
Batch 162. Data shape torch.Size([8, 1, 64, 64]) Loss 0.567802369594574
Batch 163. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576896607875824
Batch 164. Data shape torch.Size([8, 1, 64, 64]) Loss 0.571586549282074
Batch 165. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577384889125824
Batch 166. Data shape torch.Size([8, 1, 64, 64]) Loss 0.563774049282074
Batch 167. Data shape torch.Size([8, 1, 64, 64]) Loss 0.58837121

Batch 18. Data shape torch.Size([8, 1, 64, 64]) Loss 0.574607789516449
Batch 19. Data shape torch.Size([8, 1, 64, 64]) Loss 0.573081910610199
Batch 20. Data shape torch.Size([8, 1, 64, 64]) Loss 0.578544557094574
Batch 21. Data shape torch.Size([8, 1, 64, 64]) Loss 0.585472047328949
Batch 22. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576408326625824
Batch 23. Data shape torch.Size([8, 1, 64, 64]) Loss 0.579795777797699
Batch 24. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576103150844574
Batch 25. Data shape torch.Size([8, 1, 64, 64]) Loss 0.573539674282074
Batch 26. Data shape torch.Size([8, 1, 64, 64]) Loss 0.585655152797699
Batch 27. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576286256313324
Batch 28. Data shape torch.Size([8, 1, 64, 64]) Loss 0.567436158657074
Batch 29. Data shape torch.Size([8, 1, 64, 64]) Loss 0.574333131313324
Batch 30. Data shape torch.Size([8, 1, 64, 64]) Loss 0.590293824672699
Batch 31. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577079713344574
Batch 

Batch 133. Data shape torch.Size([8, 1, 64, 64]) Loss 0.582237184047699
Batch 134. Data shape torch.Size([8, 1, 64, 64]) Loss 0.575645387172699
Batch 135. Data shape torch.Size([8, 1, 64, 64]) Loss 0.577415406703949
Batch 136. Data shape torch.Size([8, 1, 64, 64]) Loss 0.590049684047699
Batch 137. Data shape torch.Size([8, 1, 64, 64]) Loss 0.561729371547699
Batch 138. Data shape torch.Size([8, 1, 64, 64]) Loss 0.585594117641449
Batch 139. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576499879360199
Batch 140. Data shape torch.Size([8, 1, 64, 64]) Loss 0.574424684047699
Batch 141. Data shape torch.Size([8, 1, 64, 64]) Loss 0.576591432094574
Batch 142. Data shape torch.Size([8, 1, 64, 64]) Loss 0.573173463344574
Batch 143. Data shape torch.Size([8, 1, 64, 64]) Loss 0.599906861782074
Batch 144. Data shape torch.Size([8, 1, 64, 64]) Loss 0.561942994594574
Batch 145. Data shape torch.Size([8, 1, 64, 64]) Loss 0.581687867641449
Batch 146. Data shape torch.Size([8, 1, 64, 64]) Loss 0.55687707

In [None]:
results_json

In [None]:
for x in exp.train_loader:
    y = np.squeeze(x['image'][0, :, :])
    print(y.max())
    break

In [None]:
exp.model.eval()

# volume is a numpy array of shape [X,Y,Z] and I will slice X axis
slices = []

# create mask for each slice across the X (0th) dimension. 
# put all slices into a 3D Numpy array
# volume = exp.test_data[0]['image']
volume = exp.test_data[0]['image']
img = volume[0,:,:]
print('img', img.shape, img)
z = np.squeeze(img)
print(z.shape)
print(z.max())
# slc = img.astype(np.single)/np.max(img)
# print('slc', slc.shape, slc)
# slc_tensor = torch.from_numpy(slc).unsqueeze(0).unsqueeze(0).to(exp.device)
# print('slc_tensor', slc_tensor.shape, slc_tensor)
slc_tensor = torch.from_numpy(volume[0,:,:].astype(np.single)).unsqueeze(0).unsqueeze(0).to(exp.device)
pred = exp.model(slc_tensor)
print('pred', pred.shape, pred)
mask = torch.argmax(np.squeeze(pred.cpu().detach()), dim=0)
print('mask', mask.shape, mask)
print((mask>0).sum())

In [None]:
s = exp.test_data[0]['seg']
for x in exp.test_data:
    print((x['seg']>0).sum() / (x['seg']==0).sum())