In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


import os
import sys
import h5py
import time
import numpy as np
import pandas as pd
from PIL import Image
from IPython.display import display as ipy_display, clear_output
import matplotlib.pyplot as plt

In [2]:
# setting path
sys.path.append('/media/rs37890/d28c4aed-3c7e-4203-8590-f72f868ee829/rs37890/Medical_images_source_MR/Deeplabv3_source_MR/')
deplabv3 = __import__('Deeplabv3')

In [3]:
# setting path
import networks

In [4]:
sys.path.append('/media/rs37890/d28c4aed-3c7e-4203-8590-f72f868ee829/rs37890/Medical_images_source_MR/Deeplabv3_source_MR/Adapt/')
dataset = __import__('Adaptation-step1-dataset')

# parameters

In [5]:
num_classes = 5
batch_size = 1
suffix = 'source'
epoch = 20000
dataset_name = 'abdomen'

# CUDA

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

# Path

In [7]:
root = '/media/rs37890/d28c4aed-3c7e-4203-8590-f72f868ee829/rs37890/Medical_images_source_MR/'

In [8]:
source_mr_train_dir =  root + "data/data/h5py/"
source_mr_test_dir = root + "data/data/h5py/"

In [9]:
target_ct_train_dir = root + "data/data/h5py/"
target_ct_test_dir = root + "data/data/h5py/"

# label_ids_abdomen

In [10]:
label_ids_abdomen = {"ignore": 0,
    "lv_myo": 1,
    "la_blood": 2,
    "lv_blood": 3,
    "aa": 4,
}
label_ids = label_ids_abdomen

# Load model 

In [11]:
dpv3 = deplabv3.DeepLabV3(num_classes)
classifier = networks.classifier(num_classes)

dpv3 = dpv3.to(device)
classifier = classifier.to(device)

# parallel
dpv3 = torch.nn.DataParallel(dpv3)
classifier = torch.nn.DataParallel(classifier)

In [12]:
root = '/media/rs37890/d28c4aed-3c7e-4203-8590-f72f868ee829/rs37890/Medical_images_source_MR/Deeplabv3_source_MR/record-data/'

In [13]:
dpv3_checkpoint = torch.load( root + f'dpv3_weights_{epoch}.pth')
classifier_checkpoint = torch.load( root + f'classifier_weights_{epoch}.pth')

dpv3.load_state_dict(dpv3_checkpoint)
classifier.load_state_dict(classifier_checkpoint)
print("Loaded model weights")

Loaded model weights


In [14]:
class Combined_Model(nn.Module):
    
    def __init__(self, Unet, classifier):
        super(Combined_Model, self).__init__()
        
        self.Unet = Unet
        self.classifier = classifier
        
    def forward(self, x):
        
        output1 = self.Unet(x)
        output2 = self.classifier(output1)
        
        return output2

In [15]:
# Create the combined model
combined = Combined_Model(dpv3, classifier)
combined = combined.to(device)

# Dataset

In [16]:
dataloader = dataset.get_dataloader( target_ct_train_dir,  
                                     target_ct_test_dir, 
                                     num_classes,
                                     batch_size,
                                     domain = 'source',
                                   )

# train_dataset = dataloader["train"].dataset
test_dataset = dataloader["test"].dataset

train_dataset = test_dataset

# learn gaussians

In [17]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

