In [1]:
!pip install -q transformers datasets evaluate

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import os
import json
from PIL import Image

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from huggingface_hub import hf_hub_download, snapshot_download

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL_SAVE_DIR = "../models/segmentation_model"
repo_id = "georgiisirotenko/circle-segmentation-finetune"
repo_dir = snapshot_download(repo_id, repo_type="dataset")

Fetching 217 files: 100%|██████████| 217/217 [00:00<00:00, 22267.55it/s]


In [4]:
class SemanticSegmentationDataset(Dataset):
    def __init__(self, root_dir, image_processor):
        self.root_dir = root_dir
        self.image_processor = image_processor
        self.masks_paths = os.listdir(os.path.join(root_dir, "masks"))

    def __len__(self):
        return len(self.masks_paths)

    def _get_image_name(self, idx):
      return self.masks_paths[idx][:-9] + ".png"

    def __getitem__(self, idx):

        image_name = self._get_image_name(idx)
        image = Image.open(os.path.join(self.root_dir, "images", image_name))

        segmentation_map = Image.open(os.path.join(self.root_dir, "masks", self.masks_paths[idx]))
        segmentation_map = Image.fromarray((np.array(segmentation_map) > 0.0)[..., 0].astype(np.uint8))

        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_()

        return encoded_inputs

In [5]:
image_processor = SegformerImageProcessor(reduce_labels=True)
train_dataset = SemanticSegmentationDataset(root_dir=repo_dir, image_processor=image_processor)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)



In [6]:
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", num_labels=2)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b5 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
epochs_num = 2
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(epochs_num):
   print("Epoch:", epoch)
   for idx, batch in enumerate(train_dataloader):

        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        with torch.no_grad():
          upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
          predicted = upsampled_logits.argmax(dim=1)

        if idx % 100 == 0:
          print("Loss:", loss.item())

Epoch: 0
Loss: 1.0449810028076172
Epoch: 1
Loss: 0.00045327944098971784


In [8]:
model.save_pretrained(MODEL_SAVE_DIR)