In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as functional
from torch import optim

In [2]:
def conv3x3(in_channels: int, out_channels: int, padding=0):
	return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=padding)


def max_pool_2d():
	return nn.MaxPool2d(kernel_size=2, stride=2)

# Net most important blocks

In [3]:
class UnetEncodeLayer(nn.Module):
    # just a standard convolution layer.
	def __init__(self, in_channels: int, out_channels: int, activated=True,max_pool=False, padding=0):
		super(UnetEncodeLayer, self).__init__()
		layers = [
            conv3x3(in_channels, out_channels, padding=padding),
			nn.BatchNorm2d(out_channels),
            # nn.BatchNorm2d(out_channels),
        ]
		if activated:
			layers += [nn.ReLU()]
		if max_pool:
			layers += [max_pool_2d()]
		self.layer = nn.Sequential(*layers)
	
	def forward(self, x):
		return self.layer(x)
	
class UnetUpscaleLayer(nn.Module):
	def __init__(self, scale_factor, in_channels):
		super(UnetUpscaleLayer, self).__init__()
		layers = [
			nn.Upsample(scale_factor = (scale_factor,scale_factor), mode = 'bilinear'),
			conv3x3(in_channels, in_channels//2,padding=1)
		]
		self.layer = nn.Sequential(*layers)
	def forward(self, x):
		return self.layer(x)

class UnetForwardDecodeLayer(nn.Module):
	def __init__(self, in_channels, out_channels, padding=0):
		super(UnetForwardDecodeLayer, self).__init__()
		layers = [
			conv3x3(in_channels=in_channels, out_channels=out_channels, padding=padding),
			nn.ReLU(),
			nn.BatchNorm2d(out_channels),
			conv3x3(in_channels=out_channels, out_channels=out_channels, padding=padding),
			nn.ReLU(),
			nn.BatchNorm2d(out_channels),
		]
		self.layer = nn.Sequential(*layers)
	def forward(self, x):
		return self.layer(x)

# U-Net base structure

In [4]:
class Urnet(nn.Module):	
	def __init__(self):
		super(Urnet, self).__init__()
		self.residuals = []
    	# encoding part of the Unet vanilla architecture
		self.encode1 = nn.Sequential(
			UnetEncodeLayer(3, 64, padding=1),
			UnetEncodeLayer(64, 64, padding=1), ## keep dimensions unchanged
		)
		self.encode2 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(64, 128, padding=1),
			UnetEncodeLayer(128, 128, padding=1),
		)
		self.encode3 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
			UnetEncodeLayer(128, 256, padding=1),
			UnetEncodeLayer(256, 256, padding=1),
		)
		self.encode4 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(256, 512, padding=1),
			UnetEncodeLayer(512, 512, padding=1),
		)
		self.encode5 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(512, 1024, padding=1),
			UnetEncodeLayer(1024, 1024, padding=1),
		)
		self.upscale1 = nn.Sequential(
			UnetUpscaleLayer(2, 1024)
		)
		self.decode_forward1 = nn.Sequential(
			UnetForwardDecodeLayer(1024,512, padding=1)
		)
		self.upscale2 = nn.Sequential(
			UnetUpscaleLayer(2, 512)
		)
		self.decode_forward2 = nn.Sequential(
			UnetForwardDecodeLayer(512, 256, padding=1)
		)
		self.upscale3 = nn.Sequential(
			UnetUpscaleLayer(2,256)
		)
		self.decode_forward3 = nn.Sequential(
			UnetForwardDecodeLayer(256,128,padding=1)
		)
		self.upscale4 = nn.Sequential(
			UnetUpscaleLayer(2,128)
		)
		self.decode_forward4 = nn.Sequential(
			UnetForwardDecodeLayer(128,64, padding=1),
			nn.Conv2d(64, 6, kernel_size=1) # final conv 1x1
			# Model output is 6xHxW, so we have a prob. distribution
			# for each pixel (each pixel has a logit for each of the 6 classes.)
		)
	def forward(self, x: torch.Tensor):
		self.x1 = self.encode1(x)
		self.x2 = self.encode2(self.x1)
		self.x3 = self.encode3(self.x2)
		self.x4 = self.encode4(self.x3)
		self.x5 = self.encode5(self.x4)

		y1 = self.upscale1(self.x5)
		c1 = torch.concat((self.x4, y1), 1)
		y2 = self.decode_forward1(c1)
		
		y2 = self.upscale2(y2)
		c2 = torch.concat((self.x3, y2), 1)
		y3 = self.decode_forward2(c2)

		y3 = self.upscale3(y3)
		c3 = torch.concat((functional.center_crop(y3, 150), self.x2), 1)
		y4 = self.decode_forward3(c3)

		y4 = self.upscale4(y4)
		c4 = torch.concat((self.x1, y4), 1)
		segmap = self.decode_forward4(c4)
		return segmap

