In [None]:
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision import transforms, models
import numpy as np # linear algebra
import random
import math
import matplotlib.pyplot as plt
import os

img_h = 384
img_w = 512
seed = 879
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n_classes = 23
class_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], 
 [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], 
 [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 
 [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128], [52, 247, 12], [243, 22, 243]]
#tree, gras, other vegetation, dirt, gravel, rocks, water, paved area, pool, 
#person, dog, car, bicycle, roof, wall, fence, fence-pole, window, door, obstacle and some unknown objects

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

In [None]:
data_root = '/kaggle/working'
print(os.listdir(data_root))

In [None]:
import shutil
from tqdm import tqdm

train_dir = 'train'
val_dir = 'val'

class_names = ['data', 'answer']

for dir_name in [train_dir, val_dir]:
    for class_name in class_names:
        os.makedirs(os.path.join(dir_name, class_name, 'unknown'), exist_ok=True)

source_dir0  = os.path.join('../input/semantic-drone-dataset/semantic_drone_dataset/original_images')
source_dir  = os.path.join('../input/semantic-drone-dataset/semantic_drone_dataset/label_images_semantic')
for i, file_name in enumerate(tqdm(os.listdir(source_dir0))):
    if i % 26 != 0:
        dest_dir0  = os.path.join(train_dir, 'data')
        dest_dir = os.path.join(train_dir, 'answer')
    else:
        dest_dir0  = os.path.join(val_dir, 'data')
        dest_dir = os.path.join(val_dir, 'answer')
    shutil.copy(os.path.join(source_dir0,  file_name), os.path.join(dest_dir0, 'unknown', file_name))
    shutil.copy(os.path.join(source_dir, file_name[:-3] + 'png'), 
                os.path.join(dest_dir, 'unknown', file_name[:-3] + 'png'))

In [None]:
class ValDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, target_paths, train=True):
        self.image_paths = image_paths
        self.target_paths = target_paths

    def transform(self, image, mask):
        # Resize
        resize = transforms.Resize(size=(img_h, img_w))
        image = resize(image)
        mask = resize(mask)

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)

        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        image = TF.normalize(image, mean, std)
        return image, mask

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        x, y = self.transform(image, mask)
        return x, y

    def __len__(self):
        return len(self.image_paths)

In [None]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, target_paths, train=True):
        self.image_paths = image_paths
        self.target_paths = target_paths

    def transform(self, image, mask):
        # Resize
        resize = transforms.Resize(size=(600, 900))
        image = resize(image)
        mask = resize(mask)

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            image, output_size=(img_h, img_w))
        image = TF.crop(image, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        
        #Random Rotation
        if random.random() > 0.6:
            angle = random.randint(1, 45)
            image = TF.rotate(image, angle)
            mask = TF.rotate(mask, angle)
        
        jitter = transforms.ColorJitter(brightness=0.5, contrast=0.4, saturation=0.5, hue=0.)
        image = jitter(image)
        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        image = TF.normalize(image, mean, std)
        return image, mask

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        x, y = self.transform(image, mask)
        return x, y

    def __len__(self):
        return len(self.image_paths)

In [None]:
train_dir = 'train/data'
val_dir   = 'val/data'
train_a   = 'train/answer'
val_a     = 'val/answer'

train_data = os.listdir('train/data/unknown')
train_labels = os.listdir('train/answer/unknown')
for i in range(len(train_data)):
    train_data[i] = os.path.join('train/data/unknown', train_data[i])
    train_labels[i] = os.path.join('train/answer/unknown', train_labels[i])
val_data = os.listdir('val/data/unknown')
val_labels = os.listdir('val/answer/unknown')
for i in range(len(val_data)):
    val_data[i] = os.path.join('val/data/unknown', val_data[i])
    val_labels[i] = os.path.join('val/answer/unknown', val_labels[i])
    
train_dataset = TrainDataset(sorted(train_data), sorted(train_labels))
val_dataset = ValDataset(sorted(val_data), sorted(val_labels))

batch_size = 12
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size, drop_last=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=8, shuffle=True, num_workers=batch_size, drop_last=True
)


In [None]:
len(train_dataloader), len(train_dataset)

In [None]:
len(val_dataloader), len(val_dataset)

In [None]:
def show_mask_above(img_a, img_a_mask):
    img_a = img_a.permute(1, 2, 0).numpy()
    img_a = std * img_a + mean
    img_a_mask = img_a_mask.reshape(img_h, img_w)
    plt.figure(1,figsize=(20,8))
    plt.subplot(121)
    plt.imshow(img_a);plt.title('Raw Drone footage ');plt.axis('off')
    plt.subplot(122)
    plt.imshow(img_a,alpha=0.9);
    plt.imshow(img_a_mask,alpha=0.6);plt.title('Drone with  mask');plt.axis('off')
    plt.show()

