In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


import os
import sys
import h5py
import time
import numpy as np
import pandas as pd
from PIL import Image
from IPython.display import display as ipy_display, clear_output
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# setting path
sys.path.append('../')
deplabv3 = __import__('Deeplabv3')

In [3]:
# setting path
sys.path.append('../')
import networks

In [4]:
source_dataset =  __import__('step1_source_dataset')

In [5]:
target_dataset =  __import__('step1_target_dataset')

# parameters

In [6]:
num_classes = 5
batch_size = 16
suffix = 'run4'
epoch = 130000
dataset_name = 'abdomen'

# CUDA

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Path

In [8]:
source_ct_train_dir = "../../data/h5py/"
source_ct_test_dir = "../../data/h5py/"

In [9]:
target_mr_train_dir = "../../data/h5py/"
target_mr_test_dir = "../../data/h5py/"

# label_ids_abdomen

In [10]:
label_ids_abdomen = {"ignore": 0,
    "lv_myo": 1,
    "la_blood": 2,
    "lv_blood": 3,
    "aa": 4,
}
label_ids = label_ids_abdomen

# Load model 

In [11]:
dpv3 = deplabv3.DeepLabV3(num_classes)
classifier = networks.classifier(num_classes)

dpv3 = dpv3.to(device)
classifier = classifier.to(device)

# parallel
dpv3 = torch.nn.DataParallel(dpv3)
classifier = torch.nn.DataParallel(classifier)

In [12]:
dpv3_checkpoint = torch.load('../record-data/' + 'dpv3_weights_' + str(epoch) + '.pth')
classifier_checkpoint = torch.load('../record-data/' + 'classifier_weights_' + str(epoch) + '.pth')

dpv3.load_state_dict(dpv3_checkpoint)
classifier.load_state_dict(classifier_checkpoint)
print("Loaded model weights")

Loaded model weights


In [13]:
class Combined_Model(nn.Module):
    
    def __init__(self, Unet, classifier):
        super(Combined_Model, self).__init__()
        
        self.Unet = Unet
        self.classifier = classifier
        
    def forward(self, x):
        
        output1 = self.Unet(x)
        output2 = self.classifier(output1)
        
        return output2

In [14]:
# Create the combined model
combined = Combined_Model(dpv3, classifier)
combined = combined.to(device)

# Dataset

In [15]:
dataloader = source_dataset.get_dataloader( source_ct_train_dir,  source_ct_test_dir, \
                                           num_classes, batch_size, domain = 'source' )

source_train_dataset = dataloader["train"].dataset
source_test_dataset = dataloader["test"].dataset

In [16]:
dataloader = target_dataset.get_dataloader( target_mr_train_dir,  target_mr_test_dir, num_classes, \
                                           batch_size, domain = 'target' )

target_train_dataset = dataloader["train"].dataset
target_test_dataset = dataloader["test"].dataset

# learn gaussians

In [17]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

