# Imports and Utility functions

## <font color='orange'>Imports</font>

In [None]:
import numpy as np
import pandas as pd
import shutil, time, os, requests, random, copy
from itertools import permutations 
import seaborn as sns
import imageio
from skimage.transform import rotate, AffineTransform, warp, resize
#from google.colab.patches import cv2_imshow
from IPython.display import clear_output, Image, SVG
import h5py

#%tensorflow_version 2.x
#%load_ext tensorboard
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array

from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, GlobalAveragePooling2D, AveragePooling2D, BatchNormalization, Reshape

#from tensorflow.keras.layers import Conv3D, MaxPooling3D, GlobalAveragePooling3D, AveragePooling3D

from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Concatenate, Lambda, LeakyReLU

from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras import regularizers, activations
from tensorflow.keras.utils import to_categorical, Sequence

from tensorflow.keras.utils import plot_model

from sklearn.utils import shuffle
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, f1_score
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split


import matplotlib.pyplot as plt
#import matplotlib.animation as animation
%matplotlib inline

In [None]:
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input as ppi_irv2

## <font color='orange'>Downloading and Extracting Data</font>

In [None]:
!wget http://download.cs.stanford.edu/deep/MRNet-v1.0.zip

In [None]:
!unzip MRNet-v1.0.zip -d ~/MRNet

In [None]:
!ls ~/MRNet/MRNet-v1.0/

## <font color='orange'>Analysing and Cleaning Data</font>

### Files

In [None]:
!ls ~/MRNet-v1/MRNet-v1.0

In [None]:
mrnet_path = '~/MRNet-v1/MRNet-v1.0'
contents = os.listdir(mrnet_path)
print(contents)
print('\nLabel Files...')
label_files = [x for x in contents if x.endswith('.csv')]
print(label_files)

### Real Labels

In [None]:
#For Colab /root/MRNet/MRNet-v1.0/
trabn = pd.read_csv(mrnet_path+'/'+'train-abnormal.csv',header=None)
#trabn.head()
tracl = pd.read_csv(mrnet_path+'/'+'train-acl.csv',header=None)
#tracl.head()
trmen = pd.read_csv(mrnet_path+'/'+'train-meniscus.csv',header=None)
#trmen.head()

In [None]:
trabn.columns = ['patient_id','label']
tracl.columns = ['patient_id','label']
trmen.columns = ['patient_id','label']

In [None]:
tr_multilabel = trabn.merge(tracl,on='patient_id').merge(trmen,on='patient_id')
tr_multilabel.columns = ['patient_id','abn','acl','men']
tr_multilabel.head()

In [None]:
#For Colab /root/MRNet/MRNet-v1.0/
valabn = pd.read_csv(mrnet_path+'/'+'valid-abnormal.csv',header=None)
valacl = pd.read_csv(mrnet_path+'/'+'valid-acl.csv',header=None)
valmen = pd.read_csv(mrnet_path+'/'+'valid-meniscus.csv',header=None)

In [None]:
valabn.columns = ['patient_id','label']
valacl.columns = ['patient_id','label']
valmen.columns = ['patient_id','label']

In [None]:
val_multilabel = valabn.merge(valacl,on='patient_id').merge(valmen,on='patient_id')
val_multilabel.columns = ['patient_id','abn','acl','men']
val_multilabel.head(120)

### <font color='blue'>Filename DataFrame</font>

In [None]:
tr_filenames_df = pd.DataFrame(columns=['filename'])
tr_filenames_df['filename'] = os.listdir(mrnet_path+'/train/'+'axial')
tr_filenames_df['patient_id'] = tr_filenames_df.apply(lambda x : int(x['filename'][:-4]),axis=1)
tr_filenames_df = tr_filenames_df[list(('patient_id','filename'))]
tr_filenames_df.sort_values(by=['patient_id'],ascending=True,inplace=True,ignore_index=True)

tr_filenames_df

In [None]:
val_filenames_df = pd.DataFrame(columns=['filename'])
val_filenames_df['filename'] = os.listdir(mrnet_path+'/valid/'+'axial')
val_filenames_df['patient_id'] = val_filenames_df.apply(lambda x : int(x['filename'][:-4]),axis=1)
val_filenames_df = val_filenames_df[list(('patient_id','filename'))]
val_filenames_df.sort_values(by=['patient_id'],ascending=True,inplace=True,ignore_index=True)

val_filenames_df

## <font color='orange'>Visualizing the Data</font>

In [None]:
tracl.iloc[:,1].hist(figsize = (10, 5))

In [None]:
np.count_nonzero(tracl.iloc[:,1]==1)

## <font color='orange'>Utility Functions</font>

## <font color='blue'>Declaring the required PATH variables</font>

In [None]:
#For Colab '/root/MRNet/MRNet-v1.0/'

train_dir = mrnet_path+'/train'
valid_dir = mrnet_path+'/valid'
axial_mode= 'axial'
sagit_mode='sagittal'
coron_mode='coronal'
base_dir = mrnet_path

NUM_FRAMES = 1
batch_size = 32 #32
NUM_CLASSES = 1000 #3
NUM_PATCHES = 4

