In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


import os
import sys
import h5py
import time
import numpy as np
import pandas as pd
from PIL import Image
from IPython.display import display as ipy_display, clear_output
import matplotlib.pyplot as plt

In [2]:
# setting path
sys.path.append('../')
deplabv3 = __import__('Deeplabv3')

In [3]:
# setting path
sys.path.append('../')
import networks

In [4]:
source_dataset =  __import__('step1_source_dataset')

In [5]:
target_dataset =  __import__('step1_target_dataset')

# parameters

In [6]:
num_classes = 5
batch_size = 16
suffix = 'run4'
epoch = 130000
dataset_name = 'abdomen'

# CUDA

In [7]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
device = 'cuda:3' 
device

'cuda:3'

# Path

In [8]:
source_ct_train_dir = "../../data/h5py/"
source_ct_test_dir = "../../data/h5py/"

In [9]:
target_mr_train_dir = "../../data/h5py/"
target_mr_test_dir = "../../data/h5py/"

# label_ids_abdomen

In [10]:
label_ids_abdomen = {"ignore": 0,
    "lv_myo": 1,
    "la_blood": 2,
    "lv_blood": 3,
    "aa": 4,
}
label_ids = label_ids_abdomen

# Load model 

In [11]:
dpv3 = deplabv3.DeepLabV3(num_classes)
classifier = networks.classifier(num_classes)

dpv3 = dpv3.to(device)
classifier = classifier.to(device)

# parallel
dpv3 = torch.nn.DataParallel(dpv3)
classifier = torch.nn.DataParallel(classifier)

In [12]:
dpv3_checkpoint = torch.load('../record-data/' + 'dpv3_weights_' + str(epoch) + '.pth')
classifier_checkpoint = torch.load('../record-data/' + 'classifier_weights_' + str(epoch) + '.pth')

dpv3.load_state_dict(dpv3_checkpoint)
classifier.load_state_dict(classifier_checkpoint)

dpv3 = dpv3.module
classifier = classifier.module

dpv3 = dpv3.to(device)
classifier = classifier.to(device)

print("Loaded model weights")

Loaded model weights


In [13]:
class Combined_Model(nn.Module):
    
    def __init__(self, Unet, classifier):
        super(Combined_Model, self).__init__()
        
        self.Unet = Unet
        self.classifier = classifier
        
    def forward(self, x):
        
        output1 = self.Unet(x)
        output2 = self.classifier(output1)
        
        return output2

In [14]:
# Create the combined model
combined = Combined_Model(dpv3, classifier)
combined = combined.to(device)

# Dataset

In [15]:
dataloader = source_dataset.get_dataloader( source_ct_train_dir,  source_ct_test_dir,  num_classes, batch_size, domain = 'source' )

source_train_dataset = dataloader["train"].dataset
source_test_dataset = dataloader["test"].dataset

In [16]:
dataloader = target_dataset.get_dataloader( target_mr_train_dir,  target_mr_test_dir, num_classes, batch_size, domain = 'target' )

target_train_dataset = dataloader["train"].dataset
target_test_dataset = dataloader["test"].dataset

# learn gaussians

In [17]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

