In [1]:
import glob
import os
import nibabel as nib
import numpy as np
from scipy.ndimage.interpolation import zoom
from scipy import ndimage
import pickle
import nibabel as nib
import random
from scipy.ndimage.interpolation import zoom


from torchvision import transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision


In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def read_mutant_txt(path):
    name_list = []
    fo = open(path)
    for line in fo:
        striped_line = line.strip('\n')
        if striped_line != '':
            name_list.append(striped_line)
    return name_list


In [3]:
save_name = 'All_data_112_64_64_longitudinal_remove_small.pickle'
with open(os.path.join(os.getcwd(),'data',save_name), "rb") as input_file:
    all_train_data = pickle.load(input_file)

In [4]:
test_data = []

for i in range(len(all_train_data)):
    test_data.append((all_train_data[i][2], all_train_data[i][3]-0.5))

print(len(test_data))


433


In [5]:
from torch.utils.data import Dataset, DataLoader

class Mouse_sub_volumes(Dataset):
    """Mouse sub-volumes BV dataset."""

    def __init__(self, all_data , transform=None):
        """
        Args:
            all_whole_volumes: Contain all the padded whole BV volumes as a dic
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.all_data = all_data
        self.transform = transform
    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, num):
        
        img_name, cur_img = self.all_data[num]
        
        img = np.float32(cur_img[np.newaxis,...])
        sample = {'image': img, 'image_name': img_name}
        return sample

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class VGG_net(nn.Module):
    def __init__(self,conv_drop_rate=0.10,linear_drop_rate=0.10):
        super(VGG_net, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=12, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv1_bn = nn.BatchNorm3d(12)
        self.conv2 = nn.Conv3d(in_channels=12, out_channels=12, kernel_size=3,stride=1,padding=2, dilation=2)
        self.conv2_bn = nn.BatchNorm3d(12)
        self.pool1 = nn.MaxPool3d(2, 2)
        self.dropout1 = nn.Dropout3d(conv_drop_rate)
        
        self.conv3 = nn.Conv3d(in_channels=12, out_channels=24, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv3_bn = nn.BatchNorm3d(24)
        self.conv4 = nn.Conv3d(in_channels=24, out_channels=24, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv4_bn = nn.BatchNorm3d(24)
        self.pool2 = nn.MaxPool3d(2, 2)
        self.dropout2 = nn.Dropout3d(conv_drop_rate)
        
        self.conv5 = nn.Conv3d(in_channels=24, out_channels=48, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv5_bn = nn.BatchNorm3d(48)
        self.conv6 = nn.Conv3d(in_channels=48, out_channels=48, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv6_bn = nn.BatchNorm3d(48)
        self.pool3 = nn.MaxPool3d(2, 2)
        self.dropout3 = nn.Dropout3d(conv_drop_rate)
        
        self.conv7 = nn.Conv3d(in_channels=48, out_channels=72, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv7_bn = nn.BatchNorm3d(72)
        self.conv8 = nn.Conv3d(in_channels=72, out_channels=72, kernel_size=3,stride=1, padding=2,dilation=2)
        self.conv8_bn = nn.BatchNorm3d(72)
        self.pool4 = nn.AdaptiveAvgPool3d((1,1,1))
        self.dropout4 = nn.Dropout3d(conv_drop_rate)
        self.pool5 = nn.AdaptiveAvgPool3d((1,1,1))
        self.dropout5 = nn.Dropout3d(conv_drop_rate)
        
        
        self.fc1 = nn.Linear(144, 2)
        
    def forward(self, x):
        x = self.conv1_bn(F.relu(self.conv1(x)))
        x = self.dropout1(self.pool1(self.conv2_bn(F.relu(self.conv2(x)))))
        
        x = self.conv3_bn(F.relu(self.conv3(x)))
        x = self.dropout2(self.pool2(self.conv4_bn(F.relu(self.conv4(x)))))
        
        x = self.conv5_bn(F.relu(self.conv5(x)))
        x = self.dropout3(self.pool3(self.conv6_bn(F.relu(self.conv6(x)))))
        
        x = self.conv7_bn(F.relu(self.conv7(x)))
        x = self.conv8_bn(F.relu(self.conv8(x)))
        
        x1 = self.dropout4(self.pool4(x))
        x2 = self.dropout5(self.pool5(x))
        
        x1 = x1.view(-1, 72)
        x2 = x2.view(-1, 72)
        x = torch.cat((x1, x2), 1)
        #x = self.dropout5(self.fc1_bn(F.relu(self.fc1(x))))
        x = self.fc1(x)
        return x


In [7]:
def test(model, device, test_loader):
    model.eval()    
    predicted_names_labels = []
    
    with torch.no_grad():
        for i_batch, sample_batched in enumerate(test_loader):
            inputs, inputs_names = sample_batched['image'], sample_batched['image_name']  
            inputs = inputs.to(device)
            # forward + backward + optimize
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            predicted_names_labels.append((inputs_names, predicted.cpu().numpy()))
            
    return predicted_names_labels

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = VGG_net()
#net.apply(weight_init)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
net.to(device)
print("There are {} parameters in the model".format(count_parameters(net)))

Let's use 2 GPUs!
There are 355358 parameters in the model


In [9]:
net.load_state_dict(torch.load('./model/mut_clas_2019_12_07_e150_global3.pth'))

In [10]:
Mouse_dataset = Mouse_sub_volumes(test_data)
test_dataloader = DataLoader(Mouse_dataset, batch_size=1, shuffle=False, num_workers=4)

test_dic = test(net, device, test_dataloader)

In [11]:
normal_count = 0
for i in range(len(test_dic)):
    if test_dic[i][1][0] == 1:
        normal_count += 1
print((len(test_dic)-normal_count)/len(test_dic))

0.2540415704387991


In [12]:
mutant_name = {0:'mutant',1:'normal'}

In [13]:
save_name = 'All_data_112_64_64_longitudinal_volume_surface.pickle'
with open(os.path.join(os.getcwd(),'data',save_name), "rb") as input_file:
    mutant_label = pickle.load(input_file)
volume_surface_dic = {}
for i in range(len(mutant_label)):
    volume_surface_dic[mutant_label[i][1]] = (mutant_label[i][3],mutant_label[i][4])
    

In [15]:
test_dic[2][0][0]

'20161219_En1_E12_E1a'

In [17]:
file = open(r"./predicted_mutant_volume_surface.txt","w+")
for i in range(len(test_dic)):
    file.write(str(test_dic[i][0][0])+': '+mutant_name[test_dic[i][1][0]]+ " "+str(volume_surface_dic[test_dic[i][0][0]][0])+" "+str(volume_surface_dic[test_dic[i][0][0]][1])+'\n')
file.close()

In [None]:
cross_table = np.zeros([2,2])
mut_to_nor = []
nor_to_mul = []

for i in range(len(test_dic)):
    if test_dic[i][0] ==0 and test_dic[i][1] ==0:
        cross_table[0,0] += 1
    elif  test_dic[i][0] ==0 and test_dic[i][1] ==1:
        cross_table[0,1] += 1
        mut_to_nor.append(i)
    elif test_dic[i][0] ==1 and test_dic[i][1] ==0:
        cross_table[1,0] += 1
        nor_to_mul.append(i)
    elif test_dic[i][0] ==1 and test_dic[i][1] ==1:
        cross_table[1,1] += 1
print(cross_table)
print(mut_to_nor)
print(nor_to_mul)

In [None]:
test_dic

In [None]:
for i in mut_to_nor:
    print(i)
    img_nft = nib.Nifti1Image(np.squeeze(test_data[i][0]+0.5),np.eye(4))
    img_save_data_path = './img/mul_img{}_cam.nii'.format(i)
    nib.save(img_nft,img_save_data_path)

In [None]:
for i in nor_to_mul:
    print(i)
    img_nft = nib.Nifti1Image(np.squeeze(test_data[i][0]+0.5),np.eye(4))
    img_save_data_path = './img/nor_img{}_cam.nii'.format(i)
    nib.save(img_nft,img_save_data_path)

In [None]:
for i in range(len(test_data)):
    if test_data[i][1] == 0:
        img_nft = nib.Nifti1Image(np.squeeze(test_data[i][0]+0.5),np.eye(4))
        img_save_data_path = './img/mul_img{}.nii'.format(i)
        nib.save(img_nft,img_save_data_path)

In [None]:
def compute_saliency_maps(X, y, model):
    """
    Compute a class saliency map using the model for images X and labels y.

    Input:
    - X: Input images; Tensor of shape (N, 3, H, W)
    - y: Labels for X; LongTensor of shape (N,)
    - model: A pretrained CNN that will be used to compute the saliency map.

    Returns:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()
    
    # Make input tensor require gradient
    X.requires_grad_()
    ##############################################################################
    # Perform a forward and backward pass through the model to compute the gradient 
    # of the correct class score with respect to each input image. You first want 
    # to compute the loss over the correct scores (we'll combine losses across a batch
    # by summing), and then compute the gradients with a backward pass.
    ##############################################################################
    scores = model(X)
    
    # Get the correct class computed scores.
    scores = scores.gather(1, y.view(-1, 1)).squeeze()  
    
    # Backward pass, need to supply initial gradients of same tensor shape as scores.
    scores.backward(torch.tensor(10.0).cuda(device))
    
    # Get gradient for image.
    saliency = X.grad.data
    
    # Convert from 3d to 1d.
    saliency = saliency.abs()
    saliency = saliency.squeeze()
    ##############################################################################
    return saliency

