In [52]:
import os

In [53]:
from datasets import load_dataset
from datasets import Dataset
import numpy as np
from PIL import Image
from torchvision.transforms import ColorJitter
from transformers import SegformerFeatureExtractor
import json

In [54]:
data_split = ['train','val']

path = '/scratch/j/jcaunedo/umar1/segmentation/IDD_Segmentation/'

In [55]:
fl = open('label2id.json')
label2id = json.load(fl)
fl.close()

In [56]:
fl = open('id2label.json')
id2label = json.load(fl)
fl.close()

In [57]:
def populate(cpath):
    
    result = []
    
    folders = os.listdir(cpath)
    
    for each_folder in folders:
        
        combined_folder_path = cpath+'/'+each_folder

        files = os.listdir(combined_folder_path)
        
        for each in files:
            
            result.append(combined_folder_path+'/'+each)
    
    return result

In [58]:
def get_all_images_path(split,subset=0):
    
    combined_path = path+'leftImg8bit/'+split+'/'
    
    images = sorted(populate(combined_path))
    
    combined_path = path+'label_processed/'+split+'/'
    
    masks = sorted(populate(combined_path))
    
    if subset>0:
        images = images[:subset]
        masks = masks[:subset]
    
    data_dict = {
        "pixel_values" : images,
        "label" : masks
    }
    
    return data_dict
    

    

In [59]:
d_train =  get_all_images_path('train',1000)
d_val = get_all_images_path('val',200)
# Create a Hugging Face Dataset from the dictionary
dataset_train = Dataset.from_dict(d_train)
dataset_val = Dataset.from_dict(d_val)

In [60]:
def mapping_fn(example):
    result = {}
    result['pixel_values'] = Image.open(example['pixel_values']).convert('RGB')
    result['label'] = Image.open(example['label']).convert('L')
    return result

In [62]:


feature_extractor = SegformerFeatureExtractor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) 

def train_transforms(example_batch):
    images = [jitter(Image.open(x).convert('RGB')) for x in example_batch['pixel_values']]
    labels = [Image.open(x).convert('L') for x in example_batch['label']]
    inputs = feature_extractor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [Image.open(x).convert('RGB') for x in example_batch['pixel_values']]
    labels = [Image.open(x).convert('L') for x in example_batch['label']]
    inputs = feature_extractor(images, labels)
    return inputs


# Set transforms
dataset_train.set_transform(train_transforms)
dataset_val.set_transform(val_transforms)

In [63]:
dataset_train[0]['labels']

array([[31, 31, 31, ..., 31, 31, 31],
       [31, 31, 31, ..., 31, 31, 31],
       [31, 31, 31, ..., 31, 31, 31],
       ...,
       [22, 22, 22, ...,  2,  2,  2],
       [22, 22, 22, ...,  2,  2,  2],
       [22, 22, 22, ...,  2,  2,  2]])

In [64]:
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0" 
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id
)


Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.2.proj.bias', 'decode_h

In [65]:
from transformers import TrainingArguments

epochs = 50
lr = 0.00006
batch_size = 2


training_args = TrainingArguments(
    "segformer-b0-finetuned-segments-sidewalk-outputs",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True

)


In [66]:
import torch
from torch import nn
import evaluate

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    # currently using _compute instead of compute
    # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=0,
            reduce_labels=feature_extractor.do_reduce_labels,
        )
    
    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[str(i)]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[str(i)]}": v for i, v in enumerate(per_category_iou)})
    
    return metrics


