In [1]:
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.autonotebook import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerFeatureExtractor
from transformers import SegformerForSemanticSegmentation
from huggingface_hub import cached_download, hf_hub_url
from datasets import load_metric

seed = 42
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Currently using "{device.upper()}" device.')

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

batch_size = 2
num_classes = 24
epochs = 20
path = r'segform_model.pth'

[Link to example tutorial](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb#scrollTo=MbNeV9xdw7rm)

In [2]:
images_path = r'../input/semantic-drone-dataset/dataset/semantic_drone_dataset/original_images/'
color_masks_path = r'../input/semantic-drone-dataset/RGB_color_image_masks/RGB_color_image_masks/'
masks_path = r'../input/semantic-drone-dataset/dataset/semantic_drone_dataset/label_images_semantic/'

images = glob(images_path + '*.jpg')
colored_masks = glob(color_masks_path + '*.png')
masks = glob(masks_path + '*.png')

images = sorted([str(p) for p in images])
colored_masks = sorted([str(p) for p in colored_masks])
masks = sorted([str(p) for p in masks])

path_df = pd.DataFrame({'image': images, 'color_mask': colored_masks, 'mask': masks})
path_df.sample(2)

In [3]:
labels = pd.read_csv('../input/semantic-drone-dataset/class_dict_seg.csv')
label_to_id = {v: k for k, v in enumerate(labels['name'].unique())}
id_to_label = {v: k for k, v in label_to_id.items()}

palette = []
for i in range(num_classes):
    color = labels.iloc[i, 1:].values.tolist()
    palette.append(color)

In [4]:
train, test = train_test_split(path_df, test_size=10, shuffle=True, random_state=42)
train, valid = train_test_split(train, test_size=40, shuffle=True, random_state=42)

train.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

print(f'Train size: {len(train)}, validation size: {len(valid)} and test size: {len(test)}')

In [5]:
sample_img = path_df.sample(1, random_state=42)

image = Image.open(sample_img['image'].values[0])
cmask = Image.open(sample_img['color_mask'].values[0])
mask = Image.open(sample_img['mask'].values[0])

plt.figure(figsize=(8, 4))
plt.subplot(131)
plt.imshow(image)
plt.subplot(132)
plt.imshow(cmask)
plt.subplot(133)
plt.imshow(mask)
plt.tight_layout()
plt.show()

In [6]:
class DroneDataset(Dataset):
    def __init__(self, dataframe, feature_extractor):
        self.dataframe = dataframe
        self.feature_extractor = feature_extractor
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, ix):
        row = self.dataframe.loc[ix].squeeze()
        image = Image.open(row['image'])        
        mask = Image.open(row['mask'])
        
        encoded_inputs = self.feature_extractor(image, mask, return_tensors="pt")
        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()
        
        return encoded_inputs

In [7]:
feature_extractor = SegformerFeatureExtractor(reduce_labels=True) # remove background class

train_dataset = DroneDataset(train, feature_extractor)
valid_dataset = DroneDataset(valid, feature_extractor)

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

In [9]:
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=num_classes, 
                                                         id2label=id_to_label, 
                                                         label2id=label_to_id,)

In [10]:
metric_train = load_metric("mean_iou")
metric_valid = load_metric("mean_iou")

optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
model.to(device)
print()

In [41]:
@torch.no_grad()
def evaluate_test(model=model, test=test):
    model.eval()
    idx = np.random.randint(len(test))
    image_p = test.loc[idx, 'image']
    gt_mask_p = test.loc[idx, 'mask']
    image = Image.open(image_p)
    gt_mask = Image.open(gt_mask_p)
    
    encoding = feature_extractor(image, return_tensors="pt")
    pixel_values = encoding.pixel_values.to(device)
    
    outputs = model(pixel_values=pixel_values)
    logits = outputs.logits.cpu()
    upsampled_logits = nn.functional.interpolate(logits,
                                                 size=image.size[::-1],
                                                 mode='bilinear',
                                                 align_corners=False)
    seg = upsampled_logits.argmax(dim=1)[0]
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    
    np_palette = np.array(palette)
    for label, color in enumerate(np_palette):
        color_seg[seg == label, :] = color
        
    color_seg = color_seg[..., ::-1]

    img = np.array(image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)
    
    # GT
    gt_map = np.array(gt_mask)
    gt_map[gt_map == 0] = 255
    gt_map = gt_map - 1
    gt_map[gt_map == 254] = 255

    classes_map = np.unique(gt_map).tolist()
    unique_classes = [model.config.id2label[idx] if idx!=255 else None for idx in classes_map]

    color_seg = np.zeros((gt_map.shape[0], gt_map.shape[1], 3), dtype=np.uint8)
    for label, color in enumerate(np_palette):
        color_seg[gt_map == label, :] = color

    color_seg = color_seg[..., ::-1]
    
    gt_mask = np.array(image) * 0.5 + color_seg * 0.5
    gt_mask = gt_mask.astype(np.uint8)

    plt.figure(figsize=(12, 8))
    plt.subplot(131)
    plt.title('Predicted Image')
    plt.imshow(img)
    
    plt.subplot(132)
    plt.title('GT segmentation mask')
    plt.imshow(gt_mask)
    
    plt.subplot(133)
    plt.title('Original Image')
    plt.imshow(image)
    
    plt.tight_layout()
    plt.show()
    plt.pause(0.01)

In [11]:
for epoch in range(epochs):
    print("Epoch:", epoch+1)
    model.train()
    for idx, batch in enumerate(tqdm(train_dataloader, leave=False)):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
          
            metric_train.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        if idx % 170 == 0:
            metrics = metric_train.compute(num_labels=len(id_to_label), 
                                           ignore_index=255,
                                           reduce_labels=False,
          )

            print("Loss:", loss.item())
            print("Mean_iou:", metrics["mean_iou"])
            print("Mean accuracy:", metrics["mean_accuracy"])
            print('-'*50)
            
    model.eval()
    print('-'*30, 'Validation', '-'*30)
    for idx, batch in enumerate(tqdm(valid_dataloader, leave=False)):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
          
            metric_valid.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        if idx % 19 == 0:
            metrics = metric_valid.compute(num_labels=len(id_to_label), 
                                           ignore_index=255,
                                           reduce_labels=False,
          )

            print("Loss:", loss.item())
            print("Mean_iou:", metrics["mean_iou"])
            print("Mean accuracy:", metrics["mean_accuracy"])
            print('-'*50)
    try:
        evaluate_test()
    except:
        pass

[Link to the training results](https://www.kaggle.com/code/pankratozzi/pytorch-hf-segformer-quadcopter?scriptVersionId=101556804)