# SETUP

The first part of the code sets up the pca_wavelet network, the training comes later. Most of this code comes from the original authors

In [1]:
import sys
import time
sys.path.append('../segmentation_helper')

import tensorflow as tf
import keras
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import data_loader as dl
import model_broker as mb
import os
import pandas as pd

tf.keras.backend.set_floatx("float64")

GPU device not found
Found GPU at: 


In [2]:
def scaledtanh(x): 
    return tf.math.tanh(x*0.5)

def scaledatanh(x):
    return tf.math.atanh(x)*2.0


In [3]:
def conv_calculate_A_and_b(imghead, seghead, img_train,seg_train):
    n = 0.0

    imdecom_shape = imghead(next(iter(img_train))[0]).shape
    img_channels = imdecom_shape[3] # shape_0
    imdecom_2d_shape = imdecom_shape[1]*imdecom_shape[2] #shape_1

    seg_channels = segdecom_shape[3] # shape_0
    segdecom_2d_shape = segdecom_shape[1]*segdecom_shape[2] #shape_1
    segdecom_shape = seghead(next(iter(seg_train))[0]).shape

    xxt = np.zeros([img_channels,img_channels])
    yxt = np.zeros([img_channels,seg_channels])
    x = np.ones([imdecom_2d_shape])
    x_m = np.zeros([img_channels])
    y = np.ones([segdecom_2d_shape]) 
    y_m = np.zeros([seg_channels])

    bar = tqdm.notebook.tqdm(total = int(img_train.cardinality()))

    for item in iter(zip(img_train,seg_train)):
        bar.update(1)
        image = item[0][0]
        segmentation = item[1][0]

        imgdecom = imghead(image)
        segdecom = seghead(segmentation)

        mat = tf.reshape(imgdecom,[-1,seg_channels])
        segmat = tf.reshape(segdecom,[-1,img_channels])

        cov = tf.tensordot(mat,mat,[0,0])
        xxt += cov
        #del cov

        segcov = tf.tensordot(mat,segmat,[0,0])
        yxt += segcov
        #del segcov

        x_m += tf.linalg.matvec(mat,x,transpose_a=True)
        y_m += tf.linalg.matvec(segmat,y,transpose_a=True)

        n += 1
    return A,b

In [4]:
def connected_calculate_A_and_b(imghead, seghead, img_train,seg_train):
    imgflat = np.prod(imghead(next(iter(img_train))[0]).shape)
    segflat = np.prod(seghead(next(iter(seg_train))[0]).shape)
    end_shape = next(iter(seg_train))[0].shape
    n = 0.0

    xxt = np.zeros([imgflat])
    yxt = np.zeros([segflat])
    x = np.zeros([imgflat])
    y = np.zeros([segflat]) 

    bar = tqdm.notebook.tqdm(total = int(img_train.cardinality()))

    for item in iter(zip(img_train,seg_train)):

        bar.update(1)

        image = item[0][0]
        segmentation = item[1][0]

        imgdecom = imghead(image)
        segdecom = seghead(segmentation)

        mat = tf.reshape(imgdecom,[-1])
        segmat = tf.reshape(segdecom,[-1])

        cov = tf.matmul([mat],[mat],transpose_a=True)
        xxt += cov
        segcov = tf.matmul([mat],[segmat],transpose_a=True)
        yxt += segcov
        x+=mat
        y+=segmat
        n += 1
        
    print("loop calculated")
    xxt = xxt - tf.matmul([x],[x],transpose_a=True)/n
    yxt = yxt - tf.matmul([x],[y],transpose_a=True)/n
    print("calculating inverse")
    inverse_xxt = tf.linalg.pinv(xxt)
    print("calculating A")
    A = tf.linalg.matmul(inverse_xxt,yxt)
    print("calculating b")
    b = (y - tf.linalg.matvec(A,x,transpose_a=True))/n
    return A,b

In [5]:
def build_model_instance(train,
                test,
                dataset,
                model_name,
                keep_percent=1.0,
                count=3,
                sample_size=100,
                activity_regularizer=None,
                inverse_activity_regularizer=None,
                activation_before=False,
                check_build=False):
    
    stats = None
    
    broker = mb.ModelBroker(trainset=train,
                                testset=test,
                                dirname=dataset+"_"+model_name,
                                keep_percent=keep_percent,
                                count=count,
                                sample_size=sample_size,
                                activity_regularizer = activity_regularizer,
                                inverse_activity_regularizer=inverse_activity_regularizer,
                                activation_before=activation_before)
    
    head,invhead = broker.build_model()
    head,invhead = broker.load_model()    
    if check_build:
        train_psnr,train_ncc = broker.check_build(head,invhead,train,stats_only=True)
        test_psnr,test_ncc = broker.check_build(head,invhead,test,stats_only=True)
        stats = (train_psnr,train_ncc,test_psnr,test_ncc)
    return head,invhead, stats