In [18]:
def learn_gaussians(data, y_true, z_model, sm_model, batch_size, label_ids, rho=.97, initial_means=None):    
    
    # z_model: deeplab v3
    # sm_model: combined
    
    num_classes = len(label_ids)
    
    means = initial_means
    if initial_means is None:
        means = np.zeros((num_classes, num_classes)) # (5,5)
        
    covs = np.zeros((num_classes, num_classes, num_classes)) # shapeL (5, 5, 5)
    cnt = np.zeros(num_classes) # shape: (5,)
    
    
    data = data.transpose((2, 0, 1)) # Before is [256, 256, 32] after is [32, 256, 256]
    data = np.expand_dims(data, axis=0) # [1, 32, 256, 256]
    data = np.expand_dims(data, axis=0) # [1, 1, 32, 256, 256]
    
    y_true = y_true.transpose((2, 0, 1)) # Before is [256, 256, 32] after is [32, 256, 256]
    y_true = np.expand_dims(y_true, axis=0) # [1, 32, 256, 256]
    
    
    assert len(data.shape) == 5
    assert len(y_true.shape) == 4
    
    assert data.shape[2] == 32
    assert y_true.shape[1] == 32
    
    num_of_correct_prediction_each_classes = np.zeros(5, dtype=int)
    
    z_model.eval()
    sm_model.eval()

    with torch.no_grad():

        input_data = torch.from_numpy(data).to(device) # input_data:  torch.Size([1, 1, 32, 256, 256]) 
        input_data = input_data.float()

        
        # Predict latent features
        zs = z_model(input_data) # (8, 5, 32, 256, 256)
        zs = zs.cpu()
        zs = zs.permute(0, 2, 3, 4, 1)
        zs = zs.reshape(-1, num_classes).detach().numpy() # (8 * 32 * 256 * 256, 5)


        # Get softmax outputs
        # (48, 5, 32, 256, 256) => (32 * 256 * 256 * 8, 5)
        y_hat = sm_model(input_data).softmax(dim = 1).cpu()
        y_hat = y_hat.permute(0, 2, 3, 4, 1)
        y_hat = y_hat.reshape(-1, num_classes).detach().numpy()
            
        # downwards across rows (axis 0) 
        # running horizontally across columns (axis 1) = (axis=-1)
        # Because of continual slice, the vmax likes  [0.2908752  0.29087755 0.29087812 ... 0.29117948 0.2911795  0.29118514]
        vmax = np.max(y_hat, axis = 1) 
        y_hat = np.argmax(y_hat, axis = 1) 
        
        # before ravel: (8, 32, 256, 256);
        y_t = y_true.ravel()
        
        # Keep a few exemplars per class
        for label in label_ids:
            c = label_ids[label]

            ind = (y_t == c) & (y_hat == c) & (vmax > rho)
            
            if np.sum(ind) > 0:
                # We have at least one sample
                curr_data = zs[ind]
                num_of_correct_prediction_each_classes[c] = np.sum(ind)

                if initial_means is None:
                    # Only update means and counts

                    # 1d + 1d == (5, ) + (5, )
                    means[c] += np.sum(curr_data, axis=0)
                    cnt[c] += np.sum(ind)

                else:
                    # ! here means are scaled to their final values
                    # Example: 
                    # curr_data.shape = (8, 5); 8 is the sum of ind
                    # means[c].shape:  (5,)
                    # curr_data - means[c]:  (8, 5)
                    # np.transpose(curr_data - means[c]:  (5, 8)
                    # np.dot = (5, 8) * (8, 5) = (5, 5)

                    sigma = np.dot(np.transpose(curr_data - means[c]), curr_data - means[c])
                    assert sigma.shape == (num_classes, num_classes)
                    covs[c] += sigma
                    cnt[c] += np.sum(ind)
        

    
    # Normalize results
    for i in range(num_classes):
        if initial_means is None:
            means[i] /= (cnt[i] + 1e-10)
        covs[i] /= (cnt[i] - 1)
        
    assert np.isnan(means).any() == False
    assert np.isnan(cnt).any() == False
    
    return means, covs, cnt, num_of_correct_prediction_each_classes

# Find the most similarity based on label space

In [19]:
filename = 'Gaussian_similarity_dataframe_27779.csv'

if os.path.exists(filename):
    
    df = pd.read_csv(filename)
    restart_image_idx = df.iloc[-1]['source_images_idx']
    restart_i = df.iloc[-1]['source_i']
    restart_j = df.iloc[-1]['source_j']
    restart_k = df.iloc[-1]['source_k']
    
else:
    df = pd.DataFrame(columns=["source_images_idx", "source_i", "source_j", "source_k", 
                               "means", "covs", "label_frequency", "perdict_frequency"])
    restart_image_idx = 0
    restart_i = 0
    restart_j = 0
    restart_k = 0

In [20]:
counter = 0

