In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as functional
from torch import optim

In [2]:
def conv3x3(in_channels: int, out_channels: int, padding=0):
	return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=padding)


def max_pool_2d():
	return nn.MaxPool2d(kernel_size=2, stride=2)

# Net most important blocks

In [3]:
class UnetEncodeLayer(nn.Module):
    # just a standard convolution layer.
	def __init__(self, in_channels: int, out_channels: int, activated=True,max_pool=False, padding=0):
		super(UnetEncodeLayer, self).__init__()
		layers = [
            conv3x3(in_channels, out_channels, padding=padding),
			nn.BatchNorm2d(out_channels),
            # nn.BatchNorm2d(out_channels),
        ]
		if activated:
			layers += [nn.ReLU()]
		if max_pool:
			layers += [max_pool_2d()]
		self.layer = nn.Sequential(*layers)
	
	def forward(self, x):
		return self.layer(x)
	
class UnetUpscaleLayer(nn.Module):
	def __init__(self, scale_factor, in_channels):
		super(UnetUpscaleLayer, self).__init__()
		layers = [
			nn.Upsample(scale_factor = (scale_factor,scale_factor), mode = 'bilinear'),
			conv3x3(in_channels, in_channels//2,padding=1)
		]
		self.layer = nn.Sequential(*layers)
	def forward(self, x):
		return self.layer(x)

class UnetForwardDecodeLayer(nn.Module):
	def __init__(self, in_channels, out_channels, padding=0):
		super(UnetForwardDecodeLayer, self).__init__()
		layers = [
			conv3x3(in_channels=in_channels, out_channels=out_channels, padding=padding),
			nn.ReLU(),
			nn.BatchNorm2d(out_channels),
			conv3x3(in_channels=out_channels, out_channels=out_channels, padding=padding),
			nn.ReLU(),
			nn.BatchNorm2d(out_channels),
		]
		self.layer = nn.Sequential(*layers)
	def forward(self, x):
		return self.layer(x)

# Net base structure

In [4]:
class Urnet(nn.Module):	
	def __init__(self):
		super(Urnet, self).__init__()
		self.residuals = []
    	# encoding part of the Unet vanilla architecture
		self.encode1 = nn.Sequential(
			UnetEncodeLayer(3, 64, padding=1),
			UnetEncodeLayer(64, 64, padding=1), ## keep dimensions unchanged
		)
		self.encode2 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(64, 128, padding=1),
			UnetEncodeLayer(128, 128, padding=1),
		)
		self.encode3 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
			UnetEncodeLayer(128, 256, padding=1),
			UnetEncodeLayer(256, 256, padding=1),
		)
		self.encode4 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(256, 512, padding=1),
			UnetEncodeLayer(512, 512, padding=1),
		)
		self.encode5 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(512, 1024, padding=1),
			UnetEncodeLayer(1024, 1024, padding=1),
		)
		self.upscale1 = nn.Sequential(
			UnetUpscaleLayer(2, 1024)
		)
		self.decode_forward1 = nn.Sequential(
			UnetForwardDecodeLayer(1024,512, padding=1)
		)
		self.upscale2 = nn.Sequential(
			UnetUpscaleLayer(2, 512)
		)
		self.decode_forward2 = nn.Sequential(
			UnetForwardDecodeLayer(512, 256, padding=1)
		)
		self.upscale3 = nn.Sequential(
			UnetUpscaleLayer(2,256)
		)
		self.decode_forward3 = nn.Sequential(
			UnetForwardDecodeLayer(256,128,padding=1)
		)
		self.upscale4 = nn.Sequential(
			UnetUpscaleLayer(2,128)
		)
		self.decode_forward4 = nn.Sequential(
			UnetForwardDecodeLayer(128,64, padding=1),
			nn.Conv2d(64, 6, kernel_size=1) # final conv 1x1
			# Model output is 6xHxW, so we have a prob. distribution
			# for each pixel (each pixel has a logit for each of the 6 classes.)
		)	
	def forward(self, x: torch.Tensor):
		self.x1 = self.encode1(x)
		self.x2 = self.encode2(self.x1)
		self.x3 = self.encode3(self.x2)
		self.x4 = self.encode4(self.x3)
		self.x5 = self.encode5(self.x4)

		y1 = self.upscale1(self.x5)
		c1 = torch.concat((self.x4, y1), 1)
		y2 = self.decode_forward1(c1)
		
		y2 = self.upscale2(y2)
		c2 = torch.concat((self.x3, y2), 1)
		y3 = self.decode_forward2(c2)

		y3 = self.upscale3(y3)
		c3 = torch.concat((functional.center_crop(y3, 150), self.x2), 1)
		y4 = self.decode_forward3(c3)

		y4 = self.upscale4(y4)
		c4 = torch.concat((self.x1, y4), 1)
		segmap = self.decode_forward4(c4)
		return segmap
		