In [16]:
def conv_metric_calculate(img_ds,seg_ds):
    threshold_intensity = 0.01
    dice_coeff_vals = []
    iou_coeff_vals = []
    n = 0
    for image,seg_base in iter(zip(img_train,seg_train)):
        imgdecom = imghead(image[0])
        conv = tf.nn.conv2d(imgdecom, A_filter,1,"VALID")
        conv = tf.nn.bias_add(conv,b)
        seg = seginvhead(conv)
        y_true = tf.cast(tf.reduce_min(seg_base[0],2)==0,tf.float64)
        y_pred = tf.cast(tf.reduce_min(seg[0],2)<threshold_intensity,tf.float64)
        dice_coeff_vals.append(dice_coef(y_true,y_pred))
        iou_coeff_vals.append(iou_coef(y_true,y_pred))
        n+=1
    return iou_coeff_vals,dice_coeff_vals,n


In [17]:
def connected_metric_calculate(img_ds,seg_ds):
    threshold_intensity = 0.01
    dice_coeff_vals = []
    iou_coeff_vals = []
    n = 0
    reconstruct = seghead(next(iter(seg_ds))[0]).shape
    for image,seg_base in iter(zip(img_ds,seg_ds)):
        imgdecom = imghead(image[0])
        imgdecom = tf.reshape(imgdecom,(1,-1))
        segdecom = tf.linalg.matvec(A,imgdecom,transpose_a=True)+b
        seg = seginvhead(tf.reshape(segdecom,(reconstruct)))
        y_true = tf.cast(tf.reduce_min(seg_base[0],2)==0,tf.float64)
        y_pred = tf.cast(tf.reduce_min(seg[0],2)<threshold_intensity,tf.float64)
        dice_coeff_vals.append(dice_coef(y_true,y_pred))
        iou_coeff_vals.append(iou_coef(y_true,y_pred))
        n+=1
    return iou_coeff_vals,dice_coeff_vals,n

In [18]:
def calculate_metrics(seg_ds,img_ds,imghead,seghead,seginvhead, method):
    iou_coeff_vals,dice_coeff_vals,n = method(img_ds,seg_ds)
    dice_coeff_mean = sum(dice_coeff_vals)/n
    iou_coeff_mean = sum(iou_coeff_vals)/n
    dice_coeff_std = (sum([((x - dice_coeff_mean) ** 2) for x in dice_coeff_vals]) / n)**0.5
    iou_coeff_std = (sum([((x - iou_coeff_mean) ** 2) for x in iou_coeff_vals]) / n)**0.5
    return dice_coeff_mean, iou_coeff_mean, dice_coeff_std, iou_coeff_std

In [19]:
def dice_coef(y_true, y_pred,smooth=1):
    y_true_f = tf.reshape(y_true,-1)
    y_pred_f =tf.reshape(y_pred,-1)
    intersection = tf.reduce_sum(y_true_f * y_pred_f,0)

    return float((2. * intersection+smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f)+smooth))

In [20]:
def iou_coef(y_true, y_pred,smooth=1):
  intersection = tf.reduce_sum(y_true * y_pred, 0)
  union = tf.reduce_sum(y_true,0)+tf.reduce_sum(y_pred,0)-intersection
  iou = tf.reduce_mean((intersection+1) / (union+1), 0)
  return float(iou)

In [23]:
counts_exp = [("count",i) for i in range(3,5)]
keep_percents_exp = [("keep_percent",i/10) for i in range(1,5)]
train_sizes_exp = [("train_size",i*2000) for i in range(1,4)]
res_exp = []#[("res",2**i) for i in range(6,9)]
experiments = counts_exp + keep_percents_exp + train_sizes_exp + res_exp

In [None]:

test_size = 300

#These are the standard experiment settings
activity_regularizer = scaledtanh
inverse_activity_regularizer = scaledatanh
count = 3
keep_percent = 0.1
train_size = None
res = 128
dataset = "pets"

