<a href="https://colab.research.google.com/github/seismosmsr/machine_learning/blob/main/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to fine-tune `SegFormerForSemanticSegmentation` on a custom **semantic segmentation** dataset, namely [RUGD](http://rugd.vision/). In semantic segmentation, the goal for the model is to label each pixel of an image with one of a list of predefined classes.

We load the encoder of the model with weights pre-trained on ImageNet-1k, and fine-tune it together with the decoder head, which starts with randomly initialized weights.

In [1]:
!pip install -q transformers datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m96.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m102.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

## Define PyTorch dataset and dataloaders

Here we define a [custom PyTorch dataset](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html). Each item of the dataset consists of an image and a corresponding segmentation map.

In [3]:
!pip install gdown
import gdown
import zipfile
import os

# I switched to pngs and jpgs to try and use tensorflows native vectorization
#todo: get gdal working so you can just use geotiff
url = 'https://drive.google.com/uc?id=1SfjV4rwnK49hCf-zJBktcLdGd_1otKu2'

# https://drive.google.com/file/d/1ZOKNZIn1_jXYiC2dTvE_sdL4GzjLDwaD/view?usp=sharing
#https://drive.google.com/file/d/1SfjV4rwnK49hCf-zJBktcLdGd_1otKu2/view?usp=drive_link
output = 'colorado_land_use_png_jpg.zip'

gdown.download(url,output,quiet = False)


cwd = os.getcwd()
with zipfile.ZipFile(cwd+'/colorado_land_use_png_jpg.zip', 'r') as zip_ref:
    zip_ref.extractall(cwd+'/sample_data')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Downloading...
From: https://drive.google.com/uc?id=1SfjV4rwnK49hCf-zJBktcLdGd_1otKu2
To: /content/colorado_land_use_png_jpg.zip
100%|██████████| 32.7M/32.7M [00:00<00:00, 340MB/s]


In [7]:
import pandas as pd

# read in RUGD's color palette
color_map = pd.read_csv('/content/COLUCD_colormap.csv')
color_map.columns = ["label_idx", "label", "R", "G", "B"]
color_map.head()
     


Unnamed: 0,label_idx,label,R,G,B
0,1,Structures,0,0,0
1,2,Surfaces,108,64,20
2,3,Water,255,229,204
3,4,Grass,0,102,0
4,5,Scrub,0,255,0


In [29]:
# validation_dataset = ValidationDataset(root_dir=validation_root_dir, feature_extractor=feature_extractor)

In [80]:
import os
from PIL import Image
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from datasets import load_metric
import torch
from torch import nn
from tqdm.notebook import tqdm

class SemanticSegmentationDataset(Dataset):
    def __init__(self, root_dir, feature_extractor, train=True):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.img_dir = os.path.join(self.root_dir, "rgbNIR")
        self.ann_dir = os.path.join(self.root_dir, "labels")
        image_file_names = sorted([fname for fname in os.listdir(self.img_dir) if fname.lower().endswith(('.jpg', '.jpeg'))])
        annotation_file_names = sorted([fname for fname in os.listdir(self.ann_dir) if fname.lower().endswith('.png')])
        self.images = [os.path.join(self.img_dir, fname) for fname in image_file_names]
        self.annotations = [os.path.join(self.ann_dir, fname) for fname in annotation_file_names]
        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        annotation = Image.open(self.annotations[idx])
        annotation = np.array(annotation, dtype=np.uint8)
        annotation_2d = np.zeros((annotation.shape[0], annotation.shape[1]), dtype=np.uint8)
        # Convert annotation to 2D segmentation map
        if isinstance(annotation, bool):
            annotation_2d = np.zeros_like(image, dtype=np.uint8)
        else:
            annotation_2d = np.array(annotation)
        encoded_inputs = self.feature_extractor(image, Image.fromarray(annotation_2d), return_tensors="pt")
        for k, v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()
        return encoded_inputs

