# Import

In [None]:
import os
from os.path import join
import random
import glob

import numpy as np
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import nibabel as nib
from matplotlib import pyplot as plt

def check_array_equality(ob1, ob2):
  if torch.is_tensor(ob1) or isinstance(ob1, np.ndarray):
    assert (ob2 == ob1).all()
  else:
    assert ob2 == ob1

def check_or_save(obj, path, index=None, header=None):
  if isinstance(obj, pd.DataFrame):
    if index is None or header is None:
      raise ValueError('Index and header must be specified for saving a dataframe')
    if os.path.exists(path):
      if not header:
        saved_df = pd.read_csv(path,header=None)
      else:
        saved_df = pd.read_csv(path)
      naked_df = saved_df.reset_index(drop=True)
      naked_df.columns = range(naked_df.shape[1])
      naked_obj = obj.reset_index(drop=not index)
      naked_obj.columns = range(naked_obj.shape[1])
      if naked_df.round(6).equals(naked_obj.round(6)):
        return
      else:
        diff = (naked_df.round(6) == naked_obj.round(6))
        diff[naked_df.isnull()] = naked_df.isnull() & naked_obj.isnull()
        assert diff.all().all(), "Dataframe is not the same as saved dataframe"
    else:
      obj.to_csv(path, index=index, header=header)
  else:
    if os.path.exists(path):
      saved_obj = torch.load(path)
      if isinstance(obj, list):
        for i in range(len(obj)):
          check_array_equality(obj[i], saved_obj[i])
      else:
        check_array_equality(obj, saved_obj)
    else:
      print(f'Saving to {path}')
      torch.save(obj, path)

# Functions

In [None]:
def power(tensor, gamma):
    if tensor.min() < 0:
        output = tensor.sign() * tensor.abs() ** gamma
    else:
        output = tensor ** gamma
    return output

class RandomGamma(torch.nn.Module):
    def __call__(self, pic):
        ran = np.random.uniform(low=0.25,high=1.75)
        transformed_tensors = power(pic,ran)
        return transformed_tensors
    
    def __repr__(self):
        return self.__class__.__name__ + '()'

In [None]:
def power(tensor, gamma):
    if tensor.min() < 0:
        output = tensor.sign() * tensor.abs() ** gamma
    else:
        output = tensor ** gamma
    return output

In [None]:
def crop_image(image, cx, cy, size):
    """ Crop a 3D image using a bounding box centred at (cx, cy) with specified size. CHANNELS FIRST """
    X, Y = image.shape[1:]
    r = int(size / 2)
    x1, x2 = cx - r, cx + r
    y1, y2 = cy - r, cy + r
    x1_, x2_ = max(x1, 0), min(x2, X)
    y1_, y2_ = max(y1, 0), min(y2, Y)
    # Crop the image
    crop = image[:, x1_: x2_, y1_: y2_]
    # Pad the image if the specified size is larger than the input image size
    if crop.ndim == 3:
        crop = np.pad(crop,
                      ((0, 0), (x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_)),
                      'constant')
    elif crop.ndim == 4:
        crop = np.pad(crop,
                      ((0, 0), (0, 0), (x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_)),
                      'constant')
    else:
        print('Error: unsupported dimension, crop.ndim = {0}.'.format(crop.ndim))
        exit(0)
    return crop

# Process

If you want to parse the original nii images into stacks for training, run this

In [None]:
# Insert path to folder containing all downloaded subject folders here
image_base_folder = 

