In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from Deeper_CDCGAN import Generator, Discriminator
from tqdm import tqdm

In [2]:
lung_CT = np.load('D:\DATA\VESSEL_DATA\lung_CT.npy')
lung_masks = np.load('D:\DATA\VESSEL_DATA\lung_masks.npy')

In [3]:
class Lung_CT_DATA(Dataset):
    def __init__(self, lungs_np, masks_np, split_type, transforms = None):
        self.lungs_np = lungs_np
        self.masks_np = masks_np
        self.MAX_VAL = 2**15-1
        self.MIN_VAL = -4579
        self.val_split = 0.1
        self.val_num = int(self.val_split*self.lungs_np.shape[0])
        self.split_type = split_type
        self.transforms = transforms
        
        if self.split_type == 'val': 
            self.val_lungs = torch.Tensor([i for i in tqdm(self.lungs_np[:self.val_num])]).view(-1, 1, 512, 512)
            self.val_lungs -= self.MIN_VAL
            self.val_lungs /= (self.MAX_VAL-self.MIN_VAL)
            self.val_masks = torch.Tensor([i for i in tqdm(self.masks_np[:self.val_num])]).view(-1, 1, 512, 512)
        elif self.split_type == 'test':
            self.test_lungs = torch.Tensor([i for i in tqdm(self.lungs_np[self.val_num:self.val_num*2])]).view(-1, 1, 512, 512)
            self.test_lungs -= self.MIN_VAL
            self.test_lungs /= (self.MAX_VAL-self.MIN_VAL)
            self.test_masks = torch.Tensor([i for i in tqdm(self.masks_np[self.val_num:self.val_num*2])]).view(-1, 1, 512, 512)
        elif self.split_type == 'train':
            self.train_lungs = torch.Tensor([i for i in tqdm(self.lungs_np[self.val_num*2:])]).view(-1, 1, 512, 512)
            self.train_lungs -= self.MIN_VAL
            self.train_lungs /= (self.MAX_VAL-self.MIN_VAL)
            self.train_masks = torch.Tensor([i for i in tqdm(self.masks_np[self.val_num*2:])]).view(-1, 1, 512, 512)
            
    def __len__(self):
        if self.split_type == 'train':
            return self.train_lungs.shape[0]
        elif self.split_type == 'val':
            return self.val_lungs.shape[0]
        elif self.split_type == 'test':
            return self.test_lungs.shape[0]
        
    def __getitem__(self, idx):
        if self.split_type == 'train':
            return self.train_lungs[idx], self.train_masks[idx]
        elif self.split_type == 'val':
            return self.val_lungs[idx], self.val_masks[idx]
        elif self.split_type == 'test':
            return self.test_lungs[idx], self.test_masks[idx]

In [4]:
train = Lung_CT_DATA(lung_CT, lung_masks, split_type = 'train', transforms = None)
val = Lung_CT_DATA(lung_CT, lung_masks, split_type = 'val', transforms = None)
test = Lung_CT_DATA(lung_CT, lung_masks, split_type = 'test', transforms = None)

100%|██████████| 6875/6875 [00:00<00:00, 2296761.45it/s]
100%|██████████| 6875/6875 [00:00<00:00, 1723600.72it/s]
100%|██████████| 859/859 [00:00<?, ?it/s]
100%|██████████| 859/859 [00:00<00:00, 861321.33it/s]
100%|██████████| 859/859 [00:00<00:00, 862352.11it/s]
100%|██████████| 859/859 [00:00<00:00, 884366.01it/s]


In [5]:
BATCH_SIZE = 1
train_loader = DataLoader(train, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test, batch_size = BATCH_SIZE, shuffle = True)

In [6]:
device = torch.device('cpu')

In [7]:
EPOCHS = 1
IN_CHANNELS = 1
net_G = Generator(IN_CHANNELS).to(device)
net_D = Discriminator(IN_CHANNELS).to(device)
optimizer_G = torch.optim.Adam(net_G.parameters())
optimizer_D = torch.optim.Adam(net_D.parameters())
criterion = nn.BCELoss()
print("Starting Training...")
losses_G_train = []
losses_D_train = []
losses_G_val = []
losses_D_val = []
for epoch in range(1, EPOCHS+1):
    index = -1
    for idx, (imgs, masks) in enumerate(train_loader):
        index = idx
        net_G.train()
        net_D.train()
        imgs, masks= imgs.to(device), masks.to(device)
        
        #Discriminator
        net_D.zero_grad()
        output = net_D(masks).view(-1)
        labels = torch.ones(output.shape[0]).reshape(output.shape[0],).to(device)
        lossD_real = criterion(output, labels)
        lossD_real.backward()

        gen_out = net_G(imgs)
        output = net_D(gen_out.detach()).view(-1)
        labels = torch.zeros(output.shape[0]).reshape(output.shape[0],).to(device)
        
        lossD_fake = criterion(output, labels)
        lossD_fake.backward()
        loss_D = lossD_fake + lossD_real
        optimizer_D.step()
        
        #Generator
        net_G.zero_grad()
        output = net_D(gen_out).view(-1)
        labels = torch.ones(output.shape[0]).reshape(output.shape[0],).to(device)
        loss_G = criterion(output, labels)
        loss_G.backward()
        optimizer_G.step()
    with torch.no_grad():
        net_D.eval()
        net_G.eval()
        for (im_val, mask_val) in val_loader:
            #Discriminator
            output_val = net_D(mask_val).view(-1)
            labels_val = torch.ones(output_val.shape[0]).reshape(output_val.shape[0],).to(device)
            lossD_real_val = criterion(output_val, labels_val)

            gen_out_val = net_G(im_val)
            output_val = net_D(gen_out_val.detach()).view(-1)
            labels_val = torch.zeros(output_val.shape[0]).reshape(output_val.shape[0],).to(device)

            lossD_fake_val = criterion(output_val, labels_val)
            loss_D_val = lossD_fake_val + lossD_real_val
            #Generator
            output_val = net_D(gen_out_val).view(-1)
            labels_val = torch.ones(output_val.shape[0]).reshape(output_val.shape[0],).to(device)
            loss_G_val = criterion(output_val, labels_val)

    if index % 100 == 0:
        losses_G_train.append(loss_G.item())
        losses_D_train.append(loss_D.item())
        losses_G_val.append(loss_G_val.item())
        losses_D_val.append(loss_D_val.item())
        print("[%s/%s] [%s/%s]\tLoss D Train: %s\tLoss G Train: %s\tLoss D Val: %s\tLoss G Val: %s" 
              % (epoch, EPOCHS, idx, len(train_loader), loss_G.item(), loss_D.item(), loss_D_val.item(), loss_G_val.item()))
           

Starting Training...


KeyboardInterrupt: 

In [None]:
img, mask = 0, 0
for im, ms in val_loader:
    img = im
    mask = ms
    break

In [None]:
plt.imshow(img[0].view(512, 512), cmap = 'gray')

In [None]:
out_gen = net_G(img[0].view(1,1, 512, 512))
plt.imshow(out_gen.detach().numpy().reshape(512, 512), cmap = 'gray')

In [None]:
plt.imshow(mask[0].view(512, 512), cmap = 'gray')