In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from skimage import io
import utils_xy
from torchvision import transforms, utils
import numpy as np
from PIL import Image
from random import randint
import time
import quickNATv2

torch.Size([4, 256, 128, 128])


In [2]:
gpu_id = 1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

device = torch.device('cuda')
print (device)

cuda


In [3]:
img_dir = '/home/xiaoyu/BMMC/BMMCdata'
mask_dir = '/home/xiaoyu/BMMC/BMMCmasks'

img_list = sorted(os.listdir(img_dir))
mask_list = sorted(os.listdir(mask_dir))

print("Training images numbers: "+str(len(img_list)))
print("Training mask Images numbers:"+str(len(mask_list)))

Training images numbers: 43
Training mask Images numbers:43


In [4]:
class TrainDataset(Dataset):
    """Training dataset with mask image mapping to classes"""
    def __init__(self, img_dir, mask_dir, transform=None):
        """
        Args:
            train_dir (string): Directory with training images
            transform (callable): Optional transform to be applied on a sample
            semantic_dir (string): Directory with semantic segmentation training image
        """
        self.img_dir = img_dir
        self.transform = transform
        self.mask_dir = mask_dir
        
        
    def __len__(self):
        img_list = os.listdir(self.img_dir)
        return len(img_list)
    
    def __getitem__(self, idx):
        img_list = sorted(os.listdir(img_dir))
        mask_list = sorted(os.listdir(mask_dir))
        
        img_str = img_list[idx]
        img_arr = io.imread(os.path.join(img_dir, img_str))
        img_tensor = torch.from_numpy(np.uint8(img_arr))
        img_tensor = img_tensor.permute(2,0,1)
        
        mask_str = mask_list[idx]
        mask_arr = io.imread(os.path.join(mask_dir, mask_str))
        mask_tensor = torch.from_numpy(mask_arr)
        
        # after the compose, the mask value range is [0,1], the data type is torch.float32
        compose = transforms.Compose([transforms.ToPILImage(),
                                      transforms.Resize((384,384),interpolation=Image.NEAREST), 
                                      transforms.ToTensor()])
        
        mask_tensor = torch.unsqueeze(mask_tensor, dim = 0)
        mask_tensor = compose(mask_tensor)
        mask_tensor = mask_tensor.squeeze()
        
        mask_tensor = torch.round(mask_tensor / 0.3).byte()-1
      
        sample = {'image':img_tensor, 'mask':mask_tensor}
        
        if self.transform:
            image = self.transform(img_tensor)
            sample = {'image':image, 'mask':mask_tensor}
        return sample

In [None]:
compose = transforms.Compose([transforms.ToPILImage(), transforms.Resize((384,384),interpolation=Image.NEAREST),
                              transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transformed_dataset = TrainDataset(img_dir=img_dir, mask_dir = mask_dir, transform = compose)

dataloader = DataLoader(transformed_dataset, batch_size = 5, shuffle = True, num_workers = 4)

In [None]:
quicknat = quickNATv2.QuickNAT(3,64,3)
model_params = list(quicknat.parameters())
nb_param=0
for param in quicknat.parameters():
    nb_param+=np.prod(list(param.data.size()))
print(nb_param)


1183240


In [None]:
# quicknat = quicknat.to(device)

In [None]:
criterion = nn.NLLLoss()

In [None]:
optimizer = optim.Adam(quicknat.parameters() ,lr=0.001)


start=time.time()
for epoch in range(1,500):
   
    

    # define the running loss
    running_loss = 0
    running_error = 0
    num_batches=0
      
    for i_batch, sample_batched in enumerate(dataloader):
        
        optimizer.zero_grad()
        
        #get the inputs
        inputs, labels = sample_batched['image'], sample_batched['mask']
        
#         inputs = inputs.to(device)
#         labels = labels.to(device)

        inputs.requires_grad_()
        
        #forward + backward +optimize
        scores = quicknat(inputs)
    
        # Define the loss
        loss = criterion(scores, labels.long()) 
        loss.backward()
        optimizer.step()
        
        # compute and accumulate stats
        running_loss += loss.detach().item()
       
        num_batches+=1 
    
    # AVERAGE STATS THEN DISPLAY    
    total_loss = running_loss/num_batches
   
    elapsed = (time.time()-start)/60
        
    print('epoch=',epoch, '\t time=', elapsed,'min', '\t loss=', total_loss )

       
print('Finish Training')

torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([3, 3, 384, 384])
torch.Size([3, 384, 384])
epoch= 1 	 time= 4.117654740810394 min 	 loss= 1.1086272133721247
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])

torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([3, 3, 384, 384])
torch.Size([3, 384, 384])
epoch= 15 	 time= 61.74028302431107 min 	 loss= 1.1142036782370672
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([3, 3, 384, 384])
torch.Size([3, 384, 384])
epoch= 16 	 time= 65.93143110275268 min 	 loss= 1.0962734619776409
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([5, 384, 384])
torch.Size([5, 3, 384, 384])
torch.Size([