def get_mid_beat_slice(im, es_slice):
    thresh=(1.0, 99.0)
    best_overlap_es = 0
    for i in range(50):
        im_slice = im[:,:,im.shape[2]//2,i]
        overlap_es = (es_slice==im_slice).sum()
        if overlap_es > best_overlap_es:
            best_overlap_es = overlap_es
            best_i_es = i

    val_l, val_h = np.percentile(im, thresh)
    im_slice = im[:,:,im.shape[2]//2,best_i_es]
    im_slice[im_slice > val_h] = val_h
    try:
        assert np.allclose(im_slice,es_slice)
    except:
        return None
    mid_beat_i = best_i_es//2
    mid_beat_slice = im[:,:,im.shape[2]//2,mid_beat_i]
    mid_beat_slice[mid_beat_slice > val_h] = val_h
    return mid_beat_slice

all_subjects = {}

problem_ids = []
missing_ids = []

for folder in glob.glob(image_base_folder):
    _id = folder.split('/')[-1]

    if _id in all_subjects:
        continue
        
    to_stack = []
    es_slice = None
    for cycle_position in ['sa_ES.nii.gz', 'sa.nii.gz', 'sa_ED.nii.gz']:
        path = join(folder,cycle_position)
        if os.path.exists(path):
            nii = nib.load(path)
            im = nii.get_fdata()
            
            # Too few z-axis slices are bad quality images
            if im.shape[2] <= 7:
                print(f'Too few z-axis slices: {folder}')
                break
                
            # Full cycle volumes are used to extract middle of heart beat slice
            if cycle_position == 'sa.nii.gz':
                mid_heart_slice = get_mid_beat_slice(im, es_slice)
                if mid_heart_slice is None:
                    print(f'ES didnt match: {folder}')
                    break
            else:
                mid_heart_slice = im[:,:,im.shape[2]//2]
            
            # Set es_slice to be used during extraction of mid beat
            if cycle_position == 'sa_ES.nii.gz':
                es_slice = mid_heart_slice
                
            # Pad to be square.
            if mid_heart_slice.shape[1]>mid_heart_slice.shape[0]:
                mid_heart_slice = np.pad(mid_heart_slice, ((((mid_heart_slice.shape[1]-mid_heart_slice.shape[0])//2), ((mid_heart_slice.shape[1]-mid_heart_slice.shape[0])//2)), (0, 0)), 'constant', constant_values=0)
            else:
                mid_heart_slice = np.pad(mid_heart_slice, ((0, 0), (((mid_heart_slice.shape[0]-mid_heart_slice.shape[1])//2), ((mid_heart_slice.shape[0]-mid_heart_slice.shape[1])//2))), 'constant', constant_values=0)
            try:
                assert mid_heart_slice.shape[0]==mid_heart_slice.shape[1], print(mid_heart_slice.shape[0], mid_heart_slice.shape[1])
            except:
                print(f'Shapes didnt match: {folder}')
                break
    
            im_t = torch.tensor(mid_heart_slice)
            to_stack.append(im_t)
        else:
            missing_ids.append(_id)
            print(f'Missing files: {folder}')
            break
    if len(to_stack)==3:
        ims_stacked_t_n = torch.stack(to_stack)
        if ims_stacked_t_n.shape==(3,208,208):
            ims_stacked_t_n=np.pad(ims_stacked_t_n, ((0,0),(1,1),(1,1)), 'constant', constant_values=0)
        all_subjects[_id] = ims_stacked_t_n
    else:
        problem_ids.append(folder)

torch.save(all_subjects, 'preprocessed_cardiac_dict.pt')
torch.save(problem_ids, 'problem_ids_cardiac.pt')
torch.save(missing_ids, 'missing_ids_cardiac.pt')

In [None]:
# Pad the 208x208 to 210x210
all_shapes = []
for key, im in all_subjects.items():
    if im.shape==(3,208,208):
        im=np.pad(im, ((0,0),(1,1),(1,1)), 'constant', constant_values=0)
        all_subjects[key]=im

In [None]:
# Verify only 210x210 left
all_shapes = []
for i in all_subjects.values():
    all_shapes.append(tuple(i.shape))
set(all_shapes)

In [None]:
torch.save(all_subjects, 'preprocessed_cardiac_dict.pt')

In [None]:
# Random quality control
random.seed(2025)
keys = list(all_subjects.keys())
f, axarr = plt.subplots(10, 3, figsize=(20,50))
for i in range(10):
    rand_idx = random.randrange(0,len(all_subjects))
    im = all_subjects[keys[rand_idx]]
    axarr[i,0].imshow(im[0,:,:])
    axarr[i,1].imshow(im[1,:,:])
    axarr[i,2].imshow(im[2,:,:])

## Split

In [None]:
BASE = 
TABULAR_BASE = join(BASE,'tabular')

all_subjects = torch.load('/home/paulhager/Projects/data/cardiac/668815/preprocessed_cardiac_dict.pt')
len(all_subjects)

In [None]:
# Random quality control
random.seed(2030)
keys = list(all_subjects.keys())
f, axarr = plt.subplots(10, 3, figsize=(20,50))
for i in range(10):
    rand_idx = random.randrange(0,len(all_subjects))
    im = all_subjects[keys[rand_idx]]
    assert im.shape == (3,210,210)
    axarr[i,0].imshow(im[0,:,:])
    axarr[i,1].imshow(im[1,:,:])
    axarr[i,2].imshow(im[2,:,:])

In [None]:
tabular_df = pd.read_csv(join(TABULAR_BASE,'cardiac_feature_668815_vector_labeled_noOH.csv'))
tabular_df.set_index('eid', inplace=True)

In [None]:
tabular_ids = list(tabular_df.index)

print(f'There are {len(all_subjects)} images in the dataset')
print(f'There are {len(tabular_ids)} tabular entries in the dataset')

imaging_ids = list(all_subjects.keys())
imaging_ids = [int(i) for i in imaging_ids]
imaging_ids.sort()

overlap_ids = [i for i in imaging_ids if i in tabular_ids]

print(f'There are {len(overlap_ids)} overlap between images and tabular entries in the dataset')

In [None]:
from sklearn.model_selection import train_test_split

train_ids, rest_ids = train_test_split(overlap_ids, test_size=0.3, random_state=2023)
val_ids, test_ids = train_test_split(rest_ids, test_size=0.5, random_state=2023)

check_or_save(train_ids, join(TABULAR_BASE,'ids_train_tabular_imaging.pt'))
check_or_save(val_ids, join(TABULAR_BASE,'ids_val_tabular_imaging.pt'))
check_or_save(test_ids, join(TABULAR_BASE,'ids_test_tabular_imaging.pt'))