In [None]:
test_Mouse_dataset = Mouse_sub_volumes(test_data)
test_dataloader = DataLoader(test_Mouse_dataset, batch_size=1, shuffle=False, num_workers=4)

for i_batch, sample_batched in enumerate(test_dataloader):
    inputs, labels = sample_batched['image'], sample_batched['label']  
    inputs = inputs.to(device)
    labels = labels.to(device)
    saliency = compute_saliency_maps(inputs, labels, net)
    
    max_value = torch.max(saliency)
    saliency[saliency >= (max_value*0.2)] = 1
    saliency[saliency < (max_value*0.2)] = 0
    
    img_nft = nib.Nifti1Image(np.squeeze(inputs.cpu().detach().numpy()+0.5),np.eye(4))
    img_save_data_path = './saliency_map/img_label{}_{}.nii'.format(labels.cpu().numpy()[0], i_batch)
    nib.save(img_nft,img_save_data_path)
    
    saliency_nft = nib.Nifti1Image(np.squeeze(saliency.cpu().numpy()),np.eye(4))
    saliency_save_data_path = './saliency_map/salency_label{}_{}.nii'.format(labels.cpu().numpy()[0], i_batch)
    nib.save(saliency_nft,saliency_save_data_path)
        

In [None]:
torch.max(saliency).cpu()

