In [None]:
from sklearn.metrics import accuracy_score
import os
import pandas as pd
import cv2
import numpy as np
import torch
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor

from PIL import Image
import matplotlib.pyplot as plt
from torch import nn

from string import digits
import warnings
warnings.filterwarnings('ignore')

In [None]:
model_path = "finetuned_city_scapes_gta.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained(model_path).to(device)

In [None]:

def run_model(model_path, class_map_path, img_path, rgb_path=False, mask_path=False, mask_type="bw",save_path=False):
    
    if save_path:
        save=True
    
    feature_extractor_inference = SegformerFeatureExtractor(do_random_crop=False, do_pad=False)
    labels = [i for i in range(30)]
    
    df = pd.read_csv(class_map_path)
    classes = df['name']
    palette = df[['r', 'g', 'b']].values
    id2label = classes.to_dict()
    label2id = {v: k for k, v in id2label.items()}
    
    image =  Image.open(img_path)
    if mask_path:

        mask = Image.open(mask_path)
        fig, axs = plt.subplots(1, 2, figsize=(20, 10))
        axs[0].imshow(image)
        if mask_type=="bw":
            axs[1].imshow(mask.convert('L'))
        else:
            axs[1].imshow(mask)
        axs[0].axis('off')
        axs[1].axis('off')
        plt.show()
    else:
        plt.imshow(image)
        plt.show()
        
    pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device).cuda()
    model.eval()
    outputs = model(pixel_values=pixel_values)
    logits = outputs.logits.cpu()
    upsampled_logits = nn.functional.interpolate(logits,size=image.size[::-1], mode='bilinear',align_corners=False)

    seg = upsampled_logits.argmax(dim=1)[0]

    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
        
#     seg1 = np.array(mask)
#     mask_temp = np.zeros((seg1.shape[0], seg1.shape[1], 4), dtype=np.uint8) 
#     if mask_path!=False and mask_type=="bw":        
#         for label, color in enumerate(palette):
#             mask_temp[seg1 == label, :] = color

    mask_temp = color_seg

    img_1 = np.array(image) * 0.5 + color_seg * 0.5
    img_1 = img_1.astype(np.uint8)
    
    if save:
        image.save(os.path.join(os.path.dirname(save_path),"og_"+os.path.basename(save_path)))
        im_pil = Image.fromarray(img_1)
        im_pil.save(save_path)
        
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    axs[0].imshow(img_1)
    axs[1].imshow(mask_temp)
    axs[0].axis('off')
    axs[1].axis('off')
    plt.show()

    

### To run the model on images from a directory, specify the dir path in the **rgb_img_dir** variable 

In [None]:
rgb_img_dir = "dataset_AnomalyTrack/images"
file_names = os.listdir(rgb_img_dir)
file_full_paths = [os.path.join(rgb_img_dir,i) for i in file_names]
class_map_path = "class_dict_seg.txt"

for i in file_full_paths:
 
    img_path = i    

    run_model(model_path, class_map_path, img_path)
    