In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img,img_to_array
import torch
import os
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
class SemanticDataset(Dataset):
    def __init__(self,root_dir,image_inp_size,seg_img_inp_size,transforms=None):
        self.root_dir=root_dir
        self.image_inp_size=image_inp_size
        self.seg_img_inp_size=seg_img_inp_size
        self.transforms=transforms
        
    def __len__(self):
        return len(os.listdir(os.path.join(self.root_dir,'SegmentationClass')))
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx=idx.tolist()

        seg_img_name=os.path.join(self.root_dir,'SegmentationClass',os.listdir(os.path.join(self.root_dir,'SegmentationClass'))[idx])
        img_name=seg_img_name.replace('SegmentationClass','JPEGImages')[:-4]+'.jpg'
        image=img_to_array(load_img(img_name,target_size=self.image_inp_size))/255
        seg_img=img_to_array(load_img(seg_img_name,target_size=self.seg_img_inp_size))
        sample={'image': image, 'seg_image': seg_img}
        
        if self.transforms:
            sample=self.transforms(sample)
            
        return sample
    
class ToTensor(object):
    def __call__(self, sample):
        image, seg_img = sample['image'], sample['seg_image']
        image=np.transpose(image,(2,0,1))
        return {'image': torch.from_numpy(image).cuda(),'seg_image': torch.tensor(voc_label_indices(seg_img,build_colormap2label()),dtype=torch.long).cuda()}
    
transforms=transforms.Compose([ToTensor()])

In [None]:
#Not in chronological order. Needs to be run first!
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

def build_colormap2label():
    colormap2label=np.zeros(256**3)
    for i,color in enumerate(VOC_COLORMAP):
        idx=(color[0]*256+color[1])*256+color[2]
        colormap2label[idx]=i
    return colormap2label

