In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= '0'

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Resized,
    EnsureTyped,
    ToTensord,
    Lambdad
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
# %matplotlib inline
from PIL import Image
from tqdm import tqdm 
import functools

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import natsort
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai import data, transforms
from monai.utils import first, set_determinism
from monai.data import NumpyReader
from scipy import ndimage
import numpy as np
from torch.autograd import Variable
import gc
import GPUtil
from monai.inferers import sliding_window_inference
import pydicom
import re
import torchsummary
from torchsummary import summary
import time
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
import glob

gc.collect()
torch.cuda.empty_cache()
GPUtil.showUtilization()

device = torch.device('cuda')

print('Count of using GPUs:', torch.cuda.device_count())
print('Current cuda device:', torch.cuda.current_device())

| ID | GPU | MEM |
------------------
|  0 |  0% |  4% |
|  1 |  0% |  0% |
|  2 |  0% |  0% |
|  3 |  0% |  0% |
Count of using GPUs: 1
Current cuda device: 0


In [2]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# U-NET 생성

class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc=2, output_nc=1, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator
        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.
        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, img_A):
        """Standard forward."""
        img_input = img_A
        return self.model(img_input)
    
generator = ResnetGenerator().apply(weights_init_normal)
discriminator = NLayerDiscriminator().apply(weights_init_normal)

discriminator = nn.DataParallel(discriminator, device_ids=[0])
discriminator.to(device)

model = nn.DataParallel(generator, device_ids=[0])
model.to(device)

DataParallel(
  (module): ResnetGenerator(
    (model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(2, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): ResnetBlock(
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [4]:
def list_sort_nicely(l):
    def tryint(s):
        try:
            return int(s)
        except:
            return s
    
    def alphanum_key(s):
        return [ tryint(c) for c in re.split('([0-9]+)', s)]
    l.sort(key=alphanum_key)
    return l

print("working...")

map_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/train/map/*/*.dcm"
vnc_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/train/vnc_wi/*/*.dcm"
vnc_wo_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/train/vnc_wo/*/*.dcm"

map_wis = list_sort_nicely((glob.glob(map_wi_dir)))
vnc_wis = list_sort_nicely((glob.glob(vnc_wi_dir)))
vnc_wos = list_sort_nicely((glob.glob(vnc_wo_dir)))
data_dicts = [{"map_wi": map_wi_name, "vnc_wi": vnc_wi_name, "vnc_wo": vnc_wo_name, "map_wi_name": map_wi_name} for map_wi_name, vnc_wi_name, vnc_wo_name, map_wi_name in zip(map_wis, vnc_wis, vnc_wos, map_wis)]
train_files = data_dicts

map_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/map/*/*.dcm"
vnc_wi_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/vnc_wi/*/*.dcm"
vnc_wo_dir = "/mnt/nas203/forGPU/leesangy/Task_VQ_mismatch/final_dataset/test/vnc_wo/*/*.dcm"

map_wis = list_sort_nicely((glob.glob(map_wi_dir)))
vnc_wis = list_sort_nicely((glob.glob(vnc_wi_dir)))
vnc_wos = list_sort_nicely((glob.glob(vnc_wo_dir)))
data_dicts = [{"map_wi": map_wi_name, "vnc_wi": vnc_wi_name, "vnc_wo": vnc_wo_name, "map_wi_name": map_wi_name} for map_wi_name, vnc_wi_name, vnc_wo_name, map_wi_name in zip(map_wis, vnc_wis, vnc_wos, map_wis)]
test_files = data_dicts

print(len(train_files))
print(len(test_files))

def squeeze(image):
    image = image.squeeze()
    return image

data_transforms = Compose(
    [
        LoadImaged(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        Lambdad(keys=["map_wi", "vnc_wi", "vnc_wo"], func=squeeze),
        EnsureChannelFirstd(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        ScaleIntensityRanged(keys=["vnc_wi", "vnc_wo"], a_min=-1024.0, a_max=-100.0, b_min=0.0, b_max=1.0, clip=True),
        ScaleIntensityRanged(keys=["map_wi"], a_min=0.0, a_max=50.0, b_min=0.0, b_max=1.0, clip=True),
        #AddChanneld(keys=["image"]),
        #Resized(keys=['image'], spatial_size=[1,224,224]),
        Resized(keys=["map_wi", "vnc_wi", "vnc_wo"], spatial_size=[256,256]),
        EnsureTyped(keys=["map_wi", "vnc_wi", "vnc_wo"]),
        ToTensord(keys=["map_wi", "vnc_wi", "vnc_wo"])]
)

check_ds = Dataset(data=test_files, transform=data_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
vnc_wi, vnc_wo, map_wi= (check_data["vnc_wi"][0][0], check_data["vnc_wo"][0][0], check_data["map_wi"][0][0])
print(f"image shape: {vnc_wi.shape}, label shape: {vnc_wo.shape}")


working...
42600
6000
image shape: torch.Size([256, 256]), label shape: torch.Size([256, 256])


In [6]:
#train_ds = CacheDataset(data=train_files, transform=data_transforms, cache_rate=2.0, num_workers=8)
#train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)

test_ds = CacheDataset(data=test_files, transform=data_transforms, cache_rate=2.0, num_workers=8)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, shuffle=False)

root_dir = "/mnt/nas203/forGPU/leesangy/final_inference/vanilla_gan/"

model.load_state_dict(torch.load(os.path.join(root_dir, "epoch22_model.pth")))
model.eval()

discriminator.load_state_dict(torch.load(os.path.join(root_dir, "epoch22_D.pth")))
discriminator.eval()

start_time = time.time()
import nibabel as nib
import cv2
from skimage import exposure

print("test started...")

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        map_wi = batch["map_wi"].to(device)
        real_vnc_wi = batch["vnc_wi"].to(device)
        real_vnc_wo = batch["vnc_wo"].to(device)
        map_wi_name = batch["map_wi_name"][0][-16:-4]
        
        mask = np.rot90(map_wi.detach().cpu().squeeze(), k=3)
        mask[mask > 0] = 1
        mask[mask < 0] = 0
        
        #plt.imshow(mask)
        #plt.show()
        
        real_A_cat = torch.cat((real_vnc_wi, real_vnc_wo),1)
        map_fake = model(real_A_cat)
        
        test_outputs= np.rot90(map_fake.detach().cpu().numpy().squeeze(),k=3)
        test_outputs = test_outputs * mask
        test_outputs[test_outputs <0] =1e-5
        test_outputs = test_outputs * mask
        image = nib.Nifti1Image(np.rot90(test_outputs, k=3), affine=np.eye(4))
        
        nib.save(image, f"/mnt/nas203/forGPU/leesangy/final_inference/vanilla_gan/test22/nifti/{map_wi_name}.nii.gz")
    
        if i % 100 == 0:
            print(i, (time.time()-start_time)/60, "min")