In [18]:
def learn_gaussians(data, y_true, z_model, sm_model, batch_size, label_ids, rho=.97, initial_means=None):    
    
    # z_model: deeplab v3
    # sm_model: combined
    
    num_classes = len(label_ids)
    
    means = initial_means
    if initial_means is None:
        means = np.zeros((num_classes, num_classes)) # (5,5)
        
    covs = np.zeros((num_classes, num_classes, num_classes)) # shapeL (5, 5, 5)
    cnt = np.zeros(num_classes) # shape: (5,)
    
    
    data = data.transpose((2, 0, 1)) # Before is [256, 256, 32] after is [32, 256, 256]
    data = np.expand_dims(data, axis=0) # [1, 32, 256, 256]
    data = np.expand_dims(data, axis=0) # [1, 1, 32, 256, 256]
    
    y_true = y_true.transpose((2, 0, 1)) # Before is [256, 256, 32] after is [32, 256, 256]
    y_true = np.expand_dims(y_true, axis=0) # [1, 32, 256, 256]
    
    
    assert len(data.shape) == 5
    assert len(y_true.shape) == 4
    
    assert data.shape[2] == 32
    assert y_true.shape[1] == 32
    
    num_of_correct_prediction_each_classes = np.zeros(5, dtype=int)
    
    z_model.eval()
    sm_model.eval()

    with torch.no_grad():

        input_data = torch.from_numpy(data).to(device) # input_data:  torch.Size([1, 1, 32, 256, 256]) 
        input_data = input_data.float()
        
        # Predict latent features
        zs = z_model(input_data) # (8, 5, 32, 256, 256)
        zs = zs.cpu()
        zs = zs.permute(0, 2, 3, 4, 1)
        zs = zs.reshape(-1, num_classes).detach().numpy() # (8 * 32 * 256 * 256, 5)


        # Get softmax outputs
        # (48, 5, 32, 256, 256) => (32 * 256 * 256 * 8, 5)
        y_hat = sm_model(input_data).softmax(dim = 1).cpu()
        y_hat = y_hat.permute(0, 2, 3, 4, 1)
        y_hat = y_hat.reshape(-1, num_classes).detach().numpy()
            
        # downwards across rows (axis 0) 
        # running horizontally across columns (axis 1) = (axis=-1)
        # Because of continual slice, the vmax likes  [0.2908752  0.29087755 0.29087812 ... 0.29117948 0.2911795  0.29118514]
        vmax = np.max(y_hat, axis = 1) 
        y_hat = np.argmax(y_hat, axis = 1) 
        
        # before ravel: (8, 32, 256, 256);
        y_t = y_true.ravel()
        
        # Keep a few exemplars per class
        for label in label_ids:
            c = label_ids[label]

            ind = (y_t == c) & (y_hat == c) & (vmax > rho)
            
            if np.sum(ind) > 0:
                # We have at least one sample
                curr_data = zs[ind]
                num_of_correct_prediction_each_classes[c] = np.sum(ind)

                if initial_means is None:
                    # Only update means and counts

                    # 1d + 1d == (5, ) + (5, )
                    means[c] += np.sum(curr_data, axis=0)
                    cnt[c] += np.sum(ind)

                else:
                    # ! here means are scaled to their final values
                    # Example: 
                    # curr_data.shape = (8, 5); 8 is the sum of ind
                    # means[c].shape:  (5,)
                    # curr_data - means[c]:  (8, 5)
                    # np.transpose(curr_data - means[c]:  (5, 8)
                    # np.dot = (5, 8) * (8, 5) = (5, 5)

                    sigma = np.dot(np.transpose(curr_data - means[c]), curr_data - means[c])
                    assert sigma.shape == (num_classes, num_classes)
                    covs[c] += sigma
                    cnt[c] += np.sum(ind)
        

    
    # Normalize results
    for i in range(num_classes):
        if initial_means is None:
            means[i] /= (cnt[i] + 1e-10)
        covs[i] /= (cnt[i] - 1)
        
    assert np.isnan(means).any() == False
    assert np.isnan(cnt).any() == False
    
    return means, covs, cnt, num_of_correct_prediction_each_classes

# Find the most similarity based on label space

In [19]:
filename = 'Gaussian_similarity_dataframe_2.csv'

if os.path.exists(filename):
    
    df = pd.read_csv(filename)
    restart_image_idx = df.iloc[-1]['source_images_idx']
    restart_i = df.iloc[-1]['source_i']
    restart_j = df.iloc[-1]['source_j']
    restart_k = df.iloc[-1]['source_k']
    
else:
    df = pd.DataFrame(columns=["source_images_idx", "source_i", "source_j", "source_k", 
                               "means", "covs", "label_frequency", "perdict_frequency"])
    restart_image_idx = 0
    restart_i = 0
    restart_j = 0
    restart_k = 0

In [20]:
counter = 0