def voc_label_indices(colormap,colormap2label):
    colormap=np.asarray(colormap,dtype=np.int32)
    idx=(colormap[:,:,0]*256+colormap[:,:,1])*256+colormap[:,:,2]
    return colormap2label[idx]

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
    def __init__(self,n_categories):
        super(Network, self).__init__()
        
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)
        self.batchn1=nn.BatchNorm2d(num_features=64)
        self.relu1=nn.ReLU(inplace=True)
        
        self.conv2=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
        self.batchn2=nn.BatchNorm2d(num_features=64)
        self.relu2=nn.ReLU(inplace=True)
        
        self.down1=nn.Conv2d(in_channels=64,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv3=nn.Conv2d(in_channels=1,out_channels=128,kernel_size=3)
        self.batchn3=nn.BatchNorm2d(num_features=128)
        self.relu3=nn.ReLU(inplace=True)
        
        self.conv4=nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3)
        self.batchn4=nn.BatchNorm2d(num_features=128)
        self.relu4=nn.ReLU(inplace=True)
        
        self.down2=nn.Conv2d(in_channels=128,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv5=nn.Conv2d(in_channels=1,out_channels=256,kernel_size=3)
        self.batchn5=nn.BatchNorm2d(num_features=256)
        self.relu5=nn.ReLU(inplace=True)
        
        self.conv6=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3)
        self.batchn6=nn.BatchNorm2d(num_features=256)
        self.relu6=nn.ReLU(inplace=True)
        
        self.down3=nn.Conv2d(in_channels=256,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv7=nn.Conv2d(in_channels=1,out_channels=512,kernel_size=3)
        self.batchn7=nn.BatchNorm2d(num_features=512)
        self.relu7=nn.ReLU(inplace=True)
        
        self.conv8=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3)
        self.batchn8=nn.BatchNorm2d(num_features=512)
        self.relu8=nn.ReLU(inplace=True)
        
        self.down4=nn.Conv2d(in_channels=512,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv9=nn.Conv2d(in_channels=1,out_channels=1024,kernel_size=3)
        self.batchn9=nn.BatchNorm2d(num_features=1024)
        self.relu9=nn.ReLU(inplace=True)
        
        self.conv10=nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3)
        self.batchn10=nn.BatchNorm2d(num_features=1024)
        self.relu10=nn.ReLU(inplace=True)
        '''
        self.up1=nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        
        self.conv11=nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=3)
        self.batchn11=nn.BatchNorm2d(num_features=512)
        self.relu11=nn.ReLU(inplace=True)
        
        self.conv12=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3)
        self.batchn12=nn.BatchNorm2d(num_features=512)
        self.relu12=nn.ReLU(inplace=True)
        '''
        self.up2=nn.ConvTranspose2d(in_channels=1024,out_channels=256,kernel_size=2,stride=2)#1024,512
        
        self.conv13=nn.Conv2d(in_channels=512,out_channels=256,kernel_size=3)
        self.batchn13=nn.BatchNorm2d(num_features=256)
        self.relu13=nn.ReLU(inplace=True)
        
        self.conv14=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3)
        self.batchn14=nn.BatchNorm2d(num_features=256)
        self.relu14=nn.ReLU(inplace=True)
        '''
        self.up3=nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        
        self.conv15=nn.Conv2d(in_channels=256,out_channels=128,kernel_size=3)
        self.batchn15=nn.BatchNorm2d(num_features=128)
        self.relu15=nn.ReLU(inplace=True)
        
        self.conv16=nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3)
        self.batchn16=nn.BatchNorm2d(num_features=128)
        self.relu16=nn.ReLU(inplace=True)
        '''
        self.up4=nn.ConvTranspose2d(in_channels=256,out_channels=64,kernel_size=2,stride=2)#256,128
        
        self.conv17=nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3)
        self.batchn17=nn.BatchNorm2d(num_features=64)
        self.relu17=nn.ReLU(inplace=True)
        
        self.conv18=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
        self.batchn18=nn.BatchNorm2d(num_features=64)
        self.relu18=nn.ReLU(inplace=True)
        
        self.out=nn.Conv2d(in_channels=64,out_channels=n_categories,kernel_size=1)
        
    def forward(self,t):
        
        t=self.conv1(t)
        t=self.batchn1(t)
        t=self.relu1(t)
        
        t=self.conv2(t)
        t=self.batchn2(t)
        c1=self.relu2(t)
        
        t=self.down1(c1)
        '''
        t=self.conv3(t)
        t=self.batchn3(t)
        t=self.relu3(t)
        
        t=self.conv4(t)
        t=self.batchn4(t)
        c2=self.relu4(t)
        
        t=self.down2(c2)
        '''
        t=self.conv5(t)
        t=self.batchn5(t)
        t=self.relu5(t)
        
        t=self.conv6(t)
        t=self.batchn6(t)
        c3=self.relu6(t)
        
        t=self.down3(c3)
        '''
        t=self.conv7(t)
        t=self.batchn7(t)
        t=self.relu7(t)
        
        t=self.conv8(t)
        t=self.batchn8(t)
        c4=self.relu8(t)
        
        t=self.down4(c4)
        '''
        t=self.conv9(t)
        t=self.batchn9(t)
        t=self.relu9(t)
        
        t=self.conv10(t)
        t=self.batchn10(t)
        t=self.relu10(t)
        '''
        t=self.up1(t)
        c4=c4[:,:,4:20,4:20]
        t=torch.cat((t,c4),dim=1)
        
        t=self.conv11(t)
        t=self.batchn11(t)
        t=self.relu11(t)
        
        t=self.conv12(t)
        t=self.batchn12(t)
        t=self.relu12(t)
        '''
        t=self.up2(t)
        
        c3=c3[:,:,4:118,4:118]  #40,118  16,4
        t=torch.cat((t,c3),dim=1)
        
        t=self.conv13(t)
        t=self.batchn13(t)
        t=self.relu13(t)
        
        t=self.conv14(t)
        t=self.batchn14(t)
        t=self.relu14(t)
        '''
        t=self.up3(t)
        c2=c2[:,:,41:81,41:81]
        t=torch.cat((t,c2),dim=1)
        
        t=self.conv15(t)
        t=self.batchn15(t)
        t=self.relu15(t)
        
        t=self.conv16(t)
        t=self.batchn16(t)
        t=self.relu16(t)
        '''
        t=self.up4(t)
        
        c1=c1[:,:,16:236,16:236]  #16,90  162,236
        t=torch.cat((t,c1),dim=1)
        
        t=self.conv17(t)
        t=self.batchn17(t)
        t=self.relu17(t)
        
        t=self.conv18(t)
        t=self.batchn18(t)
        t=self.relu18(t)
        
        t=self.out(t)
        return t
        
        