In [67]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()



Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Curb,Accuracy Caravan,Accuracy Road,Accuracy Out of roi,Accuracy Tunnel,Accuracy License plate,Accuracy Sidewalk,Accuracy Bus,Accuracy Trailer,Accuracy Vehicle fallback,Accuracy Obs-str-bar-fallback,Accuracy Autorickshaw,Accuracy Polegroup,Accuracy Animal,Accuracy Car,Accuracy Building,Accuracy Drivable fallback,Accuracy Train,Accuracy Rectification border,Accuracy Pole,Accuracy Ego vehicle,Accuracy Non-drivable fallback,Accuracy Motorcycle,Accuracy Fallback background,Accuracy Billboard,Accuracy Wall,Accuracy Parking,Accuracy Rider,Accuracy Rail track,Accuracy Traffic sign,Accuracy Guard rail,Accuracy Sky,Accuracy Fence,Accuracy Traffic light,Accuracy Bridge,Accuracy Person,Accuracy Bicycle,Accuracy Vegetation,Accuracy Truck,Iou Curb,Iou Caravan,Iou Road,Iou Out of roi,Iou Tunnel,Iou License plate,Iou Sidewalk,Iou Bus,Iou Trailer,Iou Vehicle fallback,Iou Obs-str-bar-fallback,Iou Autorickshaw,Iou Polegroup,Iou Animal,Iou Car,Iou Building,Iou Drivable fallback,Iou Train,Iou Rectification border,Iou Pole,Iou Ego vehicle,Iou Non-drivable fallback,Iou Motorcycle,Iou Fallback background,Iou Billboard,Iou Wall,Iou Parking,Iou Rider,Iou Rail track,Iou Traffic sign,Iou Guard rail,Iou Sky,Iou Fence,Iou Traffic light,Iou Bridge,Iou Person,Iou Bicycle,Iou Vegetation,Iou Truck
20,3.1046,3.319595,0.068692,0.14578,0.629247,,0.0,0.858256,,,,0.059471,0.002642,,0.0,0.094687,0.139915,0.0,0.0,0.774963,0.212985,0.042682,,,0.000649,,0.000679,0.011637,0.0,0.006853,0.026912,,0.001425,,0.0,0.001253,0.734803,0.0,0.0,0.0,0.01179,0.268755,0.967694,0.009581,0.0,0.0,0.785948,0.0,0.0,0.0,0.040557,0.002443,0.0,0.0,0.057372,0.050504,0.0,0.0,0.288023,0.151072,0.038429,0.0,0.0,0.000512,0.0,0.000676,0.009673,0.0,0.006656,0.019779,0.0,0.001416,0.0,0.0,0.000903,0.726635,0.0,0.0,0.0,0.009409,0.000148,0.479354,0.009464
40,2.384,2.228946,0.080715,0.151299,0.708144,,0.0,0.938151,,,,0.001769,0.001566,,0.0,0.126907,0.000923,0.0,0.0,0.878234,0.543514,0.003226,,,0.0,,0.0,0.003685,0.0,0.006559,0.003198,,0.0,,0.0,0.0,0.916363,0.0,0.0,0.0,0.0,0.0,0.962076,0.001507,0.0,0.0,0.83779,0.0,,0.0,0.001753,0.001549,,0.0,0.068841,0.000898,0.0,0.0,0.21599,0.263016,0.003214,0.0,0.0,0.0,,0.0,0.003629,0.0,0.006496,0.00297,0.0,0.0,0.0,0.0,0.0,0.863155,0.0,0.0,0.0,0.0,0.0,0.634927,0.0015
60,1.9153,1.738569,0.09708,0.15701,0.722363,,0.0,0.948476,,,,0.008551,7.5e-05,,0.0,0.04215,0.0,0.0,0.0,0.893346,0.686104,0.037382,,,0.0,,0.0,0.000267,0.0,0.000326,0.032097,,0.0,,0.0,0.0,0.951521,0.0,0.0,0.0,0.0,0.0,0.952551,0.000431,0.0,0.0,0.858397,,,,0.008423,7.5e-05,,0.0,0.033825,0.0,0.0,0.0,0.238424,0.239243,0.035316,0.0,,0.0,,0.0,0.000267,0.0,0.000326,0.021513,,0.0,,0.0,0.0,0.881658,0.0,0.0,0.0,0.0,0.0,0.69158,0.00043
80,2.4513,1.529708,0.10384,0.159235,0.733267,,0.0,0.970947,,,,0.000802,0.000105,,0.0,0.041008,4.7e-05,0.0,0.0,0.845162,0.731828,0.072458,,,0.0,,0.0,0.005856,0.0,0.000246,0.038066,,2e-06,,0.0,0.0,0.966301,0.0,0.0,0.0,0.0,0.0,0.94425,0.000741,0.0,0.0,0.86491,,,,0.000801,0.000105,,0.0,0.032414,4.7e-05,0.0,0.0,0.304003,0.222926,0.062533,,,0.0,,0.0,0.005824,0.0,0.000245,0.02796,,2e-06,,0.0,0.0,0.885599,0.0,0.0,0.0,0.0,0.0,0.707095,0.000738
100,1.832,1.398859,0.111024,0.168524,0.742555,,0.0,0.974081,,,,0.0,0.0,,0.0,0.019712,0.0,0.0,0.0,0.840146,0.731802,0.294505,,,0.0,,0.0,0.003948,0.0,0.014428,0.085117,,0.0,,0.0,0.0,0.968514,0.0,0.0,0.0,0.0,0.0,0.952049,0.00288,0.0,0.0,0.88557,,,,0.0,0.0,,0.0,0.0187,0.0,0.0,0.0,0.32068,0.229705,0.215114,,,0.0,,0.0,0.003934,0.0,0.013891,0.053142,,0.0,,0.0,0.0,0.897659,0.0,0.0,0.0,0.0,0.0,0.689448,0.002869
120,1.5902,1.35209,0.117773,0.17995,0.744033,,0.0,0.973485,,,,0.0,0.0,,0.0,0.073017,0.0,0.0,0.0,0.887801,0.80601,0.439873,,,0.0,,0.0,0.005772,0.0,0.004758,0.175923,,0.0,,0.0,0.0,0.976151,0.0,0.0,0.0,0.0,0.0,0.875542,0.000225,0.0,0.0,0.894957,,,,0.0,0.0,,0.0,0.057015,0.0,0.0,0.0,0.296478,0.241426,0.29244,,,0.0,,0.0,0.00573,0.0,0.004725,0.098066,,0.0,,0.0,0.0,0.892292,0.0,0.0,0.0,0.0,0.0,0.749838,0.000225
140,1.4577,1.327814,0.122169,0.186357,0.752678,,0.0,0.964756,,,,0.0,0.0,,0.0,0.079006,0.0,0.0,0.0,0.877215,0.802457,0.651419,,,0.0,,0.0,0.004178,0.0,0.020828,0.108709,,3.3e-05,,0.0,0.0,0.974689,0.0,0.0,0.0,1e-05,0.0,0.92102,3.1e-05,0.0,0.0,0.907183,,,,0.0,0.0,,0.0,0.058723,0.0,0.0,0.0,0.328123,0.243589,0.380757,,,0.0,,0.0,0.004161,0.0,0.019939,0.067381,,3.3e-05,,0.0,0.0,0.9049,0.0,0.0,0.0,9e-06,0.0,0.750256,3.1e-05
160,1.475,1.235858,0.125188,0.191066,0.757457,,0.0,0.975428,,,,0.0,0.0,,0.0,0.077113,0.0,0.0,0.0,0.835319,0.756202,0.620272,,,0.0,,0.0,0.014833,0.0,0.026057,0.318088,,0.000172,,0.0,0.0,0.967057,0.0,0.0,0.0,0.0,0.0,0.950344,1.9e-05,0.0,0.0,0.898615,,,,0.0,0.0,,0.0,0.058839,0.0,0.0,0.0,0.329116,0.28396,0.379362,,,0.0,,0.0,0.014709,0.0,0.024491,0.125519,,0.000172,,0.0,0.0,0.910481,0.0,0.0,0.0,0.0,0.0,0.730356,1.9e-05
180,1.3704,1.172522,0.125364,0.194847,0.754752,,0.0,0.955305,,,,0.0,0.0,,0.0,0.031239,0.0,0.0,0.0,0.865994,0.736894,0.753209,,,0.0,,0.0,0.06451,0.0,0.039101,0.279854,,0.002519,,0.0,0.0,0.974872,0.0,0.0,0.0,0.00067,0.0,0.946017,0.000393,0.0,0.0,0.895589,,,,0.0,0.0,,0.0,0.02894,0.0,0.0,0.0,0.297716,0.312133,0.390446,,,0.0,,0.0,0.060661,0.0,0.035489,0.130083,,0.002518,,0.0,0.0,0.907507,0.0,0.0,0.0,0.00067,0.0,0.698771,0.000393
200,1.594,1.248758,0.137983,0.213739,0.756484,,0.0,0.94239,,,,0.0,0.0,,0.0,0.09549,0.0,0.0,0.0,0.866534,0.837862,0.790539,,,0.0,,0.0,0.31219,0.0,0.042639,0.406953,,0.011516,,0.0,0.0,0.971771,0.0,0.0,0.0,8.1e-05,0.0,0.92046,9e-06,0.0,0.0,0.908446,,,,0.0,0.0,,0.0,0.064144,0.0,0.0,0.0,0.410089,0.259538,0.384404,,,0.0,,0.0,0.232452,0.0,0.037214,0.153949,,0.011465,,0.0,0.0,0.915217,0.0,0.0,0.0,8.1e-05,0.0,0.762479,9e-06


  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_are

In [48]:
id2label

{'0': 'curb',
 '1': 'caravan',
 '2': 'road',
 '3': 'out of roi',
 '4': 'tunnel',
 '5': 'license plate',
 '6': 'sidewalk',
 '7': 'bus',
 '8': 'trailer',
 '9': 'vehicle fallback',
 '10': 'obs-str-bar-fallback',
 '11': 'autorickshaw',
 '12': 'polegroup',
 '13': 'animal',
 '14': 'car',
 '15': 'building',
 '16': 'drivable fallback',
 '17': 'train',
 '18': 'rectification border',
 '19': 'pole',
 '20': 'ego vehicle',
 '21': 'non-drivable fallback',
 '22': 'motorcycle',
 '23': 'fallback background',
 '24': 'billboard',
 '25': 'wall',
 '26': 'parking',
 '27': 'rider',
 '28': 'rail track',
 '29': 'traffic sign',
 '30': 'guard rail',
 '31': 'sky',
 '32': 'fence',
 '33': 'traffic light',
 '34': 'bridge',
 '35': 'person',
 '36': 'bicycle',
 '37': 'vegetation',
 '38': 'truck'}