class ValidationDataset(Dataset):
    def __init__(self, root_dir, feature_extractor):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.img_dir = os.path.join(self.root_dir, "rgbNIR")
        self.ann_dir = os.path.join(self.root_dir, "labels")
        image_file_names = sorted([fname for fname in os.listdir(self.img_dir) if fname.lower().endswith(('.jpg', '.jpeg'))])
        annotation_file_names = sorted([fname for fname in os.listdir(self.ann_dir) if fname.lower().endswith('.png')])
        self.images = [os.path.join(self.img_dir, fname) for fname in image_file_names]
        self.annotations = [os.path.join(self.ann_dir, fname) for fname in annotation_file_names]
        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        annotation = Image.open(self.annotations[idx])
        annotation = np.array(annotation, dtype=np.uint8)
        annotation_2d = np.zeros((annotation.shape[0], annotation.shape[1]), dtype=np.uint8)
        # Convert annotation to 2D segmentation map
        if isinstance(annotation, bool):
            annotation_2d = np.zeros_like(image, dtype=np.uint8)
        else:
            annotation_2d = np.array(annotation)
        encoded_inputs = self.feature_extractor(image, Image.fromarray(annotation_2d), return_tensors="pt")
        for k, v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()
        return encoded_inputs

root_dir = '/content/sample_data/colorado_land_use_png_jpg_simplified/training'
validation_root_dir = '/content/sample_data/colorado_land_use_png_jpg_simplified/validation'

feature_extractor = SegformerFeatureExtractor(reduce_labels=True)
train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor)
validation_dataset = ValidationDataset(root_dir=validation_root_dir, feature_extractor=feature_extractor)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=False)

label2id = {label: id for id, label in zip(color_map.label_idx, color_map.label)}
id2label = {id: label for id, label in zip(color_map.label_idx, color_map.label)}
id2color = {id: [R, G, B] for id, (R, G, B) in zip(color_map.label_idx, zip(color_map.R, color_map.G, color_map.B))}

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=9, id2label=id2label, label2id=label2id)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

metric = load_metric("accuracy")

for epoch in range(10):
    for idx, batch in enumerate(train_dataloader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
            predicted_np = predicted.detach().cpu().numpy()
            references_np = labels.detach().cpu().numpy()

            predicted_list = [float(np.mean(predicted_np))]  # Wrap in a list
            references_list = [float(np.mean(references_np))]  # Wrap in a list

            metric.add_batch(predictions=predicted_list, references=references_list)

    # Compute and print accuracy on the training dataset
    metrics = metric.compute()
    print("Epoch:", epoch)
    print("Training Loss:", loss.item())
    print("Training Accuracy:", metrics["accuracy"])

    # # Evaluate on the validation dataset
    # model.eval()  # Set the model to evaluation mode
    # with torch.no_grad():
    #     for idx, batch in enumerate(validation_dataloader):
    #         pixel_values = batch["pixel_values"].to(device)
    #         labels = batch["labels"].to(device)

    #         # Compute predictions
    #         outputs = model(pixel_values=pixel_values)
    #         predicted = outputs.logits.argmax(dim=1)
    #         predicted_np = predicted.detach().cpu().numpy()
    #         references_np = labels.detach().cpu().numpy()

    #         predicted_list = [float(np.mean(predicted_np))]  # Wrap in a list
    #         references_list = [float(np.mean(references_np))]  # Wrap in a list

    #         metric.add_batch(predictions=predicted_list, references=references_list)

    #     # Compute and print accuracy on the validation dataset
    #     validation_metrics = metric.compute()
    #     print("Validation Accuracy:", validation_metrics["accuracy"])

    model.train()  # Set the model back to training mode


Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.weight', 'decode_head.linear_c.2.proj.weight', 'decode_head.classifier.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_fuse.

Epoch: 0
Training Loss: 0.6174104809761047
Training Accuracy: 0.75390625
Epoch: 1
Training Loss: 0.4580656588077545
Training Accuracy: 0.91015625
Epoch: 2
Training Loss: 0.49875926971435547
Training Accuracy: 0.92578125
Epoch: 3
Training Loss: 0.535179853439331
Training Accuracy: 0.953125
Epoch: 4
Training Loss: 0.5624239444732666
Training Accuracy: 0.94140625
Epoch: 5
Training Loss: 0.3005901277065277
Training Accuracy: 0.92578125
Epoch: 6
Training Loss: 0.29305967688560486
Training Accuracy: 0.93359375
Epoch: 7
Training Loss: 0.1579374074935913
Training Accuracy: 0.9453125
Epoch: 8
Training Loss: 0.3870920240879059
Training Accuracy: 0.9453125


KeyboardInterrupt: ignored

In [None]:
import os
from PIL import Image
import numpy as np
from torchvision.transforms import ToTensor
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch

test_root_dir = '/content/sample_data/colorado_land_use_png_jpg_simplified/training/rgbNIR'
output_dir = '/content/predicted'
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)
# model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
model.eval()