## <font color='blue'>Callbacks</font>

In [None]:
#os.makedirs('saved_models/')

def get_callbacks(pord,acctype):
    save_dir = '/saved_models/'
    model_name = 'sagittal_' + pord + '_best_model.h5'

    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir+model_name, 
                                                    monitor = 'val_' + acctype + '_accuracy', verbose=1, 
                                                    save_best_only=True, mode='max')

    #reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=4, 
    #                               verbose=1, mode='max', min_lr=0.00001)
                              
    #early = tf.keras.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)

    callbacks_list = [checkpoint]

    return callbacks_list

## <font color='blue'>Performance Metrics</font>

In [None]:
#util_wk2
def TP(y, pred, th=0.5):
    pred_t = (pred > th)
    return np.sum((pred_t == True) & (y == 1))


def TN(y, pred, th=0.5):
    pred_t = (pred > th)
    return np.sum((pred_t == False) & (y == 0))


def FN(y, pred, th=0.5):
    pred_t = (pred > th)
    return np.sum((pred_t == False) & (y == 1))


def FP(y, pred, th=0.5):
    pred_t = (pred > th)
    return np.sum((pred_t == True) & (y == 0))

def get_accuracy(y, pred, th=0.5):
    tp = TP(y,pred,th)
    fp = FP(y,pred,th)
    tn = TN(y,pred,th)
    fn = FN(y,pred,th)
    
    return (tp+tn)/(tp+fp+tn+fn)

def get_prevalence(y):
    return np.sum(y)/y.shape[0]

def sensitivity(y, pred, th=0.5):
    tp = TP(y,pred,th)
    fn = FN(y,pred,th)
    
    return tp/(tp+fn)

def specificity(y, pred, th=0.5):
    tn = TN(y,pred,th)
    fp = FP(y,pred,th)
    
    return tn/(tn+fp)

def get_ppv(y, pred, th=0.5):
    tp = TP(y,pred,th)
    fp = FP(y,pred,th)
    
    return tp/(tp+fp)

def get_npv(y, pred, th=0.5):
    tn = TN(y,pred,th)
    fn = FN(y,pred,th)
    
    return tn/(tn+fn)


def get_performance_metrics(y, pred, class_labels, tp=TP,
                            tn=TN, fp=FP,
                            fn=FN,
                            acc=get_accuracy, prevalence=get_prevalence, 
                            spec=specificity,sens=sensitivity, ppv=get_ppv, 
                            npv=get_npv, auc=roc_auc_score, f1=f1_score,
                            thresholds=[]):
    if len(thresholds) != len(class_labels):
        thresholds = [.5] * len(class_labels)

    columns = ["Injury", "TP", "TN", "FP", "FN", "Accuracy", "Prevalence",
               "Sensitivity",
               "Specificity", "PPV", "NPV", "AUC", "F1", "Threshold"]
    df = pd.DataFrame(columns=columns)
    for i in range(len(class_labels)):
        df.loc[i] = [class_labels[i],
                     round(tp(y[:, i], pred[:, i]),3),
                     round(tn(y[:, i], pred[:, i]),3),
                     round(fp(y[:, i], pred[:, i]),3),
                     round(fn(y[:, i], pred[:, i]),3),
                     round(acc(y[:, i], pred[:, i], thresholds[i]),3),
                     round(prevalence(y[:, i]),3),
                     round(sens(y[:, i], pred[:, i], thresholds[i]),3),
                     round(spec(y[:, i], pred[:, i], thresholds[i]),3),
                     round(ppv(y[:, i], pred[:, i], thresholds[i]),3),
                     round(npv(y[:, i], pred[:, i], thresholds[i]),3),
                     round(auc(y[:, i], pred[:, i]),3),
                     round(f1(y[:, i], pred[:, i] > thresholds[i]),3),
                     round(thresholds[i], 3)]

    df = df.set_index("Injury")
    return df

def bootstrap_metric(y, pred, classes, metric='auc',bootstraps = 100, fold_size = 1000):
    statistics = np.zeros((len(classes), bootstraps))
    if metric=='AUC':
        metric_func = roc_auc_score
    if metric=='Sensitivity':
        metric_func = sensitivity
    if metric=='Specificity':
        metric_func = specificity
    if metric=='Accuracy':
        metric_func = get_accuracy
    for c in range(len(classes)):
        df = pd.DataFrame(columns=['y', 'pred'])
        df.loc[:, 'y'] = y[:, c]
        df.loc[:, 'pred'] = pred[:, c]
        # get positive examples for stratified sampling
        df_pos = df[df.y == 1]
        df_neg = df[df.y == 0]
        prevalence = len(df_pos) / len(df)
        for i in range(bootstraps):
            # stratified sampling of positive and negative examples
            pos_sample = df_pos.sample(n = int(fold_size * prevalence), replace=True)
            neg_sample = df_neg.sample(n = int(fold_size * (1-prevalence)), replace=True)

            y_sample = np.concatenate([pos_sample.y.values, neg_sample.y.values])
            pred_sample = np.concatenate([pos_sample.pred.values, neg_sample.pred.values])
            score = metric_func(y_sample, pred_sample)
            statistics[c][i] = score
    return statistics

