In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import tensorflow as tf
from utils.tfrecord_utils import *
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import numpy as np
import nibabel as nib
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import sys
import shutil
import itertools
from collections import defaultdict
from IPython import display
import time

In [None]:
tick_size = 20

sns.set(rc={
    'figure.figsize':(10,10), 
    'font.size': 25, 
    "axes.labelsize":25, 
    "xtick.labelsize": tick_size, 
    "ytick.labelsize": tick_size,
    'font.family':'serif',
    'grid.linestyle': '',
    'axes.facecolor': 'white',
    'axes.edgecolor': '0.2',
    'axes.spines.bottom': True,
    'axes.spines.left': True,
    'axes.spines.right': True,
    'axes.spines.top': True,
})

palette = sns.color_palette("Set2", n_colors=6, desat=1)

# View Files

In [None]:
def anim_data(img_path, side_length):
    img_vol = nib.load(img_path).get_fdata()

    num_steps = img_vol.shape[2] // side_length**2
    target_slices = list(range(0, img_vol.shape[2], num_steps))
        
    fig, axs = plt.subplots(side_length, side_length)
    
    k = 0
    for i in range(side_length):
        for j in range(side_length):
            
            axs[i, j].imshow(img_vol[:, :, target_slices[k]].T, cmap='Greys_r', vmin=0.0, vmax=255.0)
            axs[i, j].set_xticks([])
            axs[i, j].set_yticks([])
            k += 1
                
    gap = 1e-4
    plt.subplots_adjust(wspace=gap, hspace=gap)

    display.display(plt.show())
    display.clear_output(wait=True)
    time.sleep(0.001)

In [None]:
FNAMES_PATH = Path("data/test_filenames.txt")
fnames = sorted(
    set(fname.resolve() for fname in map(lambda l: Path(l.strip().split(',')[0]), open(FNAMES_PATH, 'r').readlines())),
    key=lambda f: f.parts[-2],
)

In [None]:
for fname in fnames:
    anim_data(fname, 4)

# View TFRecord

In [None]:
def anim_data(img_slice):
    plt.imshow(img_slice.T, cmap='Greys_r')
    plt.axis('off')
    display.display(plt.show())
    display.clear_output(wait=True)
    time.sleep(0.001)

In [None]:
parse = lambda r: parse_into_slice(r, (256, 256), 6)
ds = tf.data.TFRecordDataset('data/tfrecord_dir/dataset_fold_0_train.tfrecord')\
    .map(parse)

In [None]:
for x, *y in ds:
    anim_data(x.numpy()[:, :, 0])

In [None]:
def annot_figure(fname, num_figs, offset, title):
    x = nib.load(fname).get_fdata()
    
    side = int(np.sqrt(num_figs))
    fig, axs = plt.subplots(
        side, 
        side,
    )
    spaced_slices_a = [80, 110, 140]#list(range(offset, x.shape[2]-offset, x.shape[2] // side))
    spaced_slices_c = [80, 110, 140]#list(range(offset, x.shape[1]-offset, x.shape[1] // side))
    spaced_slices_s = [80, 110, 140]#list(range(offset, x.shape[0]-offset, x.shape[0] // side))
        
    a_count = 0
    c_count = 0
    s_count = 0
    for i in range(side):
        for j in range(side):
            
            if j < (side / 3):
                cur_slice = x[:, :, spaced_slices_a[a_count]].T
                a_count += 1
            elif j >= (side / 3) and j < (2 * side / 3):
                cur_slice = x[:, spaced_slices_c[c_count], ::-1].T
                c_count += 1
            else:
                cur_slice = x[spaced_slices_s[s_count], :, ::-1].T
                s_count += 1
                
                
            axs[i, j].imshow(cur_slice, cmap='Greys_r', vmin=0.0, vmax=255.0)
            axs[i, j].set_xticks([])
            axs[i, j].set_yticks([])
    
    gap = 1e-4
    plt.subplots_adjust(wspace=gap, hspace=gap)
    
    #figure_name = fig_class_dir / "annotated_montage_{}.png".format(Path(row['filename']).stem.split('.')[0])
    #plt.savefig(figure_name, bbox_inches="tight")
    
    
    
    
    #plt.suptitle(fname.name)
    plt.suptitle(title)
    plt.tight_layout()
    #plt.show()
    #plt.close()
    
    
    display.display(plt.show())
    display.clear_output(wait=True)
    time.sleep(0.001)

In [None]:
FNAMES_PATH = Path("data/test_filenames.txt")
fnames = sorted(
    set(fname.resolve() for fname in map(lambda l: Path(l.strip().split(',')[0]), open(FNAMES_PATH, 'r').readlines())),
    key=lambda f: f.parts[-2],
)

In [None]:
def fname_to_class(fname):
    return fname.parts[-2]

In [None]:
fnames_by_class = {k: list(g) for k, g in itertools.groupby(fnames, fname_to_class)}

In [None]:
for k, v in fnames_by_class.items():
    print(k, len(v))

In [None]:
offset = 125
for i, fname in enumerate(fnames_by_class['T1'][offset:]):
    try:
        annot_figure(fname, 9, 10, i + offset)
    except KeyboardInterrupt:
        try:
            input()
        except KeyboardInterrupt:
            break
        else:
            annot_figure(fname, 9, 10, i + offset)
        
        

In [None]:
n = 136
fname = fnames_by_class['T1'][n]
print(fname)
annot_figure(fname, 9, 10, n)