In [None]:
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
#import imageio
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
from torchvision.utils import make_grid
from albumentations.pytorch import ToTensorV2
import albumentations as A

%matplotlib inline

In [None]:
torch.__version__

In [None]:
config = {'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
         'batch_size': 2,
         'learning_rate': 0.001,
         }

device=config['device']
batch_size=config['batch_size']
lr=config['learning_rate']

In [None]:
img = cv2.imread(
    '../input/camseq-semantic-segmentation/0016E5_07961.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img.transpose(0,1,2));

In [None]:
img = cv2.imread(
    '../input/camseq-semantic-segmentation/0016E5_07961_L.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img.transpose(0,1,2));

In [None]:
np.unique(img)

In [None]:
# Source: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamSeq01/
color_codes = OrderedDict({
                'Animal': [64, 128, 64],
                'Archway': [192, 0, 128],
                'Bicyclist': [0, 128, 192],
                'Bridge': [0, 128, 64],
                'Building': [128, 0, 0],
                'Car': [64, 0, 128],
                'CartLuggagePram': [64, 0, 192],
                'Child': [192, 128, 64], 
                'Column_Pole': [192, 192, 128],
                'Fence': [64, 64, 128],
                'LaneMkgsDriv': [128, 0, 192],
                'LaneMkgsNonDriv': [192, 0, 64],
                'Misc_Text': [128, 128, 64],
                'MotorcycleScooter': [192, 0, 192],
                'OtherMoving': [128, 64, 64],
                'ParkingBlock': [64, 192, 128],
                'Pedestrian': [64, 64, 0],
                'Road': [128, 64, 128],
                'RoadShoulder': [128, 128, 192],
                'Sidewalk': [0, 0, 192],
                'SignSymbol': [192, 128, 128],
                'Sky': [128, 128, 128],
                'SUVPickupTruck': [64, 128, 192],
                'TrafficCone': [0, 0, 64],
                'TrafficLight': [0, 64, 64],
                'Train': [192, 64, 128],
                'Tree': [128, 128, 0],
                'Truck_Bus': [192, 128, 192],
                'Tunnel': [64, 0, 64],
                'VegetationMisc': [192, 192, 0],
                'Void': [0, 0, 0],
                'Wall': [64, 192, 0]
})

In [None]:
class CamSeqDataset(Dataset):
    
    def __init__(self, 
                 img_dir, 
                 color_codes=color_codes,
                 transforms=None):
        
        super().__init__()
        
        self.images = sorted([os.path.join(
            img_dir, x) for x in os.listdir(img_dir)
                      if not x.split('.')[0].endswith('_L')])
        self.images = [x for x in self.images if not x.endswith('.txt')]
        self.masks = sorted([os.path.join(
            img_dir, x) for x in os.listdir(img_dir)
                     if x.split('.')[0].endswith('_L')])
        self.color_codes = color_codes
        self.num_classes = len(self.color_codes)
        self.transforms = transforms
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        
        img = Image.open(self.images[idx])
        mask = Image.open(self.masks[idx])
        
        if img.mode != 'RGB':
            img = img.convert('RGB')
        if mask.mode != 'RGB':
            mask = mask.convert('RGB')
            
        img = np.asarray(img)
        mask = np.asarray(mask)
        mask_channels = np.zeros(
            (mask.shape[0], mask.shape[1]), dtype=np.int64)
        
        for i, cls in enumerate(self.color_codes.keys()):
            color = self.color_codes[cls]
            sub_mask = np.all(mask==color, axis=-1)*i
            mask_channels += sub_mask #*i

        if self.transforms is not None:
            transformed = self.transforms(image=img, masks=mask_channels)
            img = transformed['image']
            mask_channels = transformed['masks']
            
        mask_channels = mask_channels.astype(np.float32)
        img = img.astype(np.float32) #/255
        
        instance = {'image': torch.from_numpy(img.transpose(2,0,1)),
                    'mask': torch.from_numpy(mask_channels)}
        
        return instance

In [None]:
def make_deeplab(out_channels=32):
    
    model = models.segmentation.deeplabv3_resnet50(
        pretrained=True)
    model.classifier = models.segmentation.deeplabv3.DeepLabHead(
        2048, num_classes=out_channels)

    model.train()
    return model

In [None]:
def train_model(model, 
                train_loader, 
                val_loader, 
                criterion= nn.CrossEntropyLoss(),
                num_epochs=1,
                device=device):

    model.to(device)
    #model.eval()
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(1, num_epochs + 1):
        tr_loss = []
        val_loss = []
        print('Epoch {}/{}'.format(epoch, num_epochs))
        
        for sample in tqdm(train_loader):
            if sample['image'].shape[0]==1:
                break
            inputs = sample['image'].to(device)
            masks = sample['mask'].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            y_pred = outputs['out']
            y_true = masks
            loss = criterion(y_pred.float(), y_true.long())
            loss.backward()
            tr_loss.append(loss)
            optimizer.step()
            #break
            
        print(f'Train loss: {torch.mean(torch.Tensor(tr_loss))}')
        
        for sample in tqdm(val_loader):
            if sample['image'].shape[0]==1:
                break
            inputs = sample['image'].to(device)
            masks = sample['mask'].to(device)
            
            with torch.no_grad():
                outputs = model(inputs)
            y_pred = outputs['out']
            y_true = masks
            loss = criterion(y_pred.float(), y_true.long())
            val_loss.append(loss)
            optimizer.step()
            #break
            
        print(f'Validation loss: {torch.mean(torch.Tensor(val_loss))}')
        
    return model

In [None]:
dataset = CamSeqDataset(img_dir='../input/camseq-semantic-segmentation')
train_size = int(len(dataset)*0.85)
train_set, val_set = random_split(dataset, [train_size, len(dataset)-train_size])

train_loader = DataLoader(train_set, batch_size=batch_size)
val_loader = DataLoader(val_set, batch_size=batch_size)

model = make_deeplab()
model = train_model(model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    num_epochs=5)

In [None]:
model.eval()
for img in val_loader:
    image = model(img['image'].to(device))['out']
    break

In [None]:
plt.imshow(img['image'][0].cpu().permute(1,2,0).numpy().astype(np.uint8));

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=8, sharex=True, sharey=True, figsize=(16,20)) 
axes_list = [item for sublist in axes for item in sublist]

thresh=0.3
res = image[0].detach().cpu().numpy()
for i, mask in enumerate(res):
    ax = axes_list.pop(0)
    ax.imshow(np.where(mask>thresh, 255, 0), cmap='gray')
    ax.set_title(list(color_codes.keys())[i])

for ax in axes_list: 
    ax.remove()
    
plt.tight_layout()