for source_idx in range( restart_image_idx, len(source_train_dataset) ):
    
    source_data, source_label = source_train_dataset[source_idx] 
    shape_source = source_label.shape # (126, 113, 203)
    crop_size = (32, 32, 32)

    total_steps_i = shape_source[0] - crop_size[0] + 1
    step_size_i = total_steps_i // 3  # Integer division to get whole number steps

    total_steps_j = shape_source[1] - crop_size[1] + 1
    step_size_j = total_steps_j // 3  # Integer division to get whole number steps

    total_steps_z = shape_source[2] - crop_size[2] + 1
    step_size_z = total_steps_z // 3

    # Slide the smaller array across the larger one
    if step_size_i != 0:
        x_range = [i for i in range( 0,  total_steps_i, step_size_i)]
        if x_range[-1] != (shape_source[0] - crop_size[0]):
            x_range.append(shape_source[0] - crop_size[0])
    else:
        x_range = [i for i in range( 0,  total_steps_i, 1)]
    
    if step_size_j != 0:
        y_range = [j for j in range( 0,  total_steps_j, step_size_j)]
        if y_range[-1] != (shape_source[1] - crop_size[1]):
            y_range.append(shape_source[1] - crop_size[1])
    else:
        y_range = [j for j in range( 0,  total_steps_j, 1)]
    
    if step_size_z != 0:
        z_range = [z for z in range( 0,  total_steps_z, step_size_z)]
        if z_range[-1] != (shape_source[2] - crop_size[2]):
            z_range.append(shape_source[2] - crop_size[2])
    else:
        z_range = [z for z in range( 0,  total_steps_z, 1)]
    
    # Slide the smaller array across the larger one
    for i in x_range:
        
        if i < restart_i: continue
        
        for j in y_range:
            
            if j < restart_j: continue
            
            for k in z_range:
                
                if k < restart_k: continue

                print(source_idx, i, j, k)
                print(len(df))
                counter += 1
                # Extract  
                data = source_data[i:i+crop_size[0], j:j+crop_size[1], k:k+crop_size[2]] # (32, 32, 32)
                y_true = source_label[i:i+crop_size[0], j:j+crop_size[1], k:k+crop_size[2]]
                
                means, _ , ct, _ = learn_gaussians(data, y_true, dpv3,\
                                                combined, batch_size, label_ids, rho=0.97, initial_means=None)
                

                means, covs, ct, y_perdict = learn_gaussians(data, y_true, dpv3,\
                                                  combined, batch_size, label_ids, rho=0.97, initial_means=means)
                
                
                
                n_samples = np.zeros(5, dtype=int)
                cls, ns = np.unique(y_true, return_counts=True)
                for h in range(len(cls)):
                    n_samples[int(cls[h])] = ns[h]
                
                n_samples_2 = y_perdict
                
                df = df.append({"source_images_idx": source_idx,
                                "source_i": i,
                                "source_j": j,
                                "source_k": k,
                                "means": means,
                                "covs": covs,
                                "label_frequency": n_samples, 
                                "perdict_frequency": n_samples_2}, ignore_index=True)
                
                if counter % 10000 == 0:
                    # Save DataFrame to .csv file
                    df.to_csv('Gaussian_similarity_dataframe_1298.csv', index=False)
                
                    # Clear last display and display the DataFrame
                    clear_output(wait=True)
                    display(df)
                    
                if restart_k !=0: restart_k = 0
            if restart_j != 0: restart_j = 0
        if restart_i != 0: restart_i = 0
                
                
df.to_csv('Gaussian_similarity_dataframe_1298.csv', index=False)              

0 0 0 0
0
0 0 0 22
1
0 0 0 44
2
0 0 0 66
3
0 0 29 0
4
0 0 29 22
5
0 0 29 44
6
0 0 29 66
7
0 0 58 0
8
0 0 58 22
9
0 0 58 44
10
0 0 58 66
11
0 0 86 0
12
0 0 86 22
13
0 0 86 44
14
0 0 86 66
15
0 29 0 0
16
0 29 0 22
17
0 29 0 44
18
0 29 0 66
19
0 29 29 0
20
0 29 29 22
21
0 29 29 44
22
0 29 29 66
23
0 29 58 0
24
0 29 58 22
25
0 29 58 44
26
0 29 58 66
27
0 29 86 0
28
0 29 86 22
29
0 29 86 44
30
0 29 86 66
31
0 58 0 0
32
0 58 0 22
33
0 58 0 44
34
0 58 0 66
35
0 58 29 0
36
0 58 29 22
37
0 58 29 44
38
0 58 29 66
39
0 58 58 0
40
0 58 58 22
41
0 58 58 44
42
0 58 58 66
43
0 58 86 0
44
0 58 86 22
45
0 58 86 44
46
0 58 86 66
47
0 86 0 0
48
0 86 0 22
49
0 86 0 44
50
0 86 0 66
51
0 86 29 0
52
0 86 29 22
53
0 86 29 44
54
0 86 29 66
55
0 86 58 0
56
0 86 58 22
57
0 86 58 44
58
0 86 58 66
59
0 86 86 0
60
0 86 86 22
61
0 86 86 44
62
0 86 86 66
63
1 0 0 0
64
1 0 0 26
65
1 0 0 52
66
1 0 0 78
67
1 0 0 79
68
1 0 32 0
69
1 0 32 26
70
1 0 32 52
71
1 0 32 78
72
1 0 32 79
73
1 0 64 0
74
1 0 64 26
75
1 0 64 52
76
1

