In [None]:
pip install -r requirements.txt


In [None]:
python ./utils/download_dataset.py --savedir="./dataset/camvid/"


In [None]:
class CamvidDataset(Dataset):

  def __init__(self,
               root_dir,
               image_filenames,
               masks_filenames,
               feature_extractor,
               augment=False,
               num_classes=12) -> None:

    self.root_dir = root_dir
    self.image_filenames = image_filenames
    self.masks_filenames = masks_filenames
    self.num_classes = num_classes
    self.feature_extractor = feature_extractor


    conf_file = os.path.join(root_dir,'label_colors11.txt')

    colors, labels = self._dataset_conf(conf_file)
    self.id2label = dict(zip(range(self.num_classes),labels))
    self.class_colors = labelColors


In [None]:
def __len__(self):
    return len(self.image_filenames)


In [None]:
def __getitem__(self,idx):

    image_filename = self.image_filenames[idx]
    mask_filename = self.masks_filenames[idx]

    image = cv2.imread(os.path.join(self.root_dir,image_filename),)# BGR image
    mask = cv2.imread(
            os.path.join(self.root_dir,mask_filename),
            cv2.IMREAD_UNCHANGED,
        ) # BGR image

    # convert the mask from bgr to grayscale
    mask = self.bgr2gray12(mask,self.class_colors)

    if self.augment :
      image, mask = self._data_augmentation(image,mask)

    encod_inputs =self.feature_extractor(image,mask, return_tensors='pt')

    for k,v in encod_inputs.items():
      encod_inputs[k].squeeze_()

    return encod_inputs


In [None]:
def _data_augmentation(self, image, mask):
    aug = A.Compose(
      [
          A.Flip(p=0.5),
          A.RandomRotate90(p=0.5),
          A.OneOf([
                  A.Affine(p=0.33,shear=(-5,5),rotate=(-80,90)),
                  A.ShiftScaleRotate(
                    shift_limit=0.2,
                    scale_limit=0.2,
                    rotate_limit=120,
                    #border_mode= cv2.BORDER_CONSTANT,
                    #value=255, # padding with the ignored class
                    p=0.33),
                  A.GridDistortion(p=0.33),
                ], p=1),
          A.CLAHE(p=0.8),
          A.OneOf(
              [
                  A.ColorJitter(p=0.33),
                  A.RandomBrightnessContrast(p=0.33),
                  A.RandomGamma(p=0.33)
              ],
              p=1
          )
          ]
    )
    augmentation = aug(image=image, mask=mask)
    aug_img, aug_mask = augmentation['image'], augmentation['mask']
    return aug_img, aug_mask


In [None]:
# Returns a non batched dataset
def get_dataset(data_path='/dataset/camvid/',
                val_split=0.2,
                random_state=42,
                feature_extractor_name='nvidia/segformer-b2-finetuned-cityscapes-1024-1024'):

  feature_extractor = SegformerImageProcessor.from_pretrained(feature_extractor_name)
  feature_extractor.do_reduce_labels = False
  feature_extractor.do_resize = True
  feature_extractor.size = {"height":360, "width":480}
  feature_extractor.do_normalize= False
  feature_extractor.do_rescale= True


  img_files, mask_files = get_data_filenames(data_path)

  train_imgs, val_imgs, train_masks, val_masks = train_test_split(
      img_files, mask_files, test_size=val_split, random_state=random_state, shuffle=True)

  train_dataset = CamvidDataset(data_path,
                                train_imgs, train_masks,
                                feature_extractor,
                                augment=True,
                                num_classes=12
                                )
  val_dataset = CamvidDataset(data_path,
                              val_imgs,
                              val_masks,
                              feature_extractor,
                              num_classes=12)
  return train_dataset, val_dataset


In [None]:
# counts number of samples in each class
def compute_class_distribution(dataset):
  summary = [0]*dataset.num_classes

  for inputs in dataset:
      mask = inputs['labels']
      labels, counts = np.unique(mask, return_counts=True)
      for idx,label in enumerate(labels):
        summary[label] += counts[idx]
  return summary

# computes the weight of each class
def compute_class_weights(total, class_counts):
  weights = []
  for class_count in class_counts:
    weights.append(total/class_count)
  return weights


In [None]:
# Returns batched dataset
def get_dataloader(dataset,
                   train_batch_size=10,
                   val_batch_size=7,
                   num_workers=2,
                   prefetch_factor=5):


  train_dataset, val_dataset = dataset[0], dataset[1]

  train_dataloader = DataLoader(train_dataset,
                                batch_size=train_batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                prefetch_factor=prefetch_factor)

  val_dataloader = DataLoader(val_dataset,
                              batch_size=val_batch_size,
                              num_workers=num_workers,
                              prefetch_factor=prefetch_factor)

  return train_dataloader, val_dataloader