In [None]:
def show_three(original, pred, true):
    fig, axs = plt.subplots(1, 3, figsize=(20, 20), constrained_layout=True)
    original = original.permute(1, 2, 0).numpy()
    original = std * original + mean
    pred = pred.reshape(img_h, img_w)
    true = true.reshape(img_h, img_w)
    axs[0].imshow(original)
    axs[0].set_title('original image-001.jpg')
    axs[0].grid(False)

    axs[1].imshow(pred)
    axs[1].set_title('prediction image-out.png')
    axs[1].grid(False)

    axs[2].imshow(true)
    axs[2].set_title('true label image-001.png')
    axs[2].grid(False)

In [None]:
'''
data, labels = next(iter(val_dataloader))
secondary = 255*labels
print(torch.unique(secondary.reshape(secondary.shape[0]*img_h*img_w)))
for real, test, true in zip(data, secondary, labels):
    show_mask_above(real, test)
''' 

In [None]:
'''
data, labels = next(iter(train_dataloader))
secondary = 255*labels
print(torch.unique(secondary.reshape(secondary.shape[0]*img_h*img_w)))
for real, test, true in zip(data, secondary, labels):
    show_three(real, test, true)
'''

In [None]:
class UpSample(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.ConvTrans = torch.nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
        self.seq0 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(out_channels, out_channels, 3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.LeakyReLU()
        )
    def forward(self, x, y):
        x = self.ConvTrans(x)
        x = torch.cat((x[:,:], y[:,:]), dim=1)
        x = self.seq0(x)
        return x

In [None]:
model = models.vgg13_bn(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
    
class VGGU(torch.nn.Module):
    def __init__(self, num_classes):
        super(VGGU, self).__init__()
        lst = list(list(model.children())[0])
        self.seq0 = torch.nn.Sequential(
            *lst[:6]
        )
        self.seq1 = torch.nn.Sequential(
            *lst[6:13]
        )
        self.seq2 = torch.nn.Sequential(
            *lst[13:20]
        )
        self.seq3 = torch.nn.Sequential(
            *lst[20:27]
        )
        self.seq4 = torch.nn.Sequential(
            *lst[27:34]
        )
        self.seq5 = torch.nn.Sequential(
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(512, 1024, 3, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            torch.nn.Conv2d(1024, 1024, 3, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU()
        )
        self.seq6   = UpSample(1024, 1024)
        self.seq7   = UpSample(1024, 512)
        self.seq8   = UpSample(512, 256)
        self.seq9   = UpSample(256, 128)
        self.seq10  = UpSample(128, 64)
        
        self.conv = torch.nn.Conv2d(64, num_classes, 1)
    def forward(self, x):
        x = self.seq0(x)
        help0 = x
        x  = self.seq1(x)
        help1 = x
        x  = self.seq2(x)
        help2 = x
        x  = self.seq3(x)
        help3 = x
        x = self.seq4(x)
        help4 = x
        x = self.seq5(x)
        x = self.seq6(x, help4)
        x = self.seq7(x, help3)
        x = self.seq8(x, help2)
        x = self.seq9(x, help1)
        x = self.seq10(x, help0)
        x = self.conv(x)
        return x

In [None]:
def loss_func(res, target):
    target = target.type(torch.LongTensor).to(device)
    res = (torch.stack((res[:, 0], res[:, 1], res[:, 2], res[:, 3], res[:, 4],
                        res[:, 5], res[:, 6], res[:, 7], res[:, 8], res[:, 9],
                        res[:, 10], res[:, 11], res[:, 12], res[:, 13], res[:, 14],
                        res[:, 15], res[:, 16], res[:, 17], res[:, 18], res[:, 19], 
                        res[:, 20], res[:, 21], res[:, 22]), dim=3)). \
    reshape(res.shape[0]*res.shape[2]*res.shape[3], n_classes)
    f = torch.nn.CrossEntropyLoss()
    loss = f(res, target)
    
    res = (torch.nn.functional.softmax(res,dim=1))
    res = res.argmax(dim=1)
    acc = ((target == res).float()).mean()
    return acc, loss

In [None]:
def train_model(model, accnlossfunc, optimizer, scheduler, num_epochs):
    try:
        for epoch in range(num_epochs):
            for phase in ['train', 'val']:
                if phase == 'train':
                    dataloader = train_dataloader
                    model.train()  # Set model to training mode
                else:
                    dataloader = val_dataloader
                    model.eval()   # Set model to evaluate mode
                running_loss = 0.
                running_acc = 0
                # Iterate over data.
                for inputs, labels in (dataloader):
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    labels = (labels*255).reshape(labels.shape[0]*img_h*img_w)
                    optimizer.zero_grad()
                    # forward and backward
                    with torch.set_grad_enabled(phase == 'train'):
                        preds = model(inputs)
                        acc, loss_value = accnlossfunc(preds, labels)
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss_value.backward()
                            optimizer.step()
                            scheduler.step()

                    # statistics
                    running_loss += loss_value.item()
                    running_acc += acc.item()

                epoch_loss = running_loss / len(dataloader)
                epoch_acc = running_acc / len(dataloader)

                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc), flush=True)
    except KeyboardInterrupt:
        return model
    return model

In [None]:
model = VGGU(n_classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
loss = loss_func
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Decay LR by a factor of 0.1 every 88 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)
params = list(model.parameters())

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
# Decay LR by a factor of 0.1 every 88 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.65)

In [None]:
model = train_model(model, loss, optimizer, scheduler, num_epochs=200)

In [None]:
for i in range(20, 24):
    params[i].requires_grad = True

In [None]:
def show_three0(original, pred, true):
    fig, axs = plt.subplots(1, 3, figsize=(20, 20), constrained_layout=True)
    original = original.permute(1, 2, 0).numpy()
    original = std * original + mean
    #pred = pred.reshape(img_h, img_w)
    pred = pred.permute(1, 2, 0).numpy()
    true = true.reshape(img_h, img_w)
    axs[0].imshow(original)
    axs[0].set_title('original image-001.jpg')
    axs[0].grid(False)

    axs[1].imshow(pred)
    axs[1].set_title('prediction image-out.png')
    axs[1].grid(False)

    axs[2].imshow(true)
    axs[2].set_title('true label image-001.png')
    axs[2].grid(False)

In [None]:
data, labels = next(iter(val_dataloader))
data = data.to(device)
res = model(data)
res = (torch.stack((res[:, 0], res[:, 1], res[:, 2], res[:, 3], res[:, 4],
                        res[:, 5], res[:, 6], res[:, 7], res[:, 8], res[:, 9],
                        res[:, 10], res[:, 11], res[:, 12], res[:, 13], res[:, 14],
                        res[:, 15], res[:, 16], res[:, 17], res[:, 18], res[:, 19], 
                        res[:, 20], res[:, 21], res[:, 22]), dim=3)). \
    reshape(res.shape[0]*res.shape[2]*res.shape[3], n_classes)
res = res.argmax(dim=1)
res = res.reshape(data.shape[0], img_h, img_w)
for real, test in zip(data, res):
    show_mask_above(real.cpu().detach(), test.cpu().detach())

In [None]:
model.eval()
state = torch.get_rng_state()
X_batch0, X_batch1 = next(iter(val_dataloader))
X_batch0 = X_batch0.to(device)
g = X_batch0.clone()
res = model(X_batch0)
res = (torch.stack((res[:, 0], res[:, 1], res[:, 2], res[:, 3], res[:, 4],
                        res[:, 5], res[:, 6], res[:, 7], res[:, 8], res[:, 9],
                        res[:, 10], res[:, 11], res[:, 12], res[:, 13], res[:, 14],
                        res[:, 15], res[:, 16], res[:, 17], res[:, 18], res[:, 19], 
                        res[:, 20], res[:, 21], res[:, 22]), dim=3)). \
    reshape(res.shape[0]*res.shape[2]*res.shape[3], n_classes)
res = res.argmax(dim=1)
res = res.reshape(X_batch0.shape[0], img_h, img_w)

for imgr, imgg, imgb in zip(res, g, X_batch1):
    show_three(imgg.cpu().detach(), imgr.cpu().detach(), imgb)

In [None]:
data, labels = next(iter(train_dataloader))
data = data.to(device)
res = model(data)
res = (torch.stack((res[:, 0], res[:, 1], res[:, 2], res[:, 3], res[:, 4],
                        res[:, 5], res[:, 6], res[:, 7], res[:, 8], res[:, 9],
                        res[:, 10], res[:, 11], res[:, 12], res[:, 13], res[:, 14],
                        res[:, 15], res[:, 16], res[:, 17], res[:, 18], res[:, 19], 
                        res[:, 20], res[:, 21], res[:, 22]), dim=3)). \
    reshape(res.shape[0]*res.shape[2]*res.shape[3], n_classes)
res = res.argmax(dim=1)
res = res.reshape(data.shape[0], img_h, img_w)
for real, test in zip(data, res):
    show_mask_above(real.cpu().detach(), test.cpu().detach())