In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= '0'

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Resized,
    EnsureTyped,
    ToTensord,
    Lambdad
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
# %matplotlib inline
from PIL import Image
from tqdm import tqdm 
import functools

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import natsort
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai import data, transforms
from monai.utils import first, set_determinism
from monai.data import NumpyReader
from scipy import ndimage
import numpy as np
from torch.autograd import Variable
import gc
import GPUtil
from monai.inferers import sliding_window_inference
import pydicom
import re
import torchsummary
from torchsummary import summary
import time
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
import glob

gc.collect()
torch.cuda.empty_cache()
GPUtil.showUtilization()

device = torch.device('cuda')

print('Count of using GPUs:', torch.cuda.device_count())
print('Current cuda device:', torch.cuda.current_device())

| ID | GPU | MEM |
------------------
|  0 |  0% |  0% |
|  1 |  0% |  0% |
|  2 |  0% |  0% |
|  3 |  0% |  0% |
Count of using GPUs: 1
Current cuda device: 0


In [2]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out
    
    
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc=2, output_nc=1, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        encoder_model = [nn.ReflectionPad2d(3),
                         nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                         norm_layer(ngf),
                         nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            encoder_model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                              norm_layer(ngf * mult * 2),
                              nn.ReLU(True)]

        self.encoder_model = nn.Sequential(*encoder_model)

        self.map_decoder = self.map_build_decoder(ngf, output_nc, norm_layer, use_bias, n_downsampling, n_blocks)
        self.emph_decoder = self.build_decoder(ngf, 2, norm_layer, use_bias, n_downsampling, n_blocks)

    def map_build_decoder(self, ngf, output_nc, norm_layer, use_bias, n_downsampling, n_blocks):
        decoder_model = []
        
        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks
            decoder_model += [ResnetBlock(ngf*mult, padding_type='reflect', norm_layer=norm_layer, use_dropout=False, use_bias=use_bias)]
        
        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            decoder_model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                                 kernel_size=3, stride=2,
                                                 padding=1, output_padding=1,
                                                 bias=use_bias),
                              norm_layer(int(ngf * mult / 2)),
                              nn.ReLU(True)]
            
        decoder_model += [nn.ReflectionPad2d(3)]
        decoder_model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        decoder_model += [nn.Tanh()]

        return nn.Sequential(*decoder_model)
    
    def build_decoder(self, ngf, output_nc, norm_layer, use_bias, n_downsampling, n_blocks):
        decoder_model = []
        
        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            decoder_model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                                 kernel_size=3, stride=2,
                                                 padding=1, output_padding=1,
                                                 bias=use_bias),
                              norm_layer(int(ngf * mult / 2)),
                              nn.ReLU(True)]
            
        decoder_model += [nn.ReflectionPad2d(3)]
        decoder_model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        decoder_model += [nn.Tanh()]

        return nn.Sequential(*decoder_model)
    
    def forward(self, input):
        encoded = self.encoder_model(input)
        return self.map_decoder(encoded), self.emph_decoder(encoded)

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, img_A, img_B):
        """Standard forward."""
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)
    
generator = ResnetGenerator().apply(weights_init_normal)
discriminator = NLayerDiscriminator().apply(weights_init_normal)

discriminator = nn.DataParallel(discriminator, device_ids=[0,1])
discriminator.to(device)

model = nn.DataParallel(generator, device_ids=[0,1])
model.to(device)