net = Urnet()
x = torch.rand((1,3,300,300))
output = net(x)

# Vision Transformer Base Structure

In [5]:

class PositionalEncoding(nn.Module):
	# I leave it here but it's not needed if we use convolutional tokenizer.
	def __init__(self, D, num_patches):
		super(PositionalEncoding, self).__init__()
		self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, D))
	def forward(self, x):
		return x + self.pos_embedding
	
class VisionTransformer(nn.Module):
	# D = embedding dimension (patch is p*p*3 and will be projected to be D dimensional)
	# N = number of patches
	# p = patch size	
	def __init__(self, D, num_heads):
		super(VisionTransformer, self).__init__()
		# self.linear_projection = nn.Linear(p*p*3, D) don't need it in this architecture
		# self.positional_encoding = PositionalEncoding(D, N) neither this.
		self.layer_norm1 = nn.LayerNorm(D)
		self.layer_norm2 = nn.LayerNorm(D)
		self.MHA = nn.MultiheadAttention(embed_dim=D, num_heads=num_heads, batch_first=True)
		self.mlp = nn.Sequential(
			# using D*4 hidden size according to original vision transformer paper
			nn.Linear(D, D*4),
            nn.GELU(),
            nn.Linear(D*4, D)
		)
		# we should have one of this for each head
	def forward(self, x):
		#x = self.linear_projection(x) # N, p*p*3 --> N, D
		#self.r1 = self.positional_encoding(x) # add positional encoding to x, embedded patches
		self.r1 = x		
		x = self.layer_norm1(self.r1)		
		x = self.MHA(x,x,x)[0]
		self.r2 = x + self.r1
		x = self.layer_norm2(x)
		x = self.mlp(x)
		return x + self.r2

def vision_transformer(D,num_heads):
    return VisionTransformer(D,num_heads)
    
class VisionTransformerEncoder(nn.Module):    
    def __init__(self, D, num_heads, layers):
        super(VisionTransformerEncoder, self).__init__()
        self.layers =[vision_transformer(D,num_heads) for _ in range(layers)]
        self.stack = nn.Sequential(*self.layers)
    def forward(self, x):
        return self.stack(x)

# Tunet(Transformer + Unet) base structure