In [None]:
max_value = torch.max(saliency)
saliency[saliency >= (max_value*0.2)] = 1
saliency[saliency < (max_value*0.2)] = 0

In [None]:
# net.fc1.weight.data

In [None]:
# outputs= []
# def hook(module, input, output):
#     outputs.append(output)

# net.conv8_bn.register_forward_hook(hook)
# out = net(res)
# out = net(res1)
# print(outputs)


# test_Mouse_dataset = Mouse_sub_volumes(test_data)
# test_dataloader = DataLoader(test_Mouse_dataset, batch_size=1, shuffle=False, num_workers=4)

# for i_batch, sample_batched in enumerate(test_dataloader):
#     inputs, labels = sample_batched['image'], sample_batched['label']  
#     inputs = inputs.to(device)
#     labels = labels.to(device)
#     saliency = compute_saliency_maps(inputs, labels, net)

In [None]:
net.cpu()
net.eval()

fc_weight = net.fc1.weight.data

res50_conv = nn.Sequential(*list(net.children())[:-3])
for param in res50_conv.parameters():
    param.requires_grad = False

test_Mouse_dataset = Mouse_sub_volumes(test_data)
test_dataloader = DataLoader(test_Mouse_dataset, batch_size=1, shuffle=False, num_workers=4)

for i_batch, sample_batched in enumerate(test_dataloader):
    inputs, labels = sample_batched['image'], sample_batched['label']  
    saliency = compute_cam_maps(inputs, labels, net, fc_weight, res50_conv)
    
#     max_value = np.max(saliency)
#     saliency[saliency >= (max_value*0.2)] = 1
#     saliency[saliency < (max_value*0.2)] = 0
    
    img_nft = nib.Nifti1Image(np.squeeze(inputs.numpy()+0.5),np.eye(4))
    img_save_data_path = './cam_map/img_label{}_{}.nii'.format(labels.numpy()[0], i_batch)
    nib.save(img_nft,img_save_data_path)
    
    saliency_nft = nib.Nifti1Image(np.squeeze(saliency),np.eye(4))
    saliency_save_data_path = './cam_map/salency_label{}_{}.nii'.format(labels.numpy()[0], i_batch)
    nib.save(saliency_nft,saliency_save_data_path)
    

In [None]:
max_value

In [None]:
# cam represent class saliency map
def compute_cam_maps(X, y, model, fc_weight, feature_extract): 
    model.eval()
    
    outputs = feature_extract(X).squeeze()
    channels = outputs.shape[0]
    saliency = outputs[0,...] * fc_weight[y, 0]
    for i in range(1,channels):
        saliency += outputs[i,...] * fc_weight[y, i]
    saliency = zoom(saliency.numpy(), 8)
    
    saliency = saliency - np.min(saliency)
    saliency = saliency / np.max(saliency)
    
    return saliency

In [None]:
for i in range(1,96):
    print(i)

In [None]:
net.cpu()
net.eval()

fc_weight = net.fc1.weight.data

res50_conv = nn.Sequential(*list(net.children())[:-2])
for param in res50_conv.parameters():
    param.requires_grad = False

test_Mouse_dataset = Mouse_sub_volumes(test_data)
test_dataloader = DataLoader(test_Mouse_dataset, batch_size=1, shuffle=False, num_workers=4)

for i_batch, sample_batched in enumerate(test_dataloader):
    inputs, labels = sample_batched['image'], sample_batched['label']
    print(res50_conv(inputs).shape)