DataParallel(
  (module): ResnetGenerator(
    (encoder_model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(2, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
    )
    (map_decoder): Sequential(
      (0): ResnetBlock(
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (2): BatchNorm2d(256, eps=1e-05, mom

In [2]:
def list_sort_nicely(l):
    def tryint(s):
        try:
            return int(s)
        except:
            return s
    
    def alphanum_key(s):
        return [ tryint(c) for c in re.split('([0-9]+)', s)]
    l.sort(key=alphanum_key)
    return l

print("working...")

map_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/map/*/*.dcm"
vnc_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/vnc_wi/*/*.dcm"
vnc_wo_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/vnc_wo/*/*.dcm"
#emph_dir = "/mnt/nas203/forGPU/leesangy/final_dataset/test/emp/*/*.dcm"

map_wis = list_sort_nicely((glob.glob(map_wi_dir)))
vnc_wis = list_sort_nicely((glob.glob(vnc_wi_dir)))
vnc_wos = list_sort_nicely((glob.glob(vnc_wo_dir)))
#emphs = list_sort_nicely((glob.glob(emph_dir)))

data_dicts = [{"map_wi": map_wi_name, "vnc_wi": vnc_wi_name, "vnc_wo": vnc_wo_name, "map_wi_name": map_wi_name} for map_wi_name, vnc_wi_name, vnc_wo_name, map_wi_name in zip(map_wis, vnc_wis, vnc_wos, map_wis)]
test_files = data_dicts

print(len(test_files))

def squeeze(image):
    image = image.squeeze()
    return image

data_transforms = Compose(
    [
        LoadImaged(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        Lambdad(keys=["map_wi", "vnc_wi", "vnc_wo"], func=squeeze),
        EnsureChannelFirstd(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        ScaleIntensityRanged(keys=["vnc_wi", "vnc_wo"], a_min=-1024.0, a_max=-100.0, b_min=0.0, b_max=1.0, clip=True),
        ScaleIntensityRanged(keys=["map_wi"], a_min=0.0, a_max=50.0, b_min=0.0, b_max=1.0, clip=True),
        #AddChanneld(keys=["image"]),
        #Resized(keys=['image'], spatial_size=[1,224,224]),
        Resized(keys=["map_wi", "vnc_wi", "vnc_wo"], spatial_size=[256,256]),
        EnsureTyped(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        ToTensord(keys=["map_wi", "vnc_wi", "vnc_wo"])]
)

check_ds = Dataset(data=test_files, transform=data_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
vnc_wi, vnc_wo, map_wi = (check_data["vnc_wi"][0][0], check_data["vnc_wo"][0][0], check_data["map_wi"][0][0])
print(f"image shape: {vnc_wi.shape}, label shape: {vnc_wo.shape}")

working...
6000


monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.


image shape: torch.Size([256, 256]), label shape: torch.Size([256, 256])


In [5]:
#train_ds = CacheDataset(data=train_files, transform=data_transforms, cache_rate=2.0, num_workers=8)
#train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)

root_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_inference/multitasking_G_256/"

model.load_state_dict(torch.load(os.path.join(root_dir, "epoch22_model.pth")))
model.eval()

discriminator.load_state_dict(torch.load(os.path.join(root_dir, "epoch22_D.pth")))
discriminator.eval()

test_ds = CacheDataset(data=test_files, transform=data_transforms, cache_rate=2.0, num_workers=8)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, shuffle=False)

import nibabel as nib
import cv2
from skimage import exposure
import time

print("test started...")

time_list = []

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        map_wi = batch["map_wi"].to(device)
        real_vnc_wi = batch["vnc_wi"].to(device)
        real_vnc_wo = batch["vnc_wo"].to(device)
        map_wi_name = batch["map_wi_name"][0][-16:-4]
        
        mask = np.rot90(map_wi.detach().cpu().squeeze(), k=3)
        mask[mask > 0] = 1
        mask[mask < 0] = 0
        
        #plt.imshow(mask)
        #plt.show()
        
        real_A_cat = torch.cat((real_vnc_wi, real_vnc_wo),1)
        #start_time = time.time()
        map_fake, emph_mask = model(real_A_cat)
        #time_one_slice = time.time()-start_time
        #print(time_one_slice, "sec")
        #time_list.append(time_one_slice)
        
        test_outputs= np.rot90(map_fake.detach().cpu().numpy().squeeze(),k=3)
        test_outputs = test_outputs * mask
        test_outputs[test_outputs <0] =1e-5
        test_outputs = test_outputs * mask
        image = nib.Nifti1Image(np.rot90(test_outputs, k=3), affine=np.eye(4))
        
        nib.save(image, f"/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_inference/multitasking_G_256/test/nifti/{map_wi_name}.nii.gz")

        if i % 199 == 0:
            print(i)

In [14]:
print(np.sum(time_list))

0.8347554206848145


In [16]:
print(len(time_list))

200