6 76 28 62
579
6 76 28 93
580
6 76 28 94
581
6 76 56 0
582
6 76 56 31
583
6 76 56 62
584
6 76 56 93
585
6 76 56 94
586
6 76 84 0
587
6 76 84 31
588
6 76 84 62
589
6 76 84 93
590
6 76 84 94
591
6 76 85 0
592
6 76 85 31
593
6 76 85 62
594
6 76 85 93
595
6 76 85 94
596
7 0 0 0
597
7 0 0 23
598
7 0 0 46
599
7 0 0 68
600
7 0 21 0
601
7 0 21 23
602
7 0 21 46
603
7 0 21 68
604
7 0 42 0
605
7 0 42 23
606
7 0 42 46
607
7 0 42 68
608
7 0 62 0
609
7 0 62 23
610
7 0 62 46
611
7 0 62 68
612
7 28 0 0
613
7 28 0 23
614
7 28 0 46
615
7 28 0 68
616
7 28 21 0
617
7 28 21 23
618
7 28 21 46
619
7 28 21 68
620
7 28 42 0
621
7 28 42 23
622
7 28 42 46
623
7 28 42 68
624
7 28 62 0
625
7 28 62 23
626
7 28 62 46
627
7 28 62 68
628
7 56 0 0
629
7 56 0 23
630
7 56 0 46
631
7 56 0 68
632
7 56 21 0
633
7 56 21 23
634
7 56 21 46
635
7 56 21 68
636
7 56 42 0
637
7 56 42 23
638
7 56 42 46
639
7 56 42 68
640
7 56 62 0
641
7 56 62 23
642
7 56 62 46
643
7 56 62 68
644
7 83 0 0
645
7 83 0 23
646
7 83 0 46
647
7 83 0 68
64

13 57 60 26
1120
13 57 60 52
1121
13 57 60 78
1122
13 57 60 79
1123
13 57 61 0
1124
13 57 61 26
1125
13 57 61 52
1126
13 57 61 78
1127
13 57 61 79
1128
13 58 0 0
1129
13 58 0 26
1130
13 58 0 52
1131
13 58 0 78
1132
13 58 0 79
1133
13 58 20 0
1134
13 58 20 26
1135
13 58 20 52
1136
13 58 20 78
1137
13 58 20 79
1138
13 58 40 0
1139
13 58 40 26
1140
13 58 40 52
1141
13 58 40 78
1142
13 58 40 79
1143
13 58 60 0
1144
13 58 60 26
1145
13 58 60 52
1146
13 58 60 78
1147
13 58 60 79
1148
13 58 61 0
1149
13 58 61 26
1150
13 58 61 52
1151
13 58 61 78
1152
13 58 61 79
1153
14 0 0 0
1154
14 0 0 27
1155
14 0 0 54
1156
14 0 0 81
1157
14 0 21 0
1158
14 0 21 27
1159
14 0 21 54
1160
14 0 21 81
1161
14 0 42 0
1162
14 0 42 27
1163
14 0 42 54
1164
14 0 42 81
1165
14 0 63 0
1166
14 0 63 27
1167
14 0 63 54
1168
14 0 63 81
1169
14 21 0 0
1170
14 21 0 27
1171
14 21 0 54
1172
14 21 0 81
1173
14 21 21 0
1174
14 21 21 27
1175
14 21 21 54
1176
14 21 21 81
1177
14 21 42 0
1178
14 21 42 27
1179
14 21 42 54
1180
14 21

In [21]:
df.head()

Unnamed: 0,source_images_idx,source_i,source_j,source_k,means,covs,label_frequency,perdict_frequency
0,0,0,0,0,"[[-0.2888454165426351, 21.205936679489465, 14....","[[[31.35080320878039, -16.116169540764865, 17....","[32754, 14, 0, 0, 0]","[32754, 0, 0, 0, 0]"
1,0,0,0,22,"[[-0.9200588874347555, 21.125551447744435, 13....","[[[20.169113686316926, -2.4976810497262667, 17...","[32497, 271, 0, 0, 0]","[32188, 0, 0, 0, 0]"
2,0,0,0,44,"[[0.7346884341189195, 21.550614357287696, 14.1...","[[[13.767388398456706, 2.497217517689941, 16.0...","[32668, 100, 0, 0, 0]","[32534, 0, 0, 0, 0]"
3,0,0,0,66,"[[-0.37759032845497015, 21.670763015747003, 13...","[[[16.054269562356485, -5.580640555421809, 10....","[32768, 0, 0, 0, 0]","[32768, 0, 0, 0, 0]"
4,0,0,29,0,"[[-6.049997427213509, 15.614664415088633, 5.13...","[[[21.96983906685115, -0.8256350663203466, 23....","[27621, 4184, 0, 963, 0]","[26722, 3000, 0, 589, 0]"


In [22]:
ns

array([17085, 15683])

In [23]:
y_perdict

array([15495,     0,     0,     0, 14064])

In [24]:
len(df)
# 945773

1298