In [6]:
class Tunet(nn.Module): # Unet + vision transformer
    def __init__(self, d, heads, layers):
        super(Tunet, self).__init__()
        self.d = d
        self.heads = heads
        self.layers = layers
        self.residuals = []        
        # encoding part of the Unet vanilla architecture
        self.encode1 = nn.Sequential(
            UnetEncodeLayer(3, d//16, padding=1),
        )
        self.encode2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            UnetEncodeLayer(d//16, d//8, padding=1),            
        )
        self.encode3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
            UnetEncodeLayer(d//8, d//4, padding=1),            
        )
        self.encode4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            UnetEncodeLayer(d//4, d//2, padding=1),
        )
        self.encode5 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),            
            UnetEncodeLayer(d//2, d, padding=1),
        )
        self.transformer = VisionTransformerEncoder(self.d,self.heads,self.layers)
        self.upscale1 = nn.Sequential(
            UnetUpscaleLayer(2, d)
        )
        self.decode_forward1 = nn.Sequential(
            UnetForwardDecodeLayer(d,d//2, padding=1)
        )
        self.upscale2 = nn.Sequential(
            UnetUpscaleLayer(2, d//2)
        )
        self.decode_forward2 = nn.Sequential(
            UnetForwardDecodeLayer(d//2, d//4, padding=1)
        )
        self.upscale3 = nn.Sequential(
            UnetUpscaleLayer(2,d//4)
        )
        self.decode_forward3 = nn.Sequential(
            UnetForwardDecodeLayer(d//4,d//8,padding=1)
        )
        self.upscale4 = nn.Sequential(
            UnetUpscaleLayer(2,d//8)
        )
        self.decode_forward4 = nn.Sequential(
            UnetForwardDecodeLayer(d//8,d//16, padding=1),
            nn.Conv2d(d//16, 6, kernel_size=1) # final conv 1x1
            # Model output is 6xHxW, so we have a prob. distribution
            # for each pixel (each pixel has a logit for each of the 6 classes.)
        )
    def forward(self,x):
        self.x1 = self.encode1(x)
        self.x2 = self.encode2(self.x1)
        self.x3 = self.encode3(self.x2)
        self.x4 = self.encode4(self.x3)       
        self.x5 = self.encode5(self.x4)        
        _, _,h,w = self.x5.shape
        self.sequence = self.x5.reshape(-1,h*w, self.d)
        attention_encoded = self.transformer(self.sequence)
        attention_encoded = attention_encoded.reshape(-1,self.d,h,w)
        y1 = self.upscale1(attention_encoded)
        c1 = torch.concat((self.x4, y1), 1)
        y2 = self.decode_forward1(c1)

        y2 = self.upscale2(y2)
        c2 = torch.concat((self.x3, y2), 1)
        y3 = self.decode_forward2(c2)

        y3 = self.upscale3(y3)
        c3 = torch.concat((functional.center_crop(y3, 150), self.x2), 1)
        y4 = self.decode_forward3(c3)

        y4 = self.upscale4(y4)
        c4 = torch.concat((self.x1, y4), 1)
        segmap = self.decode_forward4(c4)
        return segmap


# Converter
+ class containing some utils to convert segmentations masks.

In [7]:
class Converter:
	def __init__(self):
		self.color_to_label = {
            (1, 1, 0): 0,  # Yellow (cars)
            (0, 1, 0): 1, # Green (trees)
            (0, 0, 1): 2, # Blue (buildings)
            (1, 0, 0): 3,  # Red (clutter)
            (1, 1, 1): 4, # White(impervious surface),
            (0, 1, 1): 5 # Aqua (low vegetation)
        }
	def iconvert(self, mask):
		"""
		Function needed to convert the class label mask needed by CrossEntropy Function
		to the original mask.
		input: class label mask, HxW
		output: original mask, HxWx3
		"""
		H,W = mask.shape
		colors = torch.tensor(list(self.color_to_label.keys())).type(torch.float64)
		labels = torch.tensor(list(self.color_to_label.values())).type(torch.float64)
		output = torch.ones(H,W,3).type(torch.float64)
		for color, label in zip(colors, labels):
			match = (mask == label)
			output[match] = color
		return output
	def convert(self,mask):
		"""
		Function needed to convert the RGB (Nx3x300x300) mask into a 
		'class label mask' needed when computing the loss function.
		In this new representation for each pixel we have a value
		between [0,C) where C is the number of classes, so 6 in this case.
		This new tensor will have shape Nx300x300.
		"""			
		C,H,W = mask.shape
		colors = torch.tensor(list(self.color_to_label.keys()))
		labels = torch.tensor(list(self.color_to_label.values()))
		reshaped_mask = mask.permute(1, 2, 0).reshape(-1, 3)
		class_label_mask = torch.zeros(reshaped_mask.shape[0], dtype=torch.long)
		for color, label in zip(colors, labels):
			match = (reshaped_mask == color.type(torch.float64)).all(dim=1)
			class_label_mask[match] = label
		class_label_mask = class_label_mask.reshape(H,W)		
		return class_label_mask

# Dataset Class

In [8]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import torchvision.transforms as v2
from torch.utils.data import ConcatDataset

class PostDamDataset(Dataset):
	def __init__(self, img_dir, masks_dir, extension,transforms=None):
		self.idir = img_dir
		self.mdir = masks_dir		
		self.transforms = transforms		
		self.extension = extension
		self.items = os.listdir(self.idir)
		self.files = [item for item in self.items if os.path.isfile(os.path.join(self.idir, item))]
		self.c = Converter()
	def __len__(self):
		return len(self.files)
	def __getitem__(self, idx):
		img_path = os.path.join(self.idir, "Image_{}{}".format(idx, self.extension))
		mask_path = os.path.join(self.mdir, "Label_{}{}".format(idx, self.extension))
		tif_img = Image.open(img_path)
		tif_mask = Image.open(mask_path)
		if self.transforms:  # if transforms are provided, apply them
			final_image = self.transforms(ToTensor()(tif_img))
		# no transform is applied on mask obv.
		else:
			final_image = ToTensor()(tif_img)
		return (final_image, self.c.convert(ToTensor()(tif_mask)), idx)

transforms = v2.Compose([
    v2.GaussianBlur(kernel_size=(15), sigma=5),
    v2.ElasticTransform(alpha=200.0)    
])


# Dataset preparation
1) Must define path of images and labels;
2) Must define extension used for both images and labels;

In [9]:
#dataset = PostDamDataset("/content/drive/MyDrive/Postdam/Images", "/content/drive/MyDrive/Postdam/Labels")
# Here you need to specify dataset path for both images and labels.
# WARNING: dataset is pre-loaded in memory, so high RAM usage is expected.
images_path = "C:\\Users\\eros\\CVCS\\dataset\\Postdam_300x300_full\\Images"
labels_path = "C:\\Users\\eros\\CVCS\\dataset\\Postdam_300x300_full\\Labels"
extension = ".png"

base_dataset = PostDamDataset(images_path, labels_path, extension)
augmented_dataset = PostDamDataset(images_path, labels_path,extension, transforms=transforms)
dataset = ConcatDataset([base_dataset, augmented_dataset])

assert torch.cuda.is_available(), "Notebook is not configured properly!"
device = 'cuda:0'
print("Training network on {}".format(torch.cuda.get_device_name(device=device)))
net = Tunet(768, 12, 6).to(device)
num_params = sum([np.prod(p.shape) for p in net.parameters()])
print(f"Number of parameters : {num_params}")
print("Dataset length: {}".format(dataset.__len__()))

Training network on NVIDIA GeForce GTX 1060 6GB
Number of parameters : 54876246
Dataset length: 30400


# Dataloaders initializations

In [10]:
#Dataset train/validation split according to validation split and seed.
batch_size = 4
validation_split = .2
random_seed= 42

dataset_size = len(dataset)
base_indices = list(range(dataset_size//2))
np.random.seed(random_seed)
np.random.shuffle(base_indices)
augmented_indices = [i+(len(dataset)//2) for i in base_indices] # take coresponding augmented images
split = int(np.floor((1-validation_split) * (dataset_size//2)))

train_indices = base_indices[:split]+augmented_indices[:split]
val_base_indices = base_indices[split:]
val_noisy_indices = augmented_indices[split:]

train_sampler = SubsetRandomSampler(train_indices)
valid_base_sampler = SubsetRandomSampler(val_base_indices)
valid_noisy_sampler = SubsetRandomSampler(val_noisy_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
#for validation loader batch size is default, so 1.
validation_base_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_base_sampler)
validation_noisy_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_noisy_sampler)

print(f"Train dataset split(augmented): {len(train_indices)}")
print(f"Validation dataset split: {len(val_base_indices)}")
print(f"Validation dataset split(only noise): {len(val_noisy_indices)}")
crit = nn.CrossEntropyLoss()
opt = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.999)

Train dataset split(augmented): 24320
Validation dataset split: 3040
Validation dataset split(only noise): 3040


# Validation Function

In [11]:
from sklearn.metrics import jaccard_score as jsc
def eval_model(net, validation_loader, validation_len, show_progress = False, write_output=False, prefix=""):
    # returns (macro, weighted) IoU
    c = Converter()
    macro = 0
    weighted = 0
    if show_progress:
        pbar = tqdm(total=validation_len)
    with torch.no_grad():
        net.eval()
        for i, (x,y, index) in enumerate(validation_loader):
            x, y = x.to(device), y.to(device)
            y_pred = net(x)
            x_ref = x.cpu()
            y_pred = y_pred.squeeze().cpu()
            _,pred_mask = torch.max(y_pred, dim=0)

            prediction = pred_mask.cpu().numpy().reshape(-1)
            target = y.cpu().numpy().reshape(-1)        
            weighted += jsc(target,prediction, average='weighted') # takes into account label imbalance
            macro += jsc(target,prediction, average='macro') # simple mean over each class.            
            if(write_output):
                fig ,axarr = plt.subplots(1,3)
                _,target_transformed_mask,_ = dataset.__getitem__(index.item())

                axarr[0].title.set_text('Original Image')
                axarr[0].imshow(x_ref.squeeze().swapaxes(0,2).swapaxes(0,1))

                axarr[1].title.set_text('Model Output')
                axarr[1].imshow(c.iconvert(pred_mask))

                axarr[2].title.set_text('Original Mask')
                axarr[2].imshow(c.iconvert(target_transformed_mask))
                plt.savefig("output\\{}_Image{}.png".format(prefix, i))
                plt.close(fig)
            if show_progress:
                pbar.update(1)
    macro_score = macro / validation_len
    weighted_score = weighted / validation_len
    if show_progress:
        pbar.close()
    if(write_output):
        print("Macro IoU score: {}".format(macro_score))        
        print("Weigthed IoU score: {}".format(weighted_score))
    return macro_score, weighted_score
    
    

# Training Loop
+ Before launching, make sure to define the local directory where checkpoints will be saved.

In [12]:
from pathlib import Path

epochs = 40
validate = True # set to validate also during training
loss_values = []
macro_IoU = []
weighted_IoU = []
checkpoint_directory = "D:\\Models\\Tunet1\\checkpoints"
if not Path(checkpoint_directory).is_dir():
    print("Please provide a valid directory to save checkpoints in.")
else:
    for epoch in range(epochs):
        cumulative_loss = 0
        tot = 0
        pbar = tqdm(total=len(train_loader), desc=f'Epoch {epoch}')
        net.train()
        for batch_index, (image, mask, _) in enumerate(train_loader):
            tot+=1
            image, mask = image.to(device), mask.to(device)        
            mask_pred = net(image).to(device)
            loss = crit(mask_pred, mask)
            cumulative_loss += loss.item()
            opt.zero_grad()
            loss.backward()
            opt.step()
            pbar.update(1)
            pbar.set_postfix({'Loss': cumulative_loss/tot})
        pbar.close()
        loss_values.append(cumulative_loss/tot)
        if validate:
            # run evaluation!
            # 1) Re-initialize data loaders
            valid_base_sampler = SubsetRandomSampler(val_base_indices)
            validation_base_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_base_sampler)
            # 2) Call evaluation Loop
            macro, weighted = eval_model(net, validation_base_loader, len(val_base_indices),show_progress=False, write_output=False)
            # 3) Append results to list    
            macro_IoU.append(macro)    
            weighted_IoU.append(weighted)
        if (epoch+1) % 2 == 0: # save checkpoint every 2 epochs
            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': loss.item(),
                }, "{}\\checkpoint{}".format(checkpoint_directory, (epoch+1)))
    print("Training Done!")
    plt.plot(loss_values)
    plt.show()
    #TODO:  better to save stuff in a csv file
    with open("weighted_IoU.txt", "w") as f:
        for value in weighted_IoU:
            f.write(str(value)+"\n")
    with open("loss.txt", "w") as f:
        for value in loss:
            f.write(str(value)+"\n")
    with open("macro_IoU.txt", "w") as f:
        for value in macro_IoU:
            f.write(str(value)+"\n")

Epoch 0: 100%|██████████| 6080/6080 [1:05:57<00:00,  1.54it/s, Loss=1.13]
Epoch 1: 100%|██████████| 6080/6080 [1:05:44<00:00,  1.54it/s, Loss=1.02]
Epoch 2: 100%|██████████| 6080/6080 [1:05:28<00:00,  1.55it/s, Loss=0.947]
Epoch 3: 100%|██████████| 6080/6080 [1:05:50<00:00,  1.54it/s, Loss=0.89] 
Epoch 4: 100%|██████████| 6080/6080 [1:05:07<00:00,  1.56it/s, Loss=0.841]
Epoch 5: 100%|██████████| 6080/6080 [1:05:07<00:00,  1.56it/s, Loss=0.803]
Epoch 6: 100%|██████████| 6080/6080 [1:05:06<00:00,  1.56it/s, Loss=0.759]
Epoch 7: 100%|██████████| 6080/6080 [1:05:06<00:00,  1.56it/s, Loss=0.704]
Epoch 8: 100%|██████████| 6080/6080 [1:05:07<00:00,  1.56it/s, Loss=0.67] 
Epoch 9: 100%|██████████| 6080/6080 [1:05:07<00:00,  1.56it/s, Loss=0.62] 
Epoch 10: 100%|██████████| 6080/6080 [1:05:51<00:00,  1.54it/s, Loss=0.582]
Epoch 11: 100%|██████████| 6080/6080 [1:05:21<00:00,  1.55it/s, Loss=0.542]
Epoch 12: 100%|██████████| 6080/6080 [1:05:33<00:00,  1.55it/s, Loss=0.483]
Epoch 13: 100%|█████████

KeyboardInterrupt: 

# General Utils (save/load/evaluate)

In [None]:
# Single model evaluation, write output to file.
# call validation function to evaluate model. Remember to re-initialize loader and sampler.
valid_noisy_sampler = SubsetRandomSampler(val_noisy_indices)
validation_noisy_loader = torch.utils.data.DataLoader(dataset ,sampler=valid_noisy_sampler)
macro,weighted = eval_model(net, validation_noisy_loader, len(val_base_indices),show_progress=True, write_output=True)

In [11]:
# Load entire model (inference)
net = torch.load("D:\\Models\\urnet3.3\\urnet3_3.pt")

In [None]:
# Load model checkpoint (to resume training)
checkpoint = torch.load("D:\\Models\\urnet3\\checkpoint")
net.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [15]:
# uncomment and run to save model in the speicifed path
# SAVE model (only weights). Save locally or on a gitignore directory
torch.save(net, "D:\\Models\\urnet3.3\\urnet3_3.pt")

In [14]:
# write loss values on file
with open("weighted_IoU.txt", "w") as f:
    for value in weighted_IoU:
        f.write(str(value)+"\n")

In [11]:
# save training checkpoint (optimizer state as well.)
torch.save({
            'epoch': 150,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss': 0.0435,
            }, "D:\\Models\\urnet3.3\\checkpoint")