In [None]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

# from data_classes.py_files.custom_datasets import *

# from model_classes.py_files.cnn_model import *
# from model_classes.py_files.pigan_model import *

# from functions import *

%matplotlib qt

In [5]:
root = "/home/ptenkaate/scratch/Master-Thesis/Dataset/original_normalized_rotated"

files = os.listdir(root)
# print(files)
files = [file for file in files if file.split("__")[-1] == "pcmra.npy"]

subjs = []
projs = []
augs = []
pcmras = []
masks = []

for file in files: 
    subj, proj, aug, imtype = file.split("__")
    
    pcmra = np.load(os.path.join(root, f"{subj}__{proj}__{aug}__pcmra.npy"))
    mask = np.load(os.path.join(root, f"{subj}__{proj}__{aug}__mask.npy"))
    
    subjs.append(subj)
    projs.append(proj)
    augs.append(aug)
    pcmras.append(pcmra)
    masks.append(mask)


In [6]:
class Show_images(object):
    """
    Scroll through slices. Takes an unspecified number of subfigures per figure.
    suptitles: either a str or a list. Represents the 
    main title of a figure. 
    images_titles: a list with tuples, each tuple an np.array and a 
    title for the array subfigure. 
    """
    def __init__(self, suptitles, *images_titles, min_max=[0, 1]):
        # if string if given, make list with that title for 
        # each slice.
        if type(suptitles) == str: 
            self.suptitles = []
            for i in range(images_titles[0][0].shape[2]): 
                self.suptitles.append(suptitles)
        else: 
            self.suptitles = suptitles
                    
        self.fig, self.ax = plt.subplots(1,len(images_titles))

        # split tuples with (image, title) into lists
        self.images = [x[0] for x in images_titles]
        self.titles = [x[1] for x in images_titles]

        # get the number of slices that are to be shown
        rows, cols, self.slices = self.images[0].shape        
        self.ind = 0

        self.fig.suptitle(self.suptitles[self.ind]) # set title 

        self.plots = []
        
        # start at slice 10 if more than 20 slices, 
        # otherwise start at middle slice.
        if self.images[0].shape[2] > 20: 
            self.ind = 10
        else:
            self.ind = self.images[0].shape[2] // 2
        
        # make sure ax is an np array
        if type(self.ax) == np.ndarray:
            pass
        else: 
            self.ax = np.array([self.ax])
        
        # create title for each subfigure in slice
        for (sub_ax, image, title) in zip(self.ax, self.images, self.titles): 
            sub_ax.set_title(title)
            plot = sub_ax.imshow(image[:, :, self.ind], vmin=min_max[0], vmax=min_max[1])
            self.plots.append(plot)

            
        # link figure to mouse scroll movement
        self.plot_show = self.fig.canvas.mpl_connect('scroll_event', self.onscroll)
        

    def onscroll(self, event):
        """
        Shows next or previous slice with mouse scroll.
        """
        if event.button == 'up':
            self.ind = (self.ind - 1) % self.slices
        else:
            self.ind = (self.ind + 1) % self.slices
        
        self.update()
        

    def update(self):
        """
        Updates the figure.
        """
        self.fig.suptitle(self.suptitles[self.ind])
        
        for plot, image in zip(self.plots, self.images):
            plot.set_data(image[:, :, self.ind])
        
        self.ax[0].set_ylabel('Slice Number: %s' % self.ind)
        self.plots[0].axes.figure.canvas.draw()

In [7]:
def show_data(data, shape=(64, 64, 24)):
    
    
    titles = []
    
    pcmras = []
    masks = []
    
    for i in range(len(data[0])):
        subj, proj, aug, pcmra, mask = data[0][i], data[1][i], data[2][i], data[3][i], data[4][i]

        pcmras.append(pcmra.astype(np.float64))
        masks.append(mask.astype(np.float64))
    
        
        titles += [f"{i}: {subj} {proj} {aug}" for slce in range(shape[2])]
    
    pcmras = np.concatenate(pcmras, axis=2)
    masks = np.concatenate(masks, axis=2)
    
    pcmra_masks = pcmras + masks

    return Show_images(titles, (pcmras, "pcmras"), (masks, "masks"), (pcmra_masks, "pcrmas + masks"), min_max=[0, 2])

window = show_data([subjs, projs, augs, pcmras, masks])