for source_idx in range( restart_image_idx, len(source_train_dataset) ):
    
    source_data, source_label = source_train_dataset[source_idx] 
    shape_source = source_label.shape # (126, 113, 203)
    crop_size = (32, 32, 32)

    total_steps_i = shape_source[0] - crop_size[0] + 1
    step_size_i = total_steps_i // 10  # Integer division to get whole number steps

    total_steps_j = shape_source[1] - crop_size[1] + 1
    step_size_j = total_steps_j // 10  # Integer division to get whole number steps

    total_steps_z = shape_source[2] - crop_size[2] + 1
    step_size_z = total_steps_z // 10

    # Slide the smaller array across the larger one
    if step_size_i != 0:
        x_range = [i for i in range( 0,  total_steps_i, step_size_i)]
        if x_range[-1] != (shape_source[0] - crop_size[0]):
            x_range.append(shape_source[0] - crop_size[0])
    else:
        x_range = [i for i in range( 0,  total_steps_i, 1)]
    
    if step_size_j != 0:
        y_range = [j for j in range( 0,  total_steps_j, step_size_j)]
        if y_range[-1] != (shape_source[1] - crop_size[1]):
            y_range.append(shape_source[1] - crop_size[1])
    else:
        y_range = [j for j in range( 0,  total_steps_j, 1)]
    
    if step_size_z != 0:
        z_range = [z for z in range( 0,  total_steps_z, step_size_z)]
        if z_range[-1] != (shape_source[2] - crop_size[2]):
            z_range.append(shape_source[2] - crop_size[2])
    else:
        z_range = [z for z in range( 0,  total_steps_z, 1)]
    
    # Slide the smaller array across the larger one
    for i in x_range:
        
        if i < restart_i: continue
        
        for j in y_range:
            
            if j < restart_j: continue
            
            for k in z_range:
                
                if k < restart_k: continue

                print(source_idx, i, j, k)
                print(len(df))
                counter += 1
                # Extract  
                data = source_data[i:i+crop_size[0], j:j+crop_size[1], k:k+crop_size[2]] # (32, 32, 32)
                y_true = source_label[i:i+crop_size[0], j:j+crop_size[1], k:k+crop_size[2]]

                
                means, _ , ct, _ = learn_gaussians(data, y_true, dpv3,\
                                                combined, batch_size, label_ids, rho=0.97, initial_means=None)
                

                means, covs, ct, y_perdict = learn_gaussians(data, y_true, dpv3,\
                                                  combined, batch_size, label_ids, rho=0.97, initial_means=means)
                
                
                
                n_samples = np.zeros(5, dtype=int)
                cls, ns = np.unique(y_true, return_counts=True)
                for h in range(len(cls)):
                    n_samples[int(cls[h])] = ns[h]
                
                n_samples_2 = y_perdict
                
                new_row = pd.DataFrame({"source_images_idx": [source_idx],
                                "source_i": [i],
                                "source_j": [j],
                                "source_k": [k],
                                "means": [means],
                                "covs": [covs],
                                "label_frequency": [n_samples], 
                                "perdict_frequency": [n_samples_2]})


                df = pd.concat([df, new_row], ignore_index=True)
                
                if counter % 10000 == 0:
                    # Save DataFrame to .csv file
                    df.to_csv('Gaussian_similarity_dataframe_27779.csv', index=False)
                
                    # Clear last display and display the DataFrame
                    clear_output(wait=True)
                    display(df)
                    
                if restart_k !=0: restart_k = 0
            if restart_j != 0: restart_j = 0
        if restart_i != 0: restart_i = 0
                
                
df.to_csv('Gaussian_similarity_dataframe_27779.csv', index=False)              