calculate_A_and_b = connected_calculate_A_and_b

df = pd.DataFrame()

for settings in tqdm.notebook.tqdm(experiments):
    
    record = pd.Series()
    
    variable, value = settings
    if variable == "count":
        count = value
    if variable == "keep_percent":
        keep_percent = value
    if variable == "train_size":
        train_size = value
    if variable == "res":
        resolution = value

    loader = dl.DataLoader(IMAGE_SIZE=res,dataset=dataset,take=train_size)
    img_ds = loader.import_processed_img()
    seg_ds = loader.import_processed_seg()
    cardinality = int(img_ds.cardinality())

    img_test = img_ds.take(test_size)
    seg_test = seg_ds.take(test_size)
    img_train = img_ds.skip(test_size)
    seg_train = seg_ds.skip(test_size)    
    
    record["count"]=count
    record["keep_percent"] = keep_percent
    record["activity_regularizer"] = activity_regularizer != None
    record["training_data_size"] = train_size
    
    img_train_start = time.time()
    imghead, imginvhead,stats = build_model_instance(img_train,img_test,dataset,"img",keep_percent = keep_percent,count=count,check_build=True)
    psnr_train,ncc_train,psnr_test,ncc_test = stats
    img_train_end = time.time()
    
    record["img_channel_size"] = imghead(next(iter(img_train))[0]).shape[-1]
    record["img_train_time"] = img_train_end - img_train_start
    record["train_img_psnr"] = psnr_train
    record["train_img_ncc"] = ncc_train
    record["test_img_psnr"] = psnr_test
    record["test_img_ncc"] = ncc_test
    
    seg_train_start = time.time()
    seghead, seginvhead,stats = build_model_instance(seg_train,seg_test,dataset,"seg",count=count,keep_percent = keep_percent,check_build=True)
    psnr_train,ncc_train,psnr_test,ncc_test = stats
    seg_train_end = time.time()
    
    record["seg_channel_size"] = seghead(next(iter(seg_train))[0]).shape[-1]
    record["seg_train_time"] = img_train_end - img_train_start
    record["train_seg_psnr"] = psnr_train
    record["train_seg_ncc"] = ncc_train
    record["test_seg_psnr"] = psnr_test
    record["test_seg_ncc"] = ncc_test
    
    train_time_start = time.time()
    A,b = calculate_A_and_b(imghead,seghead,img_train,seg_train)
    train_time_end = time.time()
    
    record["linear_inverse_train_time"] = train_time_end - train_time_start 
    
    dice_mean_test, iou_mean_test, dice_std_test, iou_std_test = calculate_metrics(seg_test,img_test,imghead,seghead,seginvhead, conv_metric_calculate)
    dice_mean_train, iou_mean_train, dice_std_train, iou_std_train = calculate_metrics(seg_train,img_train,imghead,seghead,seginvhead, connected_metric_calculate)
    
    record["dice_mean_train"] = dice_mean_train
    record["iou_mean_train"] = iou_mean_train
    record["dice_std_train"] = dice_std_train
    record["iou_std_train"] = iou_std_train
    record["dice_mean_test"] = dice_mean_test
    record["iou_mean_test"] = iou_mean_test
    record["dice_std_test"] = dice_std_test
    record["iou_std_test"] = iou_std_test
    
    df = df.append(record,ignore_index=True)

  0%|          | 0/9 [00:00<?, ?it/s]

  record = pd.Series()


keep_percent 0.20629283704945683
meanimg.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
Starting level 0
Completing 64.0
pca shape tf.Tensor([27 27], shape=(2,), dtype=int32)
keep_channels 5 keep_max 12.0
keep_channels 5
ufilts.shape (1, 1, 1, 27, 5)
end loop 64.0
Starting level 1
Completing 32.0
pca shape tf.Tensor([45 45], shape=(2,), dtype=int32)
keep_channels 9 keep_max 80.0
keep_channels 9
ufilts.shape (1, 1, 1, 45, 9)
end loop 32.0
Starting level 2
Completing 16.0
pca shape tf.Tensor([81 81], shape=(2,), dtype=int32)
keep_channels 16 keep_max 576.0
keep_channels 16
ufilts.shape (1, 1, 1, 81, 16)
end loop 16.0
saving to: models/pets_img
out.shape (1, 16, 16, 16)
keep_percent 0.20629283704945683
meanimg.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
Starting level 0
Completing 64.0
pca shape tf.Tensor([27 27], shape=(2,), dtype=int32)
keep_channels 5 keep_max 12.0
keep_channels 5
ufilts.

