In [1]:
import torch
import os
from MaskRCNN import *
from MaskRCNN import maskrcnn_FF_fpn_v2_own_backbone, MaskRCNNPredictor
from FastRCNN import  FastRCNNPredictor
import torchvision.transforms as T
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import cv2
from collections import defaultdict
from shapely.geometry import Polygon, MultiPolygon
from shapely.ops import unary_union
from shapely.validation import explain_validity
import csv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_classes = 2
name_model = "./final_mask.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# eval dataset
def resize_and_to_tensor():
    return T.Compose([
        T.ToTensor()
    ])
class evalImageDataset(Dataset):
    def __init__(self, img_folder, transform=None):
        self.img_folder = img_folder
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(img_folder) if os.path.isfile(os.path.join(img_folder, f))])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_folder, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img, self.image_files[idx]

In [4]:
eval_data = evalImageDataset(img_folder = "./test/image",transform=resize_and_to_tensor())
eval_data_loader = DataLoader(eval_data, batch_size=1, shuffle=False, num_workers=0)

In [5]:
# Define the model and load the weights
model = maskrcnn_FF_fpn_v2_own_backbone(backbone_type = "convnext", version="convnext_base.fb_in22k_ft_in1k_384",num_classes=num_classes,min_size=[300,350,400,450,500,550,600],max_size=700,
                                              image_mean = [0.4807, 0.4841, 0.4823],
                                              image_std = [0.2165, 0.2045, 0.2040],
                                              rpn_pre_nms_top_n_train = 10000,
                                              rpn_pre_nms_top_n_test = 10000,
                                              rpn_post_nms_top_n_train = 7500,
                                              rpn_post_nms_top_n_test = 7500,
                                              rpn_nms_thresh =0.6,
                                              rpn_fg_iou_thresh = 0.6,
                                              rpn_bg_iou_thresh = 0.2,
                                              rpn_batch_size_per_image =3000,
                                              rpn_positive_fraction = 0.7,
                                              rpn_score_thresh = 0.0,
                                              box_score_thresh = 0.03,
                                              box_nms_thresh = 0.5,
                                              box_detections_per_img = 700,
                                              box_fg_iou_thresh = 0.6,
                                              box_bg_iou_thresh = 0.5,
                                              box_batch_size_per_image =3000,
                                              box_positive_fraction = 0.5,
                                              )


in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256

model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
model.load_state_dict(torch.load(name_model, map_location=device))
model.to(device)


  model.load_state_dict(torch.load(name_model, map_location=device))


MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.4807, 0.4841, 0.4823], std=[0.2165, 0.2045, 0.204])
      Resize(min_size=[300, 350, 400, 450, 500, 550, 600], max_size=700, mode='bilinear')
  )
  (backbone): Backbone(
    (own_backbone): ConvNeXtBackbone(
      (convnext): FeatureListNet(
        (stem_0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (stem_1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
        (stages_0): ConvNeXtStage(
          (downsample): Identity()
          (blocks): Sequential(
            (0): ConvNeXtBlock(
              (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
              (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
              (mlp): Mlp(
                (fc1): Linear(in_features=128, out_features=512, bias=True)
                (act): GELU()
                (drop1): Dropout(p=0.0, inplace=False)
                (norm): Identity()
    

In [7]:
# Define the functions to post-process the model output
# The following functions are used to post-process the model output
def is_clockwise(points):
    sum = 0
    for i in range(len(points)):
        x1, y1 = points[i]
        x2, y2 = points[(i + 1) % len(points)]
        sum += (x2 - x1) * (y2 + y1)
    return sum > 0
def fix_invalid_polygons(polygons):
    fixed_polygons = []
    for polygon in polygons:
        if not polygon.is_valid:
            fixed_polygon = polygon.buffer(0)
            if not fixed_polygon.is_valid:
                fixed_polygon = polygon.buffer(1e-10)  
                if not fixed_polygon.is_valid:
                    print(f"Invalid polygon could not be fixed: {explain_validity(polygon)}")
                    continue
            fixed_polygons.append(fixed_polygon)
        else:
            fixed_polygons.append(polygon)
    return fixed_polygons
def filter_small_polygons(polygons, min_area=1.0):
    return [p for p in polygons if p.area >= min_area]
def merge_polygons_with_topology(polygons):
    shapely_polygons = [Polygon(p) for p in polygons if len(p) >= 4]
    valid_polygons = fix_invalid_polygons(shapely_polygons)
    valid_polygons = filter_small_polygons(valid_polygons, min_area=10.0)
    merged_polygon = unary_union(valid_polygons)
    if isinstance(merged_polygon, Polygon):
        return [list(merged_polygon.exterior.coords)]
    elif isinstance(merged_polygon, MultiPolygon):
        return [list(p.exterior.coords) for p in merged_polygon.geoms]
    else:
        return []

# Run the model on the test set
model.eval()
score_threshold = 0.5
mask_threshold = 0.6
count = 0
results_dict = defaultdict(list)
output_csv = './submission.csv'
for  img_tensor, image_name in eval_data_loader:
    count+=1
    images = list(image.to(device) for image in img_tensor)
    with torch.no_grad():
        outputs = model(images,model='eval')
    for out in outputs:
      H, W = images[0].shape[1:]

      masks = out['masks']  # [N, 1, 28, 28]
      boxes = out['boxes']  # [N, 4]
      scores = out['scores']  # [N]

      keep = scores >= score_threshold
      masks = masks[keep]
      boxes = boxes[keep]
      scores = scores[keep]

      binary_masks = (masks > mask_threshold).squeeze(1).cpu().numpy()

      for i in range(len(binary_masks)):
          mask = binary_masks[i]
          contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

          for contour in contours:
                contour = contour[:, 0, :].tolist()
                if is_clockwise(contour):
                    contour.reverse()
                results_dict[int(str(image_name)[2:6])].append(contour)

# Merge the polygons and save the results
for image_id, polygons in results_dict.items():
    merged_polygons = merge_polygons_with_topology(polygons)
    results_dict[image_id] = merged_polygons

# Save the results to a CSV file
results = []

min_id = 0
max_id = 999

for image_id in range(min_id, max_id + 1):
    if image_id in results_dict:
        polygons = results_dict[image_id]
    else:
        polygons = []
    segment = '[' + ', '.join([str(p) for p in polygons]) + ']'
    image_id_str = str(image_id)
    results.append([image_id_str, segment])

with open(output_csv, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['ImageID', 'Coordinates'])
    for result in results:
        image_id = result[0]
        coordinates = result[1]
        writer.writerow([image_id, coordinates])


KeyboardInterrupt: 