def get_confidence_intervals(y,pred,class_labels):
    
    metric_dfs = {}
    for metric in ['AUC','Sensitivity','Specificity','Accuracy']:
        statistics = bootstrap_metric(y,pred,class_labels,metric)
        df = pd.DataFrame(columns=["Mean "+metric+" (CI 5%-95%)"])
        for i in range(len(class_labels)):
            mean = statistics.mean(axis=1)[i]
            max_ = np.quantile(statistics, .95, axis=1)[i]
            min_ = np.quantile(statistics, .05, axis=1)[i]
            df.loc[class_labels[i]] = ["%.2f (%.2f-%.2f)" % (mean, min_, max_)]
        metric_dfs[metric] = df
    return metric_dfs


In [None]:
def build_augmentations():
    ROTATION = [-15,0,15]
    TRANSLATEX = [-6,0,6]
    TRANSLATEY = [-6,0,6]
    SCALING = [1,1.15]
    SHEAR = []
    classes = {}
    cind = 0
    for rot in ROTATION:
        for tranX in TRANSLATEX:
            for tranY in TRANSLATEY:
                for sc in SCALING:
                    classes[cind] = [rot,tranX,tranY,sc]
                    cind+=1
    return classes


augmentations = build_augmentations()

In [None]:
augmentations

## Pretext Patch Prediction Labels

In [None]:
NUM_PATCHES = 9

In [None]:
def build_ppp_labels(num_patches):
    ppp_labels_perms = list(permutations(range(0,num_patches)))
    ppp_labels_perms = [list(t) for t in ppp_labels_perms]
    #print(ppp_labels_perms)

    ppp_labels = {}
    label_num= 0
    for key in ppp_labels_perms:
        ppp_labels[str(key)] = label_num
        label_num+=1
    
    return ppp_labels

PPP_LABELS = build_ppp_labels(NUM_PATCHES)

In [None]:
#print(PPP_LABELS)

In [None]:
def hamdist(l1,l2):
    l1 = list(map(str,list(map(int,l1.strip('[]').split(',')))))
    l2 = list(map(str,list(map(int,l2.strip('[]').split(',')))))
    
    #print(l1)
    #print(l2)
    
    dist = 0
    
    for i in range(len(l1)):
        dist+=int(l1[i]!=l2[i])
        
    return dist

In [None]:
hamdist('[1,2,3,4,5,6,7,8,9]','[9,6,5,4,7,2,3,1,8]')

In [None]:
keys = ['[0, 1, 2, 3, 4, 5, 6, 7, 8]']

ppplabels = list(PPP_LABELS.keys())

for pl in ppplabels:
    avg_ham_dist = []
    for k in keys:
        avg_ham_dist += [hamdist(k,pl)]
    all_true = 1
    for hd in avg_ham_dist:
        if hd < 5:
            all_true = 0
            
    if all_true == 1:
        keys.append(pl)

In [None]:
len(ppplabels)

In [None]:
keys = random.sample(keys,1000)
if '[0, 1, 2, 3, 4, 5, 6, 7, 8]' not in keys:
    PPP_LABELS_DICT = {'[0, 1, 2, 3, 4, 5, 6, 7, 8]':0}
    keys = keys[:-1]
    v = 1
else:
    PPP_LABELS_DICT = {}
    v = 0
for k in keys:
    PPP_LABELS_DICT[k] = v
    v+=1
#print(PPP_LABELS_DICT)

In [None]:
PPP_LABELS = PPP_LABELS_DICT

In [None]:
'[0, 1, 2, 3, 4, 5, 6, 7, 8]' in PPP_LABELS.keys()

In [None]:
len(PPP_LABELS)

## SSL PPP Data Generator

In [None]:
NUM_PATCHES = 9
NUM_CLASSES = 1000