In [5]:
class Converter:
	def __init__(self):
		self.color_to_label = {
            (1, 1, 0): 0,  # Yellow (cars)
            (0, 1, 0): 1, # Green (trees)
            (0, 0, 1): 2, # Blue (buildings)
            (1, 0, 0): 3,  # Red (clutter)
            (1, 1, 1): 4, # White(impervious surface),
            (0, 1, 1): 5 # Aqua (low vegetation)
        }
	def iconvert(self, mask):
		"""
		Function needed to convert the class label mask needed by CrossEntropy Function
		to the original mask.
		input: class label mask, HxW
		output: original mask, HxWx3
		"""
		H,W = mask.shape
		colors = torch.tensor(list(self.color_to_label.keys())).type(torch.float64)
		labels = torch.tensor(list(self.color_to_label.values())).type(torch.float64)
		output = torch.ones(H,W,3).type(torch.float64)
		for color, label in zip(colors, labels):
			match = (mask == label)
			output[match] = color
		return output
	def convert(self,mask):
		"""
		Function needed to convert the RGB (Nx3x300x300) mask into a 
		'class label mask' needed when computing the loss function.
		In this new representation for each pixel we have a value
		between [0,C) where C is the number of classes, so 6 in this case.
		This new tensor will have shape Nx300x300.
		"""			
		C,H,W = mask.shape
		colors = torch.tensor(list(self.color_to_label.keys()))
		labels = torch.tensor(list(self.color_to_label.values()))
		reshaped_mask = mask.permute(1, 2, 0).reshape(-1, 3)
		class_label_mask = torch.zeros(reshaped_mask.shape[0], dtype=torch.long)
		for color, label in zip(colors, labels):
			match = (reshaped_mask == color.type(torch.float64)).all(dim=1)
			class_label_mask[match] = label
		class_label_mask = class_label_mask.reshape(H,W)		
		return class_label_mask

# Dataset Class

In [6]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import torchvision.transforms as v2
from torch.utils.data import ConcatDataset



class PostDamDataset(Dataset):
	def __init__(self, img_dir, masks_dir,transforms=None):
		self.idir = img_dir
		self.mdir = masks_dir
		self.data = {} # index : (image, mask)
		self.transforms = transforms		
		self.items = os.listdir(self.idir)
		self.files = [item for item in self.items if os.path.isfile(os.path.join(self.idir, item))]
		self.c = Converter()
		pbar = tqdm(total=len(self.files), desc='Loading dataset...')
		for idx in range(len(self.files)):			
			img_path = os.path.join(self.idir, "Image_{}.tif".format(idx))
			mask_path = os.path.join(self.mdir, "Label_{}.tif".format(idx))
			tif_img = Image.open(img_path)
			tif_mask = Image.open(mask_path)	
			if self.transforms:  # if transforms are provided, apply them
				final_image = self.transforms(ToTensor()(tif_img))
			# no transform is applied on mask obv.
			else:
				final_image = ToTensor()(tif_img)
			self.data[idx] = (final_image, self.c.convert(ToTensor()(tif_mask)), idx)
			pbar.update(1)
		pbar.close()
	def __len__(self):
		return len(self.files)
	def __getitem__(self, idx):
		return self.data[idx]