In [None]:
class SegFormerFineTuned(pl.LightningModule):
  def __init__(self, id2label,
               train_dl,
               val_dl,
               metrics_interval,
               class_weights,
               model_path="nvidia/segformer-b2-finetuned-cityscapes-1024-1024"):

    super(SegFormerFineTuned, self).__init__()
    self.id2label = id2label
    self.metrics_interval = metrics_interval
    self.train_dl = train_dl
    self.val_dl = val_dl
    self.weights = class_weights
    self.model_path = model_path

    self.num_classes = len(id2label.keys())
    self.label2id = {v:k for k,v in self.id2label.items()}

    self.model = SegformerForSemanticSegmentation.from_pretrained(
        self.model_path,
        return_dict=False,
        num_labels=self.num_classes,
        id2label=self.id2label,
        label2id=self.label2id,
        ignore_mismatched_sizes=True,
    )

    self.train_mean_iou = evaluate.load("mean_iou")
    self.val_mean_iou = evaluate.load("mean_iou")
    self.test_mean_iou = evaluate.load("mean_iou")

    # Save the hyper-parameters
    # with the checkpoints
    self.save_hyperparameters()

  def forward(self, images, masks):
    outputs = self.model(pixel_values=images)
    return (outputs)

  def training_step(self, batch, num_batch):
    images, masks = batch['pixel_values'], batch['labels']

    # Forward pass

    predictions = self(images,masks)[0]

    # upsample the predictions
    # from size (H/4,W/4) -> (H,W)
    predictions = torch.nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="nearest-exact",
            align_corners=False
        )

    weighted_loss = CrossEntropyLoss(weight=self.weights,ignore_index=255)
    loss = weighted_loss(predictions,masks)

    predictions = predictions.argmax(dim=1)


    # Evaluate the model
    self.train_mean_iou.add_batch(
            predictions= predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy()
        )
    if num_batch % self.metrics_interval == 0:

        metrics = self.train_mean_iou.compute(
            num_labels=self.num_classes,
            ignore_index=255,
            reduce_labels=False,
        )

        metrics = {'loss': loss, "mean_iou": metrics["mean_iou"], "mean_accuracy": metrics["mean_accuracy"]}

        for k,v in metrics.items():
            self.log(k,v)

        return(metrics)
    else:
        return({'loss': loss})

  def validation_step(self, batch, num_batch):
    images, masks = batch['pixel_values'], batch['labels']

    # Forward pass

    predictions = self(images,masks)[0]

    # up-samples the predictions
    # from size (H/4,W/4) -> (H,W)
    predictions = torch.nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="nearest-exact",
            align_corners=False
        )
    weighted_loss = CrossEntropyLoss(weight=self.weights,ignore_index=255)
    loss = weighted_loss(predictions,masks)
    predictions = predictions.argmax(dim=1)


    # Evaluate the model
    self.val_mean_iou.add_batch(
            predictions= predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy()
        )

    return({'val_loss': loss})

  def validation_epoch_end(self,outputs):
    metrics = self.val_mean_iou.compute(
              num_labels=self.num_classes,
              ignore_index=255,
              reduce_labels=False,
          )

    avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    val_mean_iou = metrics["mean_iou"]
    val_mean_accuracy = metrics["mean_accuracy"]

    metrics = {"val_loss": avg_val_loss, "val_mean_iou":val_mean_iou, "val_mean_accuracy":val_mean_accuracy}
    for k,v in metrics.items():
        self.log(k,v)

    return metrics

  def configure_optimizers(self):
    return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)

  def train_dataloader(self):
    return self.train_dl

  def val_dataloader(self):
    return self.val_dl



In [None]:
def train_model(train_dataloader,
                val_dataloader,
                class_weights,
                id2label,
                hf_model_name="nvidia/segformer-b2-finetuned-cityscapes-1024-1024",
                ckpt_path='/checkpoints/',
                accelerator_mode='gpu',
                devices=1,
                max_epochs=300,
                log_every_n_steps=8,
                last_ckpt_path=None,
                resume=False
            ):

    if accelerator_mode == "gpu":
        model = SegFormerFineTuned(
            id2label,
            train_dl=train_dataloader,
            val_dl=val_dataloader,
            metrics_interval=log_every_n_steps,
            class_weights=torch.Tensor(class_weights).cuda(),
            model_path=hf_model_name
        )
    else:
        model = SegFormerFineTuned(
            id2label,
            train_dl=train_dataloader,
            val_dl=val_dataloader,
            metrics_interval=log_every_n_steps,
            class_weights=torch.Tensor(class_weights),
            model_path=hf_model_name
        )

    # Callback to stop when the model stops improving
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode="min",
    )
    # monitor the evolution of training and validation metrics
    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")

    # Callback to see a prediction sample by the end of the training
    #visualize_callback = VisualizeSampleCallback()

    trainer = pl.Trainer(
        default_root_dir=ckpt_path,
        accelerator=accelerator_mode,
        devices=devices,
        callbacks=[early_stop_callback, checkpoint_callback],
        max_epochs=max_epochs,
        log_every_n_steps= log_every_n_steps,
        val_check_interval=len(train_dataloader),
    )

    if resume and last_ckpt_path:
      trainer.fit(model,ckpt_path=last_ckpt_path)
    else:
      trainer.fit(model)

    return trainer, model


In [None]:
res = trainer.validate(ckpt_path="best")