array([[[0.75257689, 0.6329248 , 0.51152802],
        [0.80355394, 0.67678201, 0.53089267],
        [0.8363238 , 0.71473986, 0.5422368 ],
        ...,
        [0.7453891 , 0.68484151, 0.59893322],
        [0.79147881, 0.66086167, 0.5059247 ],
        [0.87138146, 0.75016683, 0.58671206]],

       [[0.79103935, 0.6632157 , 0.54447073],
        [0.80558962, 0.67709655, 0.52565682],
        [0.7719363 , 0.65932906, 0.51291245],
        ...,
        [0.71955758, 0.63062316, 0.53655267],
        [0.8280946 , 0.68232256, 0.54134595],
        [0.91678393, 0.78020835, 0.64809018]],

       [[0.76215845, 0.64369255, 0.49240434],
        [0.83149129, 0.71384424, 0.56482464],
        [0.8471536 , 0.73336279, 0.56863344],
        ...,
        [0.58904696, 0.43896008, 0.2372956 ],
        [0.83756989, 0.70129973, 0.52172375],
        [0.90942645, 0.77951616, 0.6313836 ]],

       ...,

       [[0.67408615, 0.57938206, 0.36857191],
        [0.47982442, 0.31608072, 0.22587055],
        [0.55301565, 0

array([[[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        ...,
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569]],

       [[0.03921569, 0.03921569, 0.03921569],
        [0.23937032, 0.4215574 , 0.18896006],
        [0.22728512, 0.42014661, 0.19209871],
        ...,
        [0.24852918, 0.45332271, 0.2013016 ],
        [0.25310513, 0.43315765, 0.21225035],
        [0.03921569, 0.03921569, 0.03921569]],

       [[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        ...,
        [0.03921569, 0.04090313, 0.03921569],
        [0.03921569, 0.03949095, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569]],

       ...,

       [[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0

array([[[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        ...,
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569]],

       [[0.03921569, 0.03921569, 0.03921569],
        [0.23937032, 0.4215574 , 0.18896006],
        [0.22728512, 0.42014661, 0.19209871],
        ...,
        [0.24852918, 0.45332271, 0.2013016 ],
        [0.25310513, 0.43315765, 0.21225035],
        [0.03921569, 0.03921569, 0.03921569]],

       [[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        ...,
        [0.03921569, 0.04090313, 0.03921569],
        [0.03921569, 0.03949095, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569]],

       ...,

       [[0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0.03921569, 0.03921569],
        [0.03921569, 0

keep_percent 0.20629283704945683
meanimg.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
Starting level 0
Completing 64.0
pca shape tf.Tensor([27 27], shape=(2,), dtype=int32)
keep_channels 5 keep_max 12.0
keep_channels 5
ufilts.shape (1, 1, 1, 27, 5)
end loop 64.0
Starting level 1
Completing 32.0
pca shape tf.Tensor([45 45], shape=(2,), dtype=int32)
keep_channels 9 keep_max 80.0
keep_channels 9
ufilts.shape (1, 1, 1, 45, 9)
end loop 32.0
Starting level 2
Completing 16.0
pca shape tf.Tensor([81 81], shape=(2,), dtype=int32)
keep_channels 16 keep_max 576.0
keep_channels 16
ufilts.shape (1, 1, 1, 81, 16)
end loop 16.0
saving to: models/pets_seg
out.shape (1, 16, 16, 16)
keep_percent 0.20629283704945683
meanimg.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
self.mean.dtype <dtype: 'float64'>
Starting level 0
Completing 64.0
pca shape tf.Tensor([27 27], shape=(2,), dtype=int32)
keep_channels 5 keep_max 12.0
keep_channels 5
ufilts.

array([[[0.9215619 , 0.93773502, 0.95513868],
        [0.93311363, 0.94665527, 0.96270919],
        [0.91230255, 0.92977417, 0.94153887],
        ...,
        [0.97255594, 0.97256434, 0.98040748],
        [0.97020501, 0.97627145, 0.98695356],
        [0.97112608, 0.97504765, 0.98289078]],

       [[0.92536259, 0.94020659, 0.94943345],
        [0.92122996, 0.93691623, 0.94868094],
        [0.88733751, 0.90299743, 0.91479278],
        ...,
        [0.96360296, 0.9750613 , 0.98039216],
        [0.95545346, 0.97113973, 0.9750613 ],
        [0.95495296, 0.97063923, 0.9745608 ]],

       [[0.92078739, 0.9443168 , 0.9443168 ],
        [0.9254902 , 0.94117647, 0.95294118],
        [0.9294005 , 0.94509804, 0.95686275],
        ...,
        [0.96863151, 0.97646344, 0.9725756 ],
        [0.9691636 , 0.97466302, 0.97777271],
        [0.96491343, 0.96961623, 0.97511566]],

       ...,

       [[0.81596899, 0.44522753, 0.03921569],
        [0.83071268, 0.43768814, 0.03921569],
        [0.81676269, 0

array([[[0.86446941, 0.86054784, 0.63744569],
        [0.86661547, 0.86232626, 0.63346595],
        [0.87126535, 0.86342221, 0.63633889],
        ...,
        [0.99001223, 0.99001223, 0.99001223],
        [0.99215686, 0.99215686, 0.99215686],
        [0.9939338 , 0.9939338 , 0.9939338 ]],

       [[0.86085111, 0.86173433, 0.63228691],
        [0.85776585, 0.86168742, 0.63561463],
        [0.82080197, 0.82501465, 0.62198156],
        ...,
        [0.99356616, 0.99356616, 0.99356616],
        [0.99466914, 0.99466914, 0.99466914],
        [0.99416864, 0.99416864, 0.99416864]],

       [[0.87293202, 0.87293202, 0.62403494],
        [0.88223594, 0.88443631, 0.65145963],
        [0.88156766, 0.88548923, 0.65227842],
        ...,
        [0.99111521, 0.99111521, 0.99111521],
        [0.99111521, 0.99111521, 0.99111521],
        [0.99111521, 0.99111521, 0.99111521]],

       ...,

       [[0.94486278, 0.9359588 , 0.59473085],
        [0.97576904, 0.94715577, 0.67112076],
        [0.84795642, 0

array([[[0.86446941, 0.86054784, 0.63744569],
        [0.86661547, 0.86232626, 0.63346595],
        [0.87126535, 0.86342221, 0.63633889],
        ...,
        [0.99001223, 0.99001223, 0.99001223],
        [0.99215686, 0.99215686, 0.99215686],
        [0.9939338 , 0.9939338 , 0.9939338 ]],

       [[0.86085111, 0.86173433, 0.63228691],
        [0.85776585, 0.86168742, 0.63561463],
        [0.82080197, 0.82501465, 0.62198156],
        ...,
        [0.99356616, 0.99356616, 0.99356616],
        [0.99466914, 0.99466914, 0.99466914],
        [0.99416864, 0.99416864, 0.99416864]],

       [[0.87293202, 0.87293202, 0.62403494],
        [0.88223594, 0.88443631, 0.65145963],
        [0.88156766, 0.88548923, 0.65227842],
        ...,
        [0.99111521, 0.99111521, 0.99111521],
        [0.99111521, 0.99111521, 0.99111521],
        [0.99111521, 0.99111521, 0.99111521]],

       ...,

       [[0.94486278, 0.9359588 , 0.59473085],
        [0.97576904, 0.94715577, 0.67112076],
        [0.84795642, 0

  0%|          | 0/7049 [00:00<?, ?it/s]

loop calculated
calculating inverse


In [None]:
df.to_csv("formal_experiment_connected")

In [None]:
import random
reconstruct = seghead(next(iter(seg_train))[0]).shape
threshold_intensity = 0.01
skip = random.randint(0,70)
image,seg_base = next(iter(zip(img_test.skip(skip),seg_test.skip(skip))))
imgdecom = imghead(image[0])
imgdecom = tf.reshape(imgdecom,(1,-1))
segdecom = tf.linalg.matvec(A,imgdecom,transpose_a=True)+b
seg = seginvhead(tf.reshape(segdecom,(reconstruct)))
y_true = tf.cast(tf.reduce_min(seg_base[0],2)==0,tf.float64)
y_pred = tf.cast(tf.reduce_min(seg[0],2)<threshold_intensity,tf.float64)
plt.subplot(2,1,1)
plt.imshow(np.hstack([image[0],seg_base[0],seg[0]]))
plt.subplot(2,1,2)
plt.imshow(np.hstack([y_true,y_pred]))