transforms = v2.Compose([
    v2.GaussianBlur(kernel_size=(15), sigma=5),
    v2.ElasticTransform(alpha=200.0)    
])


# Dataset preparation

In [7]:
#dataset = PostDamDataset("/content/drive/MyDrive/Postdam/Images", "/content/drive/MyDrive/Postdam/Labels")
# Here you need to specify dataset path for both images and labels.
# WARNING: dataset is pre-loaded in memory, so high RAM usage is expected.
images_path = "C:\\Users\\eros\\CVCS\\dataset\\Cropped_Postdam\\Postdam\\Images"
labels_path = "C:\\Users\\eros\\CVCS\\dataset\\Cropped_Postdam\\Postdam\\Labels"
base_dataset = PostDamDataset(images_path, labels_path)
augmented_dataset = PostDamDataset(images_path, labels_path, transforms=transforms)
dataset = ConcatDataset([base_dataset, augmented_dataset])

assert torch.cuda.is_available(), "Notebook is not configured properly!"
device = 'cuda:0'
print("Training network on {}".format(torch.cuda.get_device_name(device=device)))
net = Urnet().to(device)
num_params = sum([np.prod(p.shape) for p in net.parameters()])
print(f"Number of parameters : {num_params}")
print("Dataset length: {}".format(dataset.__len__()))

Loading dataset...: 100%|██████████| 2400/2400 [00:15<00:00, 151.55it/s]
Loading dataset...: 100%|██████████| 2400/2400 [01:28<00:00, 26.99it/s]


Training network on NVIDIA GeForce GTX 1060 6GB
Number of parameters : 34525446
Dataset length: 4800


# Dataloaders initializations

In [8]:
#Dataset train/validation split according to validation split and seed.
batch_size = 4
validation_split = .2
random_seed= 42

