In [None]:
!pip install transformers


In [None]:
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation


# load Mask2Former fine-tuned on Cityscapes semantic segmentation
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")




In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/open

In [None]:
image_path = '/content/drive/MyDrive/open/test_image/TEST_0000.png'
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits

# you can pass them to processor for postprocessing
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# we refer to the demo notebooks for visualization (see "Resources" section in the Mask2Former docs)


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def extract_label(mask_path, target_label):
    # 마스크 파일을 불러옵니다.
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)

    # 특정 클래스 라벨에 해당하는 픽셀을 추출합니다.
    label_pixels = np.where(mask == target_label)

    # 추출된 픽셀을 이용하여 해당 클래스에 해당하는 이미지를 생성합니다.
    class_image = np.zeros_like(mask)
    class_image[label_pixels] = 255  # 해당 클래스 픽셀을 흰색으로 설정합니다.

    return class_image

# 라벨을 추출할 마스크 파일 경로와 추출하고자 하는 클래스 라벨을 지정합니다.
mask_path = '/content/drive/MyDrive/open/train_source_gt/TRAIN_SOURCE_0000.png'
target_label = 2  # 예를 들어, 추출하고자 하는 클래스의 라벨이 2번이라고 가정합니다.

# 라벨을 추출합니다.
extracted_label = extract_label(mask_path, target_label)

# 0부터 11까지의 라벨에 대해 라벨을 추출하고 출력합니다.
for label in range(12):
    extracted_label = extract_label(mask_path, label)
    plt.imshow(extracted_label, cmap='gray')
    plt.title(f'mask{label}')
    plt.axis('off')
    plt.show()

In [None]:
# generate random color palette, which maps each class to a RGB value
color_palette = [list(np.random.choice(range(256), size=3)) for _ in range(len(model.config.id2label))]
print(color_palette)

In [None]:
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

seg = predicted_semantic_map
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(color_palette)
for label, color in enumerate(palette):
    color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()
