In [None]:
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
import torch.nn as nn


In [None]:
####  DATASET IMPORT GOES HERE  ####

import copy
import os

from datasets import Dataset, DatasetDict, Image


dir_path = "" # filepath here
#get file names
im_dir = os.listdir(dir_path)
im_dir = [dir_path + s for s in im_dir]
im_dir.sort()


#split by 80% to 20%
image_paths_train = im_dir[:int(len(im_dir)*0.8)]
image_paths_validation = im_dir[-int(len(im_dir)*0.2):]



#now for the labels
dir_path = "" #put path here
#get file names
lab_dir = os.listdir(dir_path)
lab_dir = [dir_path + s for s in lab_dir]
lab_dir.sort()

#split the same as the images
label_paths_train = lab_dir[:int(len(lab_dir)*0.8)]
label_paths_validation = lab_dir[-int(len(lab_dir)*0.2):]
    
#create dataset from the two filepaths
def create_dataset(image_paths, label_paths):
    dataset = Dataset.from_dict({"image": sorted(image_paths),
                                "label": sorted(label_paths)})
    dataset = dataset.cast_column("image", Image())
    dataset = dataset.cast_column("label", Image())
    return dataset

#creating Dataset objects
train_dataset = create_dataset(image_paths_train, label_paths_train)
validation_dataset = create_dataset(image_paths_validation, label_paths_validation)

#creating DatasetDict
ds = DatasetDict({
     "train": train_dataset,
     "validation": validation_dataset,
     }
)



In [None]:


import json

#creating labels where 0 = background, 1 = 1st mask, 2 = 2nd mask
id2label = {0: 'BG', 1: 'label1', 2: 'label2'}
with open('id2label.json', 'w') as fp:
    json.dump(id2label, fp)
    
    
train_ds = ds["train"]
valid_ds = ds["validation"]

#Check test and train set if data is properly set up
valid_ds[0]


#jitter to improve learning with this dataset
jitter = transforms.ColorJitter(contrast=0.5, saturation=0.25)

#dictionary for label id - done already with the custom dataset

id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

print(id2label)
print(label2id)


#PREPROCESSING

from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
checkpoint = "nvidia/mit-b0"

image_processor = AutoImageProcessor.from_pretrained(checkpoint) #reduce_labels=True removed to include background class


def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(images, labels)
    return inputs

train_ds.set_transform(train_transforms)
valid_ds.set_transform(val_transforms)

import evaluate
#mean intersection over union for evaluation calculation
metric = evaluate.load("mean_iou")

In [None]:
#finding the metrics for the model
def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        print(type(logits), logits.shape)
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [None]:
from transformers import TrainingArguments, Trainer,  AutoModelForSemanticSegmentation
#model for semantic segmentation
model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id).to(device)


In [None]:
#various hyperparameters to input to the model when training
training_args = TrainingArguments(
    output_dir= "/Users/zachderse/Documents",
    learning_rate=6e-5,
    num_train_epochs=7, #was at 50
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    use_cpu = False
)

#trainer for the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
)


trainer.train()

In [None]:
from datasets import load_dataset

# checking the model visually through the validation set
#open an example image from the validation set
image = ds["validation"][0]["image"]

#using GPU or CPU depending on availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#encode the image
encoding = image_processor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)

#print the image
image

In [None]:
#use the trained model to generate a mask
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu() 
print(outputs[0].shape)
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]
print(pred_seg.shape)

# make each label a separate color, and leave the background as no color
def ade_palette():
    return np.asarray([
        [0, 0, 0],
        [0, 250,0],        
        [250,0,250]
    ])

In [None]:
### map ###

import matplotlib.pyplot as plt
import numpy as np

color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

In [None]:
import cv2

#using cv2 to remove very small labels through opening and closing
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10)) 
color_seg_alt = cv2.morphologyEx(color_seg, cv2.MORPH_OPEN, kernel, iterations=2)


kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) 
color_seg_alt = cv2.morphologyEx(color_seg_alt, cv2.MORPH_CLOSE, kernel, iterations=4)


img = np.array(image) * 0.5 + color_seg_alt * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()