image_file_names = sorted([fname for fname in os.listdir(test_root_dir) if fname.lower().endswith(('.jpg', '.jpeg'))])

for image_file in image_file_names:
    image_path = os.path.join(test_root_dir, image_file)
    image = Image.open(image_path)

    # Convert image to tensor
    image_tensor = ToTensor()(image)

    # Apply feature extraction on image tensor
    encoded_inputs = feature_extractor(images=image_tensor.unsqueeze(0))

    # Convert encoded inputs to PyTorch tensors and move them to the appropriate device
    encoded_inputs = {key: torch.tensor(value).to(device) for key, value in encoded_inputs.items()}

    # Compute predictions
    with torch.no_grad():
        outputs = model(**encoded_inputs)
        predicted = outputs.logits.argmax(dim=1)

    # Convert predicted tensor to numpy array
    predicted_np = predicted.detach().cpu().numpy()

    # Save predicted segmentation maps as JPEG images
    img_name = os.path.splitext(os.path.basename(image_file))[0]
    output_path = os.path.join(output_dir, img_name + '.jpg')
    pred_img = Image.fromarray(predicted_np[0].astype(np.uint8))
    pred_img.save(output_path)




In [56]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(200):  # loop over the dataset multiple times
   print("Epoch:", epoch)
   for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
          upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
          predicted = upsampled_logits.argmax(dim=1)
          
          # note that the metric expects predictions + labels as numpy arrays
          metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        # let's print loss and metrics every 100 batches
        if idx % 100 == 0:
          # we use _compute for now which fixes an issue in speed
          # see this Github thread for more info: 
          metrics = metric._compute(num_labels=len(id2label), 
                                   ignore_index=255,
                                   reduce_labels=False, # we've already reduced the labels before)
          )

          print("Loss:", loss.item())
          print("Mean_iou:", metrics["mean_iou"])
          print("Mean accuracy:", metrics["mean_accuracy"])

Epoch: 0


  0%|          | 0/256 [00:00<?, ?it/s]

TypeError: ignored

## Inference

Finally, let's check whether the model has really learned something. Let's test the trained model on an image:

In [None]:
image = Image.open('/content/drive/MyDrive/SegFormer/Notebooks/Tutorial notebooks/RUGD/RUGD_sample-data/images/creek_00001.png')
image

In [None]:
# prepare the image for the model
encoding = feature_extractor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)
print(pixel_values.shape)

In [None]:
# forward pass
outputs = model(pixel_values=pixel_values)

In [None]:
# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = outputs.logits.cpu()
print(logits.shape)

In [None]:
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# First, rescale logits to original image size
upsampled_logits = nn.functional.interpolate(logits,
                size=image.size[::-1], # (height, width)
                mode='bilinear',
                align_corners=False)

# Second, apply argmax on the class dimension
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in id2color.items():
    color_seg[seg == label, :] = color

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

Let's print the pure predicted segmentation map:

In [None]:
Image.fromarray(color_seg)

Compare this to the ground truth segmentation map:

In [None]:
map = Image.open('/content/drive/MyDrive/SegFormer/Notebooks/Tutorial notebooks/RUGD/RUGD_sample-data/annotations/creek_00001.png') 
map 