In [None]:
from torch import optim
device = torch.device("cuda")
seg_model=Network(len(VOC_CLASSES))
seg_model.load_state_dict(torch.load('/kaggle/input/sem-seg-model/Segnet_pytorch.pkl',map_location="cuda:0"),strict=False)
seg_model.to(device)
seg_model.eval()
EPOCHS=200
BATCH_SIZE=32
LEARNING_RATE=0.0001
optimizer=optim.Adam(seg_model.parameters(), lr=LEARNING_RATE)
#....scheduler=optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,70,90], gamma=0.1, last_epoch=-1)
seg_dataset = SemanticDataset(root_dir='/kaggle/input/pascalvoc2009/VOCdevkit/VOC2009',image_inp_size=(256,256,3),seg_img_inp_size=(216,216,3),transforms=transforms)
seg_dataloader=DataLoader(seg_dataset,batch_size=BATCH_SIZE,shuffle=True)

In [None]:
for epoch in range(EPOCHS):
    epoch_loss=0
    for i,sample in enumerate(seg_dataloader):
        X,y=sample['image'],sample['seg_image']
        y_pred=seg_model(X)
        optimizer.zero_grad()
        loss=F.cross_entropy(y_pred,y)
        epoch_loss+=loss.item()/BATCH_SIZE
        loss.backward()
        optimizer.step()
    #scheduler.step()
    
    print('Epoch: {0}, Loss: {1}'.format(epoch,epoch_loss))

In [None]:
torch.save(seg_model.state_dict(),'/kaggle/working/A.pkl')

In [None]:
import torchvision
def get_seg_img(seg_model,n_samples):
    
    dataloader=DataLoader(seg_dataset,batch_size=n_samples,shuffle=True)
    samples=iter(dataloader)
    sample=next(samples)
    X,y=sample['image'],sample['seg_image'].tolist()
    y_pred=torch.argmax(nn.Softmax(dim=1)(seg_model(X)),dim=1).tolist()
    #print(torch.argmax(nn.Softmax(dim=1)(seg_model(X)),dim=1).sum())
    #'''
    for i in range(n_samples):
        for j in range(216):
            for k in range(216):
                y_pred[i][j][k]=VOC_COLORMAP[y_pred[i][j][k]]
                y[i][j][k]=VOC_COLORMAP[y[i][j][k]]
    X=X.cpu()
    plt.figure(figsize=(15,15))
    for i in range(n_samples):
        plt.subplot(1,n_samples,i+1)
        plt.imshow(np.transpose(X[i],(1,2,0)))
    plt.show()
    
    plt.figure(figsize=(15,15))
    for i in range(n_samples):
        plt.subplot(1,n_samples,i+1)
        plt.imshow(y[i])
    plt.show()
    
    return np.array(y_pred)
    #'''
    
samples=get_seg_img(seg_model,4)
#'''
plt.figure(figsize=(15,15))
for i in range(samples.shape[0]):
    plt.subplot(1,samples.shape[0],i+1)
    plt.imshow(samples[i])
plt.show()
#'''

In [None]:
!pip install torchsummary
from torchsummary import summary
seg_model=Network(len(VOC_CLASSES)).cuda()
summary(seg_model,(3,256,256))