Unnamed: 0,source_images_idx,source_i,source_j,source_k,means,covs,label_frequency,perdict_frequency
0,0,0,0,0,"[[-0.28571146606139625, 21.204547536178723, 14...","[[[31.335978016179237, -16.114325028327354, 17...","[32754, 14, 0, 0, 0]","[32754, 0, 0, 0, 0]"
1,0,0,0,6,"[[-0.494177851054571, 21.66812016989689, 14.26...","[[[31.988975393722537, -6.743702629660195, 26....","[32682, 86, 0, 0, 0]","[32608, 0, 0, 0, 0]"
2,0,0,0,12,"[[-0.9696461552831845, 21.908530290891097, 13....","[[[20.144235070568715, -3.890665311996754, 18....","[32592, 176, 0, 0, 0]","[32452, 0, 0, 0, 0]"
3,0,0,0,18,"[[-1.1793156030154612, 21.575080577679916, 13....","[[[21.395850187951797, -2.0197667613360255, 18...","[32521, 247, 0, 0, 0]","[32267, 0, 0, 0, 0]"
4,0,0,0,24,"[[-0.7695333785146009, 20.76432502490964, 13.3...","[[[19.368920152949404, -2.1646473495710037, 17...","[32493, 275, 0, 0, 0]","[32116, 0, 0, 0, 0]"
...,...,...,...,...,...,...,...,...
19995,11,42,20,24,"[[-3.0824664429528132, 7.399526137793127, -2.0...","[[[7.162089406694452, -2.9815736293771216, -0....","[2884, 1636, 15493, 10333, 2422]","[1490, 306, 14296, 8352, 1711]"
19996,11,42,20,32,"[[-1.7480537154977152, 4.639217486136282, -3.1...","[[[11.922739070341196, -1.995273397114842, 6.1...","[3783, 998, 16352, 7510, 4125]","[1499, 427, 15104, 6263, 2966]"
19997,11,42,20,40,"[[-1.6554243158480049, 11.778929089090123, 0.7...","[[[8.837984726185518, -8.999691032890446, -1.2...","[9040, 958, 14261, 4265, 4244]","[6123, 416, 12722, 3446, 3120]"
19998,11,42,20,48,"[[-4.952122511061913, 15.481363810415143, 0.89...","[[[34.14266088946012, -28.582415372013816, 1.4...","[17036, 945, 10070, 1266, 3451]","[14690, 397, 9044, 559, 2405]"


11 42 20 64
20000
11 42 20 72
20001
11 42 20 80
20002
11 42 20 84
20003
11 42 25 0
20004
11 42 25 8
20005
11 42 25 16
20006
11 42 25 24
20007
11 42 25 32
20008
11 42 25 40
20009
11 42 25 48
20010
11 42 25 56
20011
11 42 25 64
20012
11 42 25 72
20013
11 42 25 80
20014
11 42 25 84
20015
11 42 30 0
20016
11 42 30 8
20017
11 42 30 16
20018
11 42 30 24
20019
11 42 30 32
20020
11 42 30 40
20021
11 42 30 48
20022
11 42 30 56
20023
11 42 30 64
20024
11 42 30 72
20025
11 42 30 80
20026
11 42 30 84
20027
11 42 35 0
20028
11 42 35 8
20029
11 42 35 16
20030
11 42 35 24
20031
11 42 35 32
20032
11 42 35 40
20033
11 42 35 48
20034
11 42 35 56
20035
11 42 35 64
20036
11 42 35 72
20037
11 42 35 80
20038
11 42 35 84
20039
11 42 40 0
20040
11 42 40 8
20041
11 42 40 16
20042
11 42 40 24
20043
11 42 40 32
20044
11 42 40 40
20045
11 42 40 48
20046
11 42 40 56
20047
11 42 40 64
20048
11 42 40 72
20049
11 42 40 80
20050
11 42 40 84
20051
11 42 45 0
20052
11 42 45 8
20053
11 42 45 16
20054
11 42 45 24
20055
11

In [29]:
len(df)
# 945773

27780

In [30]:
df.head

<bound method NDFrame.head of       source_images_idx source_i source_j source_k  \
0                     0        0        0        0   
1                     0        0        0        6   
2                     0        0        0       12   
3                     0        0        0       18   
4                     0        0        0       24   
...                 ...      ...      ...      ...   
27775                15      101       75       63   
27776                15      101       75       72   
27777                15      101       75       81   
27778                15      101       75       90   
27779                15      101       75       91   

                                                   means  \
0      [[-0.28571146606139625, 21.204547536178723, 14...   
1      [[-0.494177851054571, 21.66812016989689, 14.26...   
2      [[-0.9696461552831845, 21.908530290891097, 13....   
3      [[-1.1793156030154612, 21.575080577679916, 13....   
4      [[-0.769533378