dataset_size = len(dataset) #4800
base_indices = list(range(int(dataset_size/2)))#0...2399
np.random.seed(random_seed)
np.random.shuffle(base_indices)
augmented_indices = [i+2400 for i in base_indices] # take coresponding augmented images
split = int(np.floor((1-validation_split) * (dataset_size//2)))

train_indices = base_indices[:split]+augmented_indices[:split]
val_base_indices = base_indices[split:]
val_noisy_indices = augmented_indices[split:]

train_sampler = SubsetRandomSampler(train_indices)
valid_base_sampler = SubsetRandomSampler(val_base_indices)
valid_noisy_sampler = SubsetRandomSampler(val_noisy_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
#for validation loader batch size is default, so 1.
validation_base_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_base_sampler)
validation_noisy_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_noisy_sampler)

print(f"Train dataset split(augmented): {len(train_indices)}")
print(f"Validation dataset split: {len(val_base_indices)}")
print(f"Validation dataset split(only noise): {len(val_noisy_indices)}")
crit = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.999)

Train dataset split(augmented): 3840
Validation dataset split: 480
Validation dataset split(only noise): 480


# Validation Function

In [9]:
from sklearn.metrics import jaccard_score as jsc
def eval_model(net, validation_loader, validation_len, show_progress = False, write_output=False, prefix=""):
    # returns (macro, weighted) IoU
    c = Converter()
    macro = 0
    weighted = 0
    if show_progress:
        pbar = tqdm(total=validation_len)
    with torch.no_grad():
        net.eval()
        for i, (x,y, index) in enumerate(validation_loader):
            x, y = x.to(device), y.to(device)
            y_pred = net(x)
            x_ref = x.cpu()
            y_pred = y_pred.squeeze().cpu()
            _,pred_mask = torch.max(y_pred, dim=0)

            prediction = pred_mask.cpu().numpy().reshape(-1)
            target = y.cpu().numpy().reshape(-1)        
            weighted += jsc(target,prediction, average='weighted') # takes into account label imbalance
            macro += jsc(target,prediction, average='macro') # simple mean over each class.            
            if(write_output):
                fig ,axarr = plt.subplots(1,3)
                _,target_transformed_mask,_ = dataset.__getitem__(index.item())

                axarr[0].title.set_text('Original Image')
                axarr[0].imshow(x_ref.squeeze().swapaxes(0,2).swapaxes(0,1))

                axarr[1].title.set_text('Model Output')
                axarr[1].imshow(c.iconvert(pred_mask))

                axarr[2].title.set_text('Original Mask')
                axarr[2].imshow(c.iconvert(target_transformed_mask))
                plt.savefig("output\\{}_Image{}.png".format(prefix, i))
                plt.close(fig)
            if show_progress:
                pbar.update(1)
    macro_score = macro / validation_len
    weighted_score = weighted / validation_len
    if show_progress:
        pbar.close()
    if(write_output):
        print("Macro IoU score: {}".format(macro_score))        
        print("Weigthed IoU score: {}".format(weighted_score))
    return macro_score, weighted_score
    
    

# Training Loop

In [None]:
epochs = 150
validate = True # set to validate also during training
loss_values = []
macro_IoU = []
weighted_IoU = []
c = Converter()

for epoch in range(epochs):
    cumulative_loss = 0
    tot = 0
    pbar = tqdm(total=len(train_loader), desc=f'Epoch {epoch}')
    net.train()
    for batch_index, (image, mask, _) in enumerate(train_loader):
        tot+=1
        image, mask = image.to(device), mask.to(device)        
        mask_pred = net(image).to(device)
        loss = crit(mask_pred, mask)
        cumulative_loss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.update(1)
        pbar.set_postfix({'Loss': cumulative_loss/tot})
    pbar.close()
    loss_values.append(cumulative_loss/tot)
    if validate:
        # run evaluation!
        # 1) Re-initialize data loaders
        valid_base_sampler = SubsetRandomSampler(val_base_indices)
        validation_base_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_base_sampler)
        # 2) Call evaluation Loop
        macro, weighted = eval_model(net, validation_base_loader, len(val_base_indices),show_progress=True, write_output=False)
        # 3) Append results to list    
        macro_IoU.append(macro)    
        weighted_IoU.append(weighted)
print("Training Done!")
plt.plot(loss_values)
plt.show()

# General Utils (save/load/evaluate)

In [None]:
# Single model evaluation, write output to file.
# call validation function to evaluate model. Remember to re-initialize loader and sampler.
valid_noisy_sampler = SubsetRandomSampler(val_noisy_indices)
validation_noisy_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_noisy_sampler)
macro,weighted = eval_model(net, validation_noisy_loader, len(val_base_indices),show_progress=True, write_output=True)

In [11]:
# Load entire model (inference)
net = torch.load("D:\\Models\\urnet3.3\\urnet3_3.pt")

In [None]:
# Load model checkpoint (to resume training)
checkpoint = torch.load("D:\\Models\\urnet3\\checkpoint")
net.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [15]:
# uncomment and run to save model in the speicifed path
# SAVE model (only weights). Save locally or on a gitignore directory
torch.save(net, "D:\\Models\\urnet3.3\\urnet3_3.pt")

In [14]:
# write loss values on file
with open("weighted_IoU.txt", "w") as f:
    for value in weighted_IoU:
        f.write(str(value)+"\n")

In [11]:
# save training checkpoint (optimizer state as well.)
torch.save({
            'epoch': 150,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss': 0.0435,
            }, "D:\\Models\\urnet3.3\\checkpoint")