In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gc

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    del mask
    gc.collect()

def show_masks_on_image(raw_image, masks):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  ax.set_autoscale_on(False)
  for mask in masks:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()
  del mask
  gc.collect()

In [None]:
from transformers import pipeline
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=-1)


In [None]:
#local image test
from PIL import Image

raw_image = Image.open("crossfar.png").convert("RGB")

plt.imshow(raw_image)

In [None]:
outputs = generator(raw_image, points_per_batch=64)

In [None]:
masks = outputs["masks"]
def show_masks_on_image(raw_image, masks, save_path=None):
    plt.imshow(np.array(raw_image))
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for mask in masks:
        show_mask(mask, ax=ax, random_color=True)
    plt.axis("off")
    plt.show()
    if save_path is not None:
        plt.savefig(save_path)
    del mask
    gc.collect()
show_masks_on_image(raw_image, masks, save_path="crossfar_mask.png")

In [None]:
print(f"Number of masks: {len(outputs['masks'])}")
print(f"Shape of scores: {outputs['scores'].shape}")

In [101]:
# %% cell 7

import urllib
import json
from torchvision import models, transforms
from torch.autograd import Variable

# Load the pretrained model
model = models.resnet50(pretrained=True)
model.eval()

# Image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Download the ImageNet class index
class_idx_url = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
class_idx_str = urllib.request.urlopen(class_idx_url).read()
class_idx = json.loads(class_idx_str)

# Create a 2D array to represent the image
image_text = [[' ']*20 for _ in range(20)]

# Calculate the size of each cell in the original image
cell_width = raw_image.size[0] // 20
cell_height = raw_image.size[1] // 20

# Classify each mask
for i, mask in enumerate(masks):
    # Your existing code to classify the mask...
    # Convert mask to PIL image
    mask_image = Image.fromarray(mask)

    # Convert the image to RGB
    mask_image = mask_image.convert("RGB")

    # Apply transformations
    mask_image = transform(mask_image)

    # Unsqueeze dimensions
    mask_image = mask_image.unsqueeze(0)

    # Wrap it in Variable
    inputs = Variable(mask_image)

    # Forward pass
    outputs = model(inputs)

    # Get the index of the max log-probability
    _, predicted = torch.max(outputs.data, 1)

    # Get the class name from the class index
    class_name = class_idx[predicted.item()]

    print(f"Mask {i}: Class {class_name}")

    # Get the bounding box of the mask
    bbox = Image.fromarray(mask).getbbox()

    # Calculate the corresponding cells in the text image
    cell_x1 = bbox[0] // cell_width
    cell_y1 = bbox[1] // cell_height
    cell_x2 = (bbox[2] + cell_width - 1) // cell_width
    cell_y2 = (bbox[3] + cell_height - 1) // cell_height

    # Fill the corresponding cells in the text image with the class name
    for y in range(cell_y1, min(cell_y2, 20)):
        for x in range(cell_x1, min(cell_x2, 20)):
            image_text[y][x] = class_name

# Write the text image to a file
with open('output.txt', 'w') as f:
    for row in image_text:
        f.write(' '.join(row))
        f.write('\n')



Mask 0: Class envelope
Mask 1: Class bulletproof vest
Mask 2: Class vulture
Mask 3: Class switch
Mask 4: Class mortar
Mask 5: Class match
Mask 6: Class T-shirt
Mask 7: Class match
Mask 8: Class switch
Mask 9: Class T-shirt
Mask 10: Class T-shirt
Mask 11: Class match
Mask 12: Class notebook computer
Mask 13: Class match
Mask 14: Class bulletproof vest
Mask 15: Class notebook computer
Mask 16: Class match
Mask 17: Class match
Mask 18: Class T-shirt
Mask 19: Class match
Mask 20: Class match
Mask 21: Class spotlight
Mask 22: Class toilet paper
Mask 23: Class match
Mask 24: Class match
Mask 25: Class quill
Mask 26: Class notebook computer
Mask 27: Class match
Mask 28: Class match
Mask 29: Class nematode
Mask 30: Class digital clock
Mask 31: Class match
Mask 32: Class nematode
Mask 33: Class analog clock
Mask 34: Class match
Mask 35: Class match
Mask 36: Class cleaver
Mask 37: Class match
Mask 38: Class match
Mask 39: Class match
Mask 40: Class match
Mask 41: Class envelope
Mask 42: Class te