In [None]:
# make a vgg-fcn-8s , use vgg model and lets just start it
import torch.nn as nn
import torch.nn.functional as F
from abc import ABCMeta
import torchvision.models as models
import torch
from tqdm import tqdm
import os
import pandas as pd
from torchvision.io import read_image
import torch
import torchvision.transforms.v2 as transforms
!pip install torchsummary
from torchsummary import summary
import random
import pickle
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
#load in the data
import os
import pandas as pd
import numpy as np
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode

In [None]:
vgg = models.vgg16(pretrained = True)
# "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vgg.to(device)
import os
import pandas as pd
from torchvision.io import read_image
import torch
h = 720
w = 960

In [None]:
def one_hot_from_label(class_dict ,segmented_img):
    semantic_map = []
    
    for color in class_dict.values():
        
        #color is a tensor of dim (3)
        color = color.unsqueeze(dim = 1).unsqueeze(dim = 2)
        equal = torch.eq(color , segmented_img)
        semantic_map.append(torch.all(equal , axis = 0))
        
    return torch.stack(semantic_map)

In [None]:
def rev_one_hot(class_dict , one_hot):
    #print(" 0000000000 " + str(torch.sum(one_hot)) )
    # one hot is of type 1s and 0s
    i = 0
    channels , height , width = one_hot.size()
    rgb_image = torch.zeros(3 , height , width).to(device)
    for color in class_dict.values():
        
        
        rgb_image += torch.mul(one_hot[i].to(device) , color.unsqueeze(dim = 1).unsqueeze(dim = 2).to(device))
        i+= 1
        
        
    return rgb_image        
                               
                        

In [None]:
def prob_to_one_hot(t):
    
    channels , height , width = t.size()
    maximums = torch.argmax(t , dim = 0)
    #maximums is of shape height*width
    one_hot = torch.zeros((channels , height , width)).to(device)
    one_hot[maximums , torch.arange(height).unsqueeze(1) , torch.arange(width).unsqueeze(0)] = 1
    
    return one_hot

In [None]:
def intersection_over_union(one_hot_truth , one_hot_predicted):
    sum = 0
    classes = len(one_hot_truth)
    
    for i in range(len(one_hot_truth)):
        
        intersection = one_hot_truth[i]*one_hot_predicted[i]
        intersection = torch.sum(intersection).item()
        union = torch.sum(one_hot_truth[i]).item() + torch.sum(one_hot_predicted[i]).item() - intersection
        
        if(union == 0):
            classes -= 1 # i think that if a class is not there then it should not affect the avg IOU
            continue
        
        sum += intersection/union
        
    return (sum/len(one_hot_truth) , sum/classes) 
        

In [None]:
def pixel_accuracy(label, pred):
    sum = 0

    
    channel , width , height = pred.size()
    eq = torch.eq(label , pred)
    sum+= torch.sum(torch.all(eq, dim = 0)).item()/(width*height)
    
    
    return (sum)

    

In [None]:
def mean_accuracy(one_hot_truth , one_hot_label):
    sum = 0
    channels , width , height = one_hot_truth.size()
    for i in range (len(one_hot_truth)):
        
        
        
        true = torch.sum(torch.eq(one_hot_truth[i], 1) & torch.eq(one_hot_label[i], 1)).item()
        
        if(torch.sum(one_hot_truth[i]).item() == 0):
            sum+= 1
        else:
            sum += (true/(torch.sum(one_hot_truth[i]).item()))
       
    
    return (sum/channels)
        

In [None]:
def dict_from_csv(file_path):
    df = pd.read_csv(file_path)
    mappings ={}
    for i in range (len(df)):
        t = torch.tensor((df.iloc[i,1] , df.iloc[i,2] , df.iloc[i,3]))
        mappings[df.iloc[i,0]] = t
    return mappings

    

In [None]:

class vgg_fcn_32(nn.Module):
  def __init__(self, vggnet, n_classes ):
    super().__init__()
    self.block1 = vggnet.features
    
    self.block3 = nn.Sequential(
        nn.Conv2d(512 , 1024 , 1 , 1 ),
        nn.Dropout(),
        nn.Conv2d(1024 , 2048 , 1 , 1 ),
        nn.Dropout(),
        nn.Conv2d(2048 , 2048, 1 , 1 ),
    )
    #so in VGG , the final output is 1/32 ( each spatial dimension) , this reduction is done primarily by the MAX - pool layers
    #now upscale
    self.block4 = nn.Sequential(
        nn.ConvTranspose2d(2048 , 2048 , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        #nn.BatchNorm2d(2048),
        nn.ConvTranspose2d(2048 , 1024 , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(1024),
        nn.ConvTranspose2d(1024 , 1024, 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        #nn.BatchNorm2d(1024),
        nn.ConvTranspose2d(1024, 512 , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        nn.ConvTranspose2d(512 , n_classes , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU()
    )
    self.classifier = nn.Conv2d(n_classes , n_classes , 1 )
    self.last = nn.Softmax2d()
  def forward(self , x):
    x = self.block1(x)
    
    x = self.block3(x)
    
    x = self.block4(x)
    
    x = self.classifier(x)
    x = self.last(x)
    
    return x





    


In [None]:
class vgg_fcn_8(nn.Module):
    def __init__(self, vggnet , n_classes):
        super().__init__()
        #output of pool3 + 2*2* ouput of pool 4 + 2* 1x1 convolutions(output of pool 5)
        self.pool3 = vggnet.features[:17] #256 channels output
        self.pool4 = vggnet.features[17:24] # 512 channels output
        self.pool5 = vggnet.features[24:] #512 channels output  
        self.convolutions1x1 = nn.Sequential(
        nn.Conv2d(512 , 1024 , 1 , 1 ),
        nn.Dropout(),
        nn.Conv2d(1024 , 2048 , 1 , 1 ),
        nn.Dropout(),
        nn.Conv2d(2048 , 2048 , 1 , 1 ),
        )
        self.upsample_conv_7 = nn.Sequential(
        #self.upscale_bilinear_1 = 
        nn.ConvTranspose2d(2048 , 1024 ,3 , 2 , padding = 1 , output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(1024),
        #self.upscale_bilinear_2 = 
        nn.ConvTranspose2d(1024 , 512, 3, 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        #self.dec_layer_1 =
        nn.ConvTranspose2d(512 , 256 , 3 , 1 , padding = 1)
        
        )
        
        self.upsample_pool_4 = nn.Sequential(
        #self.upscale_bilinear_3 = 
        nn.ConvTranspose2d(512 , 512 , 3 , 2, padding = 1 , output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        #self.dec_layer_2 = 
        nn.ConvTranspose2d(512 , 256 , 3 , 1 , padding = 1)
        )
        self.block1 = nn.Sequential(
        nn.ConvTranspose2d(256 , 1024 , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(1024),
        nn.ConvTranspose2d(1024 , 512 , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        nn.ConvTranspose2d(512 , n_classes , 3 , 2 , padding = 1, output_padding = 1),
        nn.ReLU(),            
        nn.Conv2d(n_classes , n_classes , 1),
        nn.Softmax2d()
        )
    def forward(self , x):
        pool3_output = self.pool3(x)
        pool4_output = self.pool4(pool3_output)
        pool5_output = self.pool5(pool4_output)
        output_conv_7 = self.convolutions1x1(pool5_output)
        upsampled_conv_7 = self.upsample_conv_7(output_conv_7)
        upsampled_pool_4 = self.upsample_pool_4(pool4_output)
        x = torch.add(pool3_output , torch.add( upsampled_conv_7 ,upsampled_pool_4))
        output = self.block1(x)
        return output
        
        
        
        

In [None]:
trans = transforms.Compose([transforms.ToDtype(torch.float32 , scale = True), 
                            transforms.Resize((320 , 480))])
trans_target = transforms.Compose([transforms.ToDtype(torch.float32) , transforms.Resize((320 , 480) , interpolation=InterpolationMode.NEAREST_EXACT)])

In [None]:


#custom dataset class
class LoadCamvid(Dataset):
    def __init__(self , csv_file , label_dir , img_dir , transform = None , target_transform = None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.csv_file = csv_file
        self.transform = transform
        self.target_transform = target_transform
        self.class_data = pd.read_csv(csv_file)
        self.dict = {}
        self.img_files = os.listdir(self.img_dir)
        self.label_files = os.listdir(self.label_dir)
        df = pd.read_csv(self.csv_file)
        
        self.img_files.sort()
        self.label_files.sort()
        
        for i in range(len(df)):
            
            t = torch.tensor((df.iloc[i,1] , df.iloc[i,2] , df.iloc[i,3]))
            self.dict[df.iloc[i,0]] = t
    
        
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self , idx):
        # i will return onehot encoding
        label = read_image(os.path.join(self.label_dir , self.label_files[idx]))
        label.to(device)
        
        onehot = one_hot_from_label(self.dict , label)
        onehot.to(device)
        img = read_image(os.path.join(self.img_dir , self.img_files[idx]))
        img.to(device)
        img = self.transform(img)
        onehot = self.target_transform(onehot)       
        label = self.target_transform(label)
        
        return (img , onehot ,label)
        
        
        

In [None]:
training_data  = LoadCamvid("/kaggle/input/camvid/CamVid/class_dict.csv" ,"/kaggle/input/camvid/CamVid/train_labels" ,  "/kaggle/input/camvid/CamVid/train" , transform = trans , target_transform = trans_target)
validation_data= LoadCamvid("/kaggle/input/camvid/CamVid/class_dict.csv" , "/kaggle/input/camvid/CamVid/val_labels" ,"/kaggle/input/camvid/CamVid/val" , transform = trans , target_transform = trans_target )

partial_set = torch.utils.data.Subset(training_data , range(4))

train_dataloader = DataLoader(training_data , batch_size =16 , shuffle = True)
val_dataloader = DataLoader(validation_data , batch_size = 16 , shuffle = True)

df = pd.read_csv("/kaggle/input/camvid/CamVid/class_dict.csv")



mappings ={}
for i in range (len(df)):
    t = torch.tensor((df.iloc[i,1] , df.iloc[i,2] , df.iloc[i,3]))
    mappings[df.iloc[i,0]] = t

In [None]:
#model
fcn = vgg_fcn_8(vgg , 32)
fcn.to(device)
summary(fcn  , input_size = (3 , 320 , 480))
fcn = nn.DataParallel(fcn)
layer_wise_gradient = {}
for name, param in fcn.named_parameters():
    if('weight' in name):
        
        layer_wise_gradient[name] = []

In [None]:
#training loop
epochs = 75


train_loss = []
epoch_count = []
val_loss= []
#train_pixelwise_accuracy= []
val_pixelwise_accuracy = []
#train_mean_accuracy = []
val_mean_accuracy = []
#train_iou = []
val_iou = []

loss_fnc = nn.BCELoss()
optim = torch.optim.Adam(fcn.parameters() , lr = 0.0001)
scheduler = ExponentialLR(optim, gamma=0.985) 
for epoch in tqdm(range(epochs)):
    train_loss.append(0)
    epoch_count.append(epoch)
    val_loss.append(0)
    #train_pixelwise_accuracy.append(0)
    val_pixelwise_accuracy.append(0)
    #train_mean_accuracy.append(0)
    val_mean_accuracy.append(0)
    #train_iou.append(0)
    val_iou.append(0)
    clip_value = 1
    
    for val_batch_idx , (img , one_hot , label) in enumerate(val_dataloader):
        fcn.eval()
        with torch.no_grad():
            
            img = img.to(device)
            one_hot = one_hot.to(device)
            label = label.to(device)
            
            output = fcn(img)
            
            val_loss[epoch] += loss_fnc(output , one_hot).item()
            
            
            for i in range(len(output)):
                j  = random.randint(0 , len(output)-1)
                
                one_hot_output = prob_to_one_hot(output[j])
                
                output_label = rev_one_hot(mappings , one_hot_output)

                
                val_iou[epoch] += intersection_over_union(one_hot[j] , one_hot_output)[1]
                val_pixelwise_accuracy[epoch] += pixel_accuracy(label[j] , output_label)
                val_mean_accuracy[epoch] += mean_accuracy(one_hot[j] , one_hot_output)
                
    val_iou[epoch] /= 100
    val_pixelwise_accuracy[epoch] /= 100
    val_mean_accuracy[epoch] /= 100
            
    for batch_idx , (img , one_hot , label) in enumerate(train_dataloader):
        fcn.train()
        torch.cuda.empty_cache()
        img = img.to(device)
        one_hot = one_hot.to(device)
        output = fcn(img)
        loss_val = loss_fnc(output , one_hot)
        
        train_loss[epoch] += loss_val.item()
        optim.zero_grad()
        loss_val.backward()
        if epoch%5 == 0:
            for name , param in fcn.named_parameters():
                if('weight' in name):

                    layer_wise_gradient[name].append(param.grad.norm().item())
                    
        torch.nn.utils.clip_grad_norm_(fcn.parameters(), clip_value)
        optim.step()
        
        
        
    
    scheduler.step()     
    train_loss[epoch] = train_loss[epoch]/369
    val_loss[epoch] = val_loss[epoch]/100    
    if epoch%5 == 0:
        torch.save(fcn.state_dict(), "fcn8_" + str(epoch) +".pth")
    print(f"Learning Rate: {optim.param_groups[0]['lr']}")
    print(f"epoch - : {epoch_count[epoch]} ,training loss - : {train_loss[epoch]} , validation los - : {val_loss[epoch]}" )
    
    
        
        

In [None]:
torch.cuda.empty_cache()

In [None]:
#plotting the gradiesnts of all the layers (trivial)
plt.figure(figsize=(12, 12))
print(type(layer_wise_gradient.items()))
for index,(key, values) in enumerate(layer_wise_gradient.items()):
    #if('block4' in key):
        plt.plot(values, label=key)

# Adding labels and title
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Line Graph for Each Key')
plt.legend()

In [None]:
#plotting the losses
plt.plot(epoch_count, train_loss , label = 'train_loss')
plt.plot(epoch_count, val_loss , label = 'val_loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.grid(True)
plt.show()

In [None]:
#plotting IOU over the validation dataset
plt.plot(epoch_count, val_iou , label = 'val_iou')
plt.title('IOU over Epochs')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('IOU')

plt.grid(True)
plt.show()

In [None]:
#plotting mean_accuracy over the validation dataset
plt.plot(epoch_count, val_mean_accuracy , label = 'val_iou')
plt.title('mean_accuracy over Epochs')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('mean acc')

plt.grid(True)
plt.show()

In [None]:
#plotting pixel accuracy over the validation dataset
plt.plot(epoch_count, val_pixelwise_accuracy , label = 'val_pixelwise_accuracy')
plt.title('pixelwise_accuracy over Epochs')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('pixelwise acc')

plt.grid(True)
plt.show()

In [None]:
#saving the status
with open('record.pkl' , 'wb') as f:
    pickle.dump(layer_wise_gradient , f)
    pickle.dump(train_loss , f)
    pickle.dump(val_loss , f)
    pickle.dump(val_iou , f)
    pickle.dump(val_mean_accuracy , f)
    pickle.dump(val_pixelwise_accuracy , f)

In [None]:
#manually checking some inputs from the test data
test_data = training_data  = LoadCamvid("/kaggle/input/camvid/CamVid/class_dict.csv" ,"/kaggle/input/camvid/CamVid/test_labels" ,  "/kaggle/input/camvid/CamVid/test" , transform = trans , target_transform = trans_target)

img , onehot , label = test_data[150]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[100]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[200]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[25]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[125]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[225]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[80]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()

In [None]:
img , onehot , label = test_data[66]
# img , onehot , label
label = label.to(torch.int32)
plt.imshow(img.permute(1,2,0))
plt.show()
plt.imshow(label.permute(1,2,0))
plt.show()
output = fcn(img.unsqueeze(dim = 0))
pred_mask = rev_one_hot(mappings ,prob_to_one_hot(output[0]))
pred_mask = pred_mask.to(torch.int32)
plt.imshow(pred_mask.permute(1,2,0).detach().to('cpu'))
plt.show()