In [18]:
def learn_gaussians(dataset, z_model, sm_model, batch_size, label_ids, rho=.97, initial_means=None):    
    
    num_classes = len(label_ids)
    
    means = initial_means
    if initial_means is None:
        means = np.zeros((num_classes, num_classes)) 
        
    covs = np.zeros((num_classes, num_classes, num_classes))
    cnt = np.zeros(num_classes) 
    
    N = len(dataset)
    
    i = 0

    while i < N:
        print(i,'/',N)
        X, Y, sample_cnts = dataset[i]

        if len(X) < 80:
            batch_splits = [[X], [Y]]
        # elif:
        #     batch_splits = [X[:len(X)//2], X[len(X)//2:]], [Y[:len(Y)//2], Y[len(Y)//2:]]
        else:
            batch_splits = [ X[:len(X)//4], X[len(X)//4 : len(X)//2 ], X[ len(X)//2 : (3*len(X))//4 ], X[ (3*len(X))//4: ]], \
            [Y[ :len(Y)//4], Y[ len(Y)//4 : len(Y)//2 ], Y[ len(Y)//2 : (3*len(Y))//4], X[ (3*len(Y))//4:]]
        
        for data, y_true in zip(*batch_splits):
            print(np.shape(data))
            print(np.shape(y_true))
            num_of_correct_prediction_each_classes = np.zeros(5, dtype=int)
            
            z_model.eval()
            sm_model.eval()
            
            with torch.no_grad():
                
                input_data = torch.from_numpy(data).to(device)
                input_data = input_data.float()
                
                zs = z_model(input_data).cpu() 
                zs = zs.permute(0, 2, 3, 4, 1)
                zs = zs.reshape(-1, num_classes).detach().numpy()

                y_hat = sm_model(input_data).softmax(dim = 1).cpu()
                y_hat = y_hat.permute(0, 2, 3, 4, 1)
                y_hat = y_hat.reshape(-1, num_classes).detach().numpy()
            
            vmax = np.max(y_hat, axis = 1) 
            y_hat = np.argmax(y_hat, axis = 1) 
            
            y_t = y_true.ravel()

            for label in label_ids:
                c = label_ids[label]

                # ind = (y_t == c) & (y_hat == c) & (vmax > rho)
                ind = (y_t == c)  & (vmax > rho)
                
                if np.sum(ind) > 0:
                    curr_data = zs[ind]
                    num_of_correct_prediction_each_classes[c] = np.sum(ind)

                    if initial_means is None:
                        means[c] += np.sum(curr_data, axis=0)
                        cnt[c] += np.sum(ind)

                    else:
                        
                        sigma = np.dot(np.transpose(curr_data - means[c]), curr_data - means[c])
                        covs[c] += sigma
                        cnt[c] += np.sum(ind)
        
        i += 1

        
    # Normalize results
    for i in range(num_classes):
        if initial_means is None:
            means[i] /= (cnt[i] + 1e-10)
        covs[i] /= (cnt[i] - 1)
        
    assert np.isnan(means).any() == False
    assert np.isnan(cnt).any() == False
    
    return means, covs, cnt

In [19]:
start_time = time.time()
means, _ , ct = learn_gaussians(train_dataset, dpv3, combined, batch_size, label_ids, rho=0.97, initial_means=None)

print("computed means in", time.time() - start_time)

0 / 4
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 1, 32, 32, 32)
1 / 4
(72, 1, 32, 32, 32)
(72, 32, 32, 32)
2 / 4
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 1, 32, 32, 32)
3 / 4
(72, 1, 32, 32, 32)
(72, 32, 32, 32)
computed means in 5.770476341247559


In [20]:
start_time = time.time()
means, covs, ct = learn_gaussians(train_dataset, dpv3, combined, batch_size, label_ids, rho=0.97, initial_means=means)

print("finished training gaussians in", time.time() - start_time)

0 / 4
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 32, 32, 32)
(20, 1, 32, 32, 32)
(20, 1, 32, 32, 32)
1 / 4
(72, 1, 32, 32, 32)
(72, 32, 32, 32)
2 / 4
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 32, 32, 32)
(25, 1, 32, 32, 32)
(25, 1, 32, 32, 32)
3 / 4
(72, 1, 32, 32, 32)
(72, 32, 32, 32)
finished training gaussians in 3.412978410720825


In [21]:
np.save("./extras/means_" + dataset_name + "_" + suffix + ".npy", means)
np.save("./extras/covs_" + dataset_name + "_" + suffix + ".npy", covs)