In [None]:
class PPPDataGen(Sequence):
    def __init__(self,phase,mode,base_dir,filenames_df,preprocess_input=None,
                 ppp_labels_dict = PPP_LABELS,augmentations_dict = augmentations,
                 batch_size=8,num_patches = NUM_PATCHES,num_frames = NUM_FRAMES,
                 num_classes=NUM_CLASSES,hor_flip = True,data_aug = True):
        self.base_dir = base_dir
        self.ph_mode_dir = base_dir+'/'+phase+'/'+mode
        self.filenames = os.listdir(self.ph_mode_dir)
        self.phase = phase
        self.mode = mode
        self.batch_size = batch_size
        self.num_patches = num_patches
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.filenames_df = filenames_df
        self.preprocess_input = preprocess_input
        self.ppp_labels_dict = ppp_labels_dict
        self.augmentations_dict = augmentations_dict 
        self.hor_flip = hor_flip
        self.data_aug = data_aug
        
        self.invGamma100 = 1.0 
        self.invGamma115 = 1.0/1.15
        self.invGamma085 = 1.0/0.85
        self.table100 = np.array([((k / 255.0) ** self.invGamma100) * 255 for k in np.arange(0, 256)]).astype("uint8")
        self.table115 = np.array([((k / 255.0) ** self.invGamma115) * 255 for k in np.arange(0, 256)]).astype("uint8")
        self.table085 = np.array([((k / 255.0) ** self.invGamma085) * 255 for k in np.arange(0, 256)]).astype("uint8")
        self.gamma_dict100 = dict(zip(range(256),self.table100))
        self.gamma_dict115 = dict(zip(range(256),self.table115))
        self.gamma_dict085 = dict(zip(range(256),self.table085))


    def get_random_shuffle_order(self,batch_sz):
        blist = list(range(batch_sz))
        random.shuffle(blist)
        return blist
    
    def load_volume(self,mode,file_idx):
        filePoolLen = self.filenames_df.shape[0]
        file_idx = file_idx%filePoolLen #np.random.randint(0,filePoolLen)
        npy_file = np.load(self.ph_mode_dir+'/'+self.filenames_df['filename'].iloc[file_idx])
        return npy_file
    
    def get_frames(self,mode,idx):
        image_volume = self.load_volume(mode,idx)
        tot_frames = image_volume.shape[0]
        frame_idxs = np.random.randint(0,tot_frames,size=self.num_frames)
        frames = np.array(image_volume[frame_idxs,:,:])
        #print(frames.shape)
        return frames
    
    def __len__(self):
        return int(len(self.filenames)) #/np.max(1,int(self.batch_size/self.num_classes)))) #(-1) only for .DS_Store
    
    def __getitem__(self,idx):
        
        #self.start_idx = idx*self.batch_size
        #self.end_idx = self.start_idx + self.batch_size
        
        file_idx = idx

        #DECLARE VARIABLES
        batch_imgs = np.array([]).reshape((0,256,256,3))

        model1_inp = np.array([]).reshape((0,64,64,3))
        model2_inp = np.array([]).reshape((0,64,64,3))
        model3_inp = np.array([]).reshape((0,64,64,3))
        model4_inp = np.array([]).reshape((0,64,64,3))
        model5_inp = np.array([]).reshape((0,64,64,3))
        model6_inp = np.array([]).reshape((0,64,64,3))
        model7_inp = np.array([]).reshape((0,64,64,3))
        model8_inp = np.array([]).reshape((0,64,64,3))
        model9_inp = np.array([]).reshape((0,64,64,3))

        batch_labs = np.array([]).reshape((0,self.num_classes))
        
        #CREATE BATCH
        for bs in range(self.batch_size):
            #print(bs)
            #GET CLIP FRAMES
            #file_idx = idx #+ bs
            imgs = np.array([]).reshape((256,256,0))
            img = self.get_frames(self.mode,idx)
            for i in range(3):
                imgs = np.append(imgs,img.reshape((256,256,1)),axis=2)
            
            batch_imgs = np.append(batch_imgs,np.expand_dims(imgs,axis=0),axis=0)

        ppp_labels = list(self.ppp_labels_dict.keys())
        #print(ppp_labels)
        crop_window = int(256/int(np.sqrt(self.num_patches)))
        crop_window_rlx = 64 ############################int(crop_window - np.ceil(0.2*float(crop_window)))
        eachgap = int((crop_window-crop_window_rlx)/2.)

        for i in range(batch_imgs.shape[0]):
            if self.phase=='train':
                batch_imgs[i] = self.gamma_correction(batch_imgs[i])
            
            #temp_patch1 = np.zeros(shape = (crop_window_rlx,crop_window_rlx,3))
            #temp_patch2 = np.zeros(shape = (crop_window_rlx,crop_window_rlx,3))
            label_idx = np.random.choice(list(range(100)))
            #print(ppp_labels[label_idx])
            jumbling_order = list(map(int,ppp_labels[label_idx].strip('[]').split(',')))
            temp_batch_img = np.zeros((64,64,3))
            for jo in range(len(jumbling_order)):
                temp_patch1 = np.zeros((crop_window_rlx,crop_window_rlx,3))
                spatch_num = jumbling_order[jo]
                scol_num = spatch_num%int(np.sqrt(self.num_patches))
                srow_num = int(np.floor(spatch_num/int(np.sqrt(self.num_patches))))
                gapx = np.random.randint(0,eachgap)
                gapy = np.random.randint(0,eachgap)
                sourcesx = crop_window*scol_num + gapx
                sourceex = sourcesx + crop_window_rlx
                sourcesy = crop_window*srow_num + gapy
                sourceey = sourcesy + crop_window_rlx
                #sx = np.random.choice(range(int(eachgap/4.),int(eachgap/4.)+int(eachgap/2.)))
                #sy = np.random.choice(range(int(eachgap/4.),int(eachgap/4.)+int(eachgap/2.)))
                #temp_patch1 = batch_imgs[i,sourcesx:sourceex,sourcesy:sourceey,:]

                #AUGMENT FRAMES
                if self.phase=='train':
                    temp_patch1 = self.__augment(batch_imgs[i,sourcesx:sourceex,sourcesy:sourceey,:])
                else:
                    temp_patch1 = batch_imgs[i,sourcesx:sourceex,sourcesy:sourceey,:]

                if jo == 0:
                    model1_inp = np.append(model1_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 1:
                    model2_inp = np.append(model2_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 2:
                    model3_inp = np.append(model3_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 3:
                    model4_inp = np.append(model4_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 4:
                    model5_inp = np.append(model5_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 5:
                    model6_inp = np.append(model6_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 6:
                    model7_inp = np.append(model7_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 7:
                    model8_inp = np.append(model8_inp,np.expand_dims(temp_patch1,axis=0),axis=0)
                if jo == 8:
                    model9_inp = np.append(model9_inp,np.expand_dims(temp_patch1,axis=0),axis=0)

            batch_labs = np.append(batch_labs,
                                   to_categorical(int(self.ppp_labels_dict[ppp_labels[label_idx]]),self.num_classes).reshape((1,-1)),
                                   axis = 0)

        #PREPROCESS FRAMES
        model1_inp = self.preprocess_input(model1_inp)
        model2_inp = self.preprocess_input(model2_inp)
        model3_inp = self.preprocess_input(model3_inp)
        model4_inp = self.preprocess_input(model4_inp)
        model5_inp = self.preprocess_input(model5_inp)
        model6_inp = self.preprocess_input(model6_inp)
        model7_inp = self.preprocess_input(model7_inp)
        model8_inp = self.preprocess_input(model8_inp)
        model9_inp = self.preprocess_input(model9_inp)

        return [model1_inp,model2_inp,model3_inp,model4_inp,model5_inp,model6_inp,model7_inp,model8_inp,model9_inp],batch_labs
    
    def on_epoch_end(self):
        self.filenames_df = self.filenames_df.sample(frac=1).reset_index(drop=True)

    def gamma_correction(self,temp_patch):
        
        gamma_val = np.random.choice([0,1,2])
        if gamma_val == 0:
            temp_patch = np.vectorize(self.gamma_dict100.get)(temp_patch.astype('int'))
        if gamma_val == 1:
            temp_patch = np.vectorize(self.gamma_dict115.get)(temp_patch.astype('int'))
        if gamma_val == 2:
            temp_patch = np.vectorize(self.gamma_dict085.get)(temp_patch.astype('int'))

        return temp_patch
    
    def __augment(self,temp_patch):

        transforms = np.random.choice(list(self.augmentations_dict.keys()))
        transformations = self.augmentations_dict[transforms]
        temp_patch = rotate(temp_patch,transformations[0],preserve_range=True)
        temp_patch = warp(temp_patch,
                          AffineTransform(matrix=np.array([[transformations[3], 0, transformations[1]],
                                                           [0,transformations[3],  transformations[2]],
                                                           [0,         0,                   1]])).inverse,
                          preserve_range=True)
            
        if self.hor_flip:
            if np.random.choice([True,False]):
                temp_patch = np.flip(temp_patch,axis=2)

        #if self.gamma:
        

        return temp_patch

In [None]:
dg = PPPDataGen('train','sagittal',mrnet_path,tr_filenames_df,preprocess_input = ppi_irv2,ppp_labels_dict = PPP_LABELS,augmentations_dict = augmentations,batch_size=16,num_frames = 1,num_classes=NUM_CLASSES,hor_flip = False,data_aug = True)

In [None]:
ppp_imgs,ppp_labs, = dg.__getitem__(1)
print(ppp_imgs[0].shape)
print(ppp_labs.shape)

In [None]:
ppp_labs

In [None]:
samples = {}
for i in range(1000):
    _,gtp_labs = dg.__getitem__(i)
    for t in gtp_labs:
        if t not in list(samples.keys()):
            samples[int(t)]=0
        samples[int(t)]+=1
plt.bar(list(samples.keys()),list(samples.values()))
plt.show()


In [None]:
plt.imshow((ppp_imgs[3]+1)/2.0,cmap='gray')
plt.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False)
plt.tick_params(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=False,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelleft=False)

In [None]:
fig,axs=plt.subplots(4,4,figsize=(80,80))
for i in range(16):
    axs[int(i/4),i%4].imshow((ppp_imgs[0][i]+1)/2.0,cmap='gray')
    #axs[int(i/4),i%4].set_title(str(list(PPP_LABELS.keys())[ppp_labs[i]]))
plt.show()

## Sagittal

In [None]:
rate = 0.55
NUM_CLASSES = 1000

## Manual Model

In [None]:
def rocket_model(input_shape = (64,64,3)):
    model11_inp = Input(shape=input_shape)
    model11 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model11_inp)
    model11 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model11)
    model11 = MaxPool2D(pool_size = (2,2),strides=2)(model11)
    model11 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model11)
    model11 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model11)
    model11 = MaxPool2D(pool_size = (2,2),strides=2)(model11)

    model12_inp = Input(shape=input_shape)
    model12 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model12_inp)
    model12 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model12)
    model12 = MaxPool2D(pool_size = (2,2),strides=2)(model12)
    model12 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model12)
    model12 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model12)
    model12 = MaxPool2D(pool_size = (2,2),strides=2)(model12)

    model13_inp = Input(shape=input_shape)
    model13 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model13_inp)
    model13 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model13)
    model13 = MaxPool2D(pool_size = (2,2),strides=2)(model13)
    model13 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model13)
    model13 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model13)
    model13 = MaxPool2D(pool_size = (2,2),strides=2)(model13)

    model21_inp = Input(shape=input_shape)
    model21 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model21_inp)
    model21 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model21)
    model21 = MaxPool2D(pool_size = (2,2),strides=2)(model21)
    model21 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model21)
    model21 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model21)
    model21 = MaxPool2D(pool_size = (2,2),strides=2)(model21)

    model22_inp = Input(shape=input_shape)
    model22 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model22_inp)
    model22 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model22)
    model22 = MaxPool2D(pool_size = (2,2),strides=2)(model22)
    model22 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model22)
    model22 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model22)
    model22 = MaxPool2D(pool_size = (2,2),strides=2)(model22)

    model23_inp = Input(shape=input_shape)
    model23 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model23_inp)
    model23 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model23)
    model23 = MaxPool2D(pool_size = (2,2),strides=2)(model23)
    model23 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model23)
    model23 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model23)
    model23 = MaxPool2D(pool_size = (2,2),strides=2)(model23)

    model31_inp = Input(shape = input_shape)
    model31 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model31_inp)
    model31 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model31)
    model31 = MaxPool2D(pool_size = (2,2),strides=2)(model31)
    model31 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model31)
    model31 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model31)
    model31 = MaxPool2D(pool_size = (2,2),strides=2)(model31)

    model32_inp = Input(shape= input_shape)
    model32 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model32_inp)
    model32 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model32)
    model32 = MaxPool2D(pool_size = (2,2),strides=2)(model32)
    model32 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model32)
    model32 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model32)
    model32 = MaxPool2D(pool_size = (2,2),strides=2)(model32)

    model33_inp = Input(shape = input_shape)
    model33 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model33_inp)
    model33 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model33)
    model33 = MaxPool2D(pool_size = (2,2),strides=2)(model33)
    model33 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model33)
    model33 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model33)
    model33 = MaxPool2D(pool_size = (2,2),strides=2)(model33)


    model_stem = Concatenate()([model11,model12,model13,model21,model22,model23,model31,model32,model33])
    
    model_stem = Conv2D(filters = 2048, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model_stem)

    model_stem1 = Conv2D(filters = 1024, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model_stem)
    model_stem1 = Conv2D(filters = 1024, kernel_size = 3, strides = 2, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model_stem1)

    model_stem2 = Conv2D(filters = 1024, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer =regularizers.l2(0.0001))(model_stem)
    model_stem2 = MaxPool2D(pool_size=(2,2),strides=2)(model_stem2)

    model_stem12 = Concatenate()([model_stem1,model_stem2])

    model_stem12 = GlobalAveragePooling2D()(model_stem12)

    model_stem12 = Dense(1024,activation='relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer = regularizers.l2(0.0001))(model_stem12)
    model_stem12 = Dense(1024,activation='relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),kernel_regularizer = regularizers.l2(0.0001))(model_stem12)
    output = Dense(1000,activation = 'softmax')(model_stem12)

    rocket_model = Model(inputs = [model11_inp,model12_inp,model13_inp,model21_inp,model22_inp,model23_inp,model31_inp,model32_inp,model33_inp],outputs = output,name='rocket_model')
    
    return rocket_model


In [None]:
pretext_model = rocket_model((64,64,3))

In [None]:
pretext_model.summary()

In [None]:
plot_model(rocket_model, to_file='rocket_model.png',show_shapes=True)

### OPTIMIZER AND DATA GEN

In [None]:
optimizer = tf.keras.optimizers.RMSprop(tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=0.0001, decay_steps = 1130,decay_rate=0.95,staircase=True))

In [None]:
tdg = PPPDataGen('train','sagittal',mrnet_path,tr_filenames_df,preprocess_input = ppi_irv2,ppp_labels_dict = PPP_LABELS,augmentations_dict = augmentations,batch_size=32,num_frames=1,num_classes=NUM_CLASSES,hor_flip=False,data_aug=True)

In [None]:
vdg = PPPDataGen('valid','sagittal',mrnet_path,val_filenames_df,preprocess_input = ppi_irv2,ppp_labels_dict = PPP_LABELS,augmentations_dict = augmentations,batch_size=32,num_frames=1,num_classes=NUM_CLASSES,hor_flip=False,data_aug=True)

In [None]:
pretext_model.compile(optimizer = optimizer,
                     loss = tf.keras.losses.CategoricalCrossentropy(),
                     metrics = tf.keras.metrics.CategoricalAccuracy())

In [None]:
pretext_model.fit(tdg, epochs = 5, callbacks = get_callbacks('pretext','categorical'), validation_data = vdg)

## Downstream

In [None]:
pretext_model = rocket_model((256,256,3))
pretext_model.load_weights('/saved_models/sagittal_pretext_best_model.h5')

In [None]:
for l in pretext_model.layers:
    print(l.name, l.output_shape)

In [None]:
pretext_out1 = pretext_model.get_layer('max_pooling2d_1').output
pretext_out2 = pretext_model.get_layer('max_pooling2d_3').output
pretext_out3 = pretext_model.get_layer('max_pooling2d_5').output
pretext_out4 = pretext_model.get_layer('max_pooling2d_7').output
pretext_out5 = pretext_model.get_layer('max_pooling2d_9').output
pretext_out6 = pretext_model.get_layer('max_pooling2d_11').output
pretext_out7 = pretext_model.get_layer('max_pooling2d_13').output
pretext_out8 = pretext_model.get_layer('max_pooling2d_15').output
pretext_out9 = pretext_model.get_layer('max_pooling2d_17').output

In [None]:
pretext_out = Concatenate(axis=0)([pretext_out1, pretext_out2, pretext_out3,
                                   pretext_out4, pretext_out5, pretext_out6,
                                   pretext_out7, pretext_out8, pretext_out9])

In [None]:
pretext_out.shape

In [None]:
disc1 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(pretext_out)

disc1 = Conv2D(filters = 512, kernel_size = 3, strides = 2, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(disc1)

disc2 = Conv2D(filters = 1024, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(disc1)

disc2 = Conv2D(filters = 1024, kernel_size = 3, strides = 2, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(disc2)

disc3 = Conv2D(filters = 1024, kernel_size = 3, strides = 1, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(disc2)

disc3 = Conv2D(filters = 1024, kernel_size = 3, strides = 2, padding = 'same', activation = 'relu',
               kernel_initializer = tf.keras.initializers.he_normal(seed=16),
               kernel_regularizer = regularizers.l2(0.0001))(disc3)


In [None]:
gap = GlobalAveragePooling2D()(disc3)

maxoverframes = tf.keras.layers.Lambda(lambda x : tf.keras.backend.max(x,axis=0,keepdims = True))(gap)

fc1 = Dense(1024, activation='relu',kernel_initializer = tf.keras.initializers.he_normal(seed=16),
            kernel_regularizer = regularizers.l2(0.0001))(maxoverframes)

out = Dense(1, activation = 'sigmoid',kernel_initializer = tf.keras.initializers.he_normal(seed=16),
            kernel_regularizer = regularizers.l2(0.0001))(fc1)

In [None]:
dsmodel = Model(inputs = pretext_model.input , outputs = out)

In [None]:
dsmodel.summary()

## OVERSAMPLING

In [None]:
NUM_1 = np.count_nonzero(tr_multilabel['acl']==1)
NUM_0 = np.count_nonzero(tr_multilabel['acl']==0)
min_class = np.argmin(np.array([NUM_0,NUM_1]))
if int(min_class) == 0:
    gapnum = (1130-NUM_0)-NUM_0
    INDICES = tr_multilabel[tr_multilabel['acl']==0].index.values
    INDICES = np.random.choice(list(INDICES),gapnum)
else:
    gapnum = (1130-NUM_1)-NUM_1
    INDICES = tr_multilabel[tr_multilabel['acl']==1].index.values
    INDICES = np.random.choice(list(INDICES),gapnum)
    
tr_acl_multilabel = tr_multilabel.append(tr_multilabel.iloc[INDICES,:],ignore_index=True)
tr_acl_filenames_df = tr_filenames_df.append(tr_filenames_df.iloc[INDICES,:],ignore_index=True)

print(tr_acl_multilabel)
print(tr_acl_filenames_df)

## Downstream Data Generator

In [None]:
class DSDataGen(Sequence):
    def __init__(self, phase, base_dir, labs_df, filenames_df, injury, preprocess_input = None,batch_size=8, max_batch_size = 32, data_aug = True, num_frames = NUM_FRAMES, num_classes=NUM_CLASSES):
        self.base_dir = base_dir
        self.ph_mode_dir = base_dir+'/'+phase
        self.filenames = os.listdir(self.ph_mode_dir)
        self.phase = phase
        self.batch_size = batch_size
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.filenames_df = filenames_df
        self.preprocess_input = preprocess_input
         
        self.mode = ['sagittal','coronal','axial']
        self.injury = injury

        self.mllabs = labs_df

        self.indices = list(range(self.filenames_df.shape[0]))
        
        self.data_aug = data_aug


    def get_random_shuffle_order(self,batch_sz):
        blist = list(range(batch_sz))
        random.shuffle(blist)
        #print(blist)
        return blist
    
    def load_volume(self,mode,file_idx):
        filePoolLen = self.filenames_df.shape[0]
        #print(file_idx)
        file_idx = file_idx%filePoolLen 
        npy_file = np.load(self.ph_mode_dir+'/'+mode+'/'+self.filenames_df['filename'].iloc[file_idx])
        return npy_file
    
    def get_frames(self,mode,idx):
        image_volume = self.load_volume(mode,self.indices[idx])
        tot_frames = image_volume.shape[0]
        #print(tot_frames)
        #print(mode,tot_frames)
        self.num_frames = min([self.num_frames, tot_frames])
        sampling_interval = int(tot_frames/self.num_frames)
        
        nf_mid = int(self.num_frames/2)
        nf_lr = int(nf_mid/2)
        left_sec_end = int(tot_frames/2) - int(tot_frames/4)
        right_sec_start = int(tot_frames/2) + int(tot_frames/4)

        left_frames = np.array(sorted(random.sample(range(left_sec_end),nf_lr)))
        right_frames = np.array(sorted(random.sample(range(right_sec_start,tot_frames),nf_lr)))
        mid_frames = np.array(sorted(random.sample(list(range(left_sec_end,left_sec_end+nf_mid)),nf_mid)))
        #print(left_frames,mid_frames,right_frames)
        frame_idxs = np.append(np.append(left_frames,mid_frames),right_frames)
        #print(frame_idxs)
        
        #frame_idxs = sorted(random.sample(list(range(tot_frames)),self.num_frames))

        frames = np.array([]).reshape((0,256,256,3))
        for n in range(frame_idxs.shape[0]):
            frame_idx = frame_idxs[n] #np.random.randint(n*sampling_interval,(n+1)*sampling_interval,size=1)
            frame = np.array(image_volume[[frame_idx],:,:])
            frame = np.expand_dims(frame,axis=3)
            frame = np.append(frame,np.append(frame,frame,axis=3),axis=3)
            frames = np.append(frames,frame,axis=0)
        #print(frames.shape)
        return frames
    
    def __len__(self):
        return int(np.floor((len(self.filenames_df))/self.batch_size))
    
    def __getitem__(self,idx):
        
        #DECLARE VARIABLES
        sagittal_batch_imgs = np.array([]).reshape((0,256,256,3))
        
        ds_batch_labs = np.array([]).reshape((0,1))
        
        #CREATE BATCH
        for bs in range(self.batch_size):
            #GET CLIP FRAMES
            sagittal_batch_imgs = np.append(sagittal_batch_imgs,self.get_frames('sagittal',idx),axis=0)
            
            ds_batch_labs = self.mllabs[self.injury].iloc[self.indices[idx]].reshape((1,-1))

        #print(batch_imgs.shape)

        #AUGMENT FRAMES
        if self.data_aug:
            sagittal_batch_imgs = self.__augment(sagittal_batch_imgs)

        #PREPROCESS FRAMES
        sagittal_batch_imgs = self.preprocess_input(sagittal_batch_imgs)
        
        inputs = []
        numf = self.num_frames//9
        #print(self.num_frames,numf)
        for i in range(9):
            s = i*numf
            e = (i+1)*numf
            inputs.append(sagittal_batch_imgs[s:e])
        
        return inputs, ds_batch_labs
    
    def on_epoch_end(self):
        random.shuffle(self.indices)
    
    def __augment(self,batch_imgs):
        num_imgs = batch_imgs.shape[0]
        rotang = np.random.choice([-30,0,30])
        scale = np.random.choice([1,1.2])
        transformation_matrix=np.array([[scale,           0,             np.random.choice([-25,0,25])],
                                        [0,               scale,         np.random.choice([-25,0,25])],
                                        [0,               0,                       1                 ]])
        for i in range(num_imgs):
            batch_imgs[i] = rotate(batch_imgs[i],rotang,preserve_range=True)
            batch_imgs[i] = warp(batch_imgs[i], AffineTransform(matrix=transformation_matrix).inverse, preserve_range=True)
            #print(batch_labs)
                
        return batch_imgs

In [None]:
tdg = DSDataGen('train',mrnet_path,tr_acl_multilabel,tr_acl_filenames_df,'acl',
                preprocess_input = ppi_irv2,batch_size=1,data_aug = True,num_frames=36,num_classes=1)

In [None]:
vdg = DSDataGen('valid',mrnet_path,val_multilabel,val_filenames_df,'acl',
                preprocess_input = ppi_irv2,batch_size=1,data_aug = False,num_frames=36,num_classes=1)

In [None]:
dsopt = tf.keras.optimizers.Adam(0.00001)

In [None]:
dsmodel.compile(optimizer = dsmodel.optimizer,
               loss = tf.keras.losses.BinaryCrossentropy(),
               metrics = [tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.AUC()])

In [None]:
dshist = dsmodel.fit(tdg, epochs = 30, validation_data = vdg, callbacks = get_callbacks('downstream','binary'))

In [None]:
dsmodel = tf.keras.models.load_model('/saved_models/sagittal_downstream_best_model.h5',compile=True)

In [None]:
preds = np.array([]).reshape((0,1))
labs = np.array([]).reshape((0,1))

In [None]:
for i in range(len(val_filenames_df)):
    inps, lab = vdg.__getitem__(i)
    pred = dsmodel(inps)
    preds = np.append(preds,pred.numpy().reshape((-1,1)),axis=0)
    labs = np.append(labs,lab,axis=0)

In [None]:
preds.shape

In [None]:
perf_df = get_performance_metrics(labs,preds,['acl'])

In [None]:
perf_df