In [3]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [4]:
cd  "/content/drive/MyDrive/final_proj/task2"


/content/drive/MyDrive/final_proj/task2


In [5]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image
import json


In [6]:
label_map = {'r': 1, 'n': 2, 'b': 3, 'k': 4, 'q': 5, 'p': 6, 'R': 7, 'N': 8, 'B': 9, 'K': 10, 'Q': 11, 'P': 12}
reversed_label_map = {v: k for k, v in label_map.items()}


In [7]:
class ChessDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = sorted([file for file in os.listdir(os.path.join(root_dir)) if file.endswith('.png')])

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.image_list[idx])
        json_name = os.path.splitext(img_name)[0] + ".json"

        image = Image.open(img_name).convert("RGB")
        with open(json_name) as f:
            annotation = json.load(f)

        # extract labels and boxes of chess pieces
        pieces_info = annotation['pieces']
        labels = [label_map[piece['piece']] for piece in pieces_info]
        boxes = [piece['box'] for piece in pieces_info]

        # convert box coordinates to (xmin, ymin, xmax, ymax) format
        boxes = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in boxes]

        if self.transform:
            image = self.transform(image)

        targets = {
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': torch.tensor(labels)
        }

        return image, targets,img_name


In [8]:
def collate_fn(batch):
    return tuple(zip(*batch))


In [9]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator


In [10]:
def calculate_iou(box1, box2):
    # get coordinates of intersection rectangle
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    # calculate intersection area
    if x_right < x_left or y_bottom < y_top:
        return 0.0
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # calculate areas of individual boxes
    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    # calculate union area
    union_area = area_box1 + area_box2 - intersection_area
    iou = intersection_area / union_area
    return iou


In [11]:

def find_golden_match(gts, pred, pred_idx, threshold=0.5, ious=None):
    golden_match_idx = -1
    best_iou = -1

    for idx, gt in enumerate(gts):
        iou = calculate_iou(gt, pred)

        if iou >= threshold and iou > best_iou:
            best_iou = iou
            golden_match_idx = idx

    return golden_match_idx


In [12]:
def calculate_metrics(gts, preds, threshold = 0.5, ious=None) -> float:
    n = len(preds)
    tp = 0
    fp = 0
    fns=[1 for i in range(len(gts))]
    for pred_idx in range(n):

        golden_match_gt_idx = find_golden_match(gts, preds[pred_idx], pred_idx,
                                                threshold=threshold, ious=ious)

        if golden_match_gt_idx >= 0:
            # true positive
            tp += 1
            # remove matched GT box
            fns[golden_match_gt_idx]=0
        else:
            # false positive, no matching GT box
            fp += 1

    # false negative, a gt box had no associated predicted box.
    fn = sum(fns)
    p=tp / (tp + fp + fn)
    r=tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1=2*p*r/(p+r) if (p+r)>0 else 0.0
    return p,r,f1


## Predict

In [13]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


In [14]:
def predict(model,test_loader,device):

    model.eval()

    precisions=[]
    recalls=[]
    f1s=[]
    with torch.no_grad():
        for step, (images, targets,image_names) in enumerate(test_loader):
            images = list(image.to(device) for image in images)
            outputs = model(images)

            for i, image in enumerate(images):
                detection_threshold = 0.05

                sample = image.permute(1,2,0).cpu().numpy()

                #image = image.cpu().numpy()
                boxes = outputs[i]['boxes'].data.cpu().numpy()
                scores = outputs[i]['scores'].data.cpu().numpy()
                labels = outputs[i]['labels'].data.cpu().numpy()
                valid_indices = scores >= detection_threshold
                boxes = boxes[valid_indices].astype(np.int32)
                scores = scores[valid_indices]
                labels = labels[valid_indices]
                gt_boxes = targets[i]['boxes'].cpu().numpy()
                preds_sorted_idx = np.argsort(scores)[::-1]
                preds_sorted = boxes[preds_sorted_idx]
                image_precision,image_recall,image_f1 = calculate_metrics(preds_sorted, gt_boxes)
                precisions.append(image_precision)
                recalls.append(image_recall)
                f1s.append(image_f1)

                fig, ax = plt.subplots(1, 1, figsize=(8, 8))

                for box, label, score in zip(boxes, labels, scores):
                    box_coords = (box[0], box[1]), box[2] - box[0], box[3] - box[1]
                    rect = Rectangle(*box_coords, fill=False, edgecolor='red', linewidth=2)
                    ax.add_patch(rect)
                    label=reversed_label_map[label]
                    label_text = f"{label}:{score:.2f}"
                    ax.text(box[0], box[1], label_text, bbox=dict(facecolor='white', alpha=0.5))

                ax.set_axis_off()
                ax.imshow(sample)
                plt.tight_layout(pad=0)
                plt.savefig('results/' + image_names[i][5:])
                print(image_names[i][5:] + ' saved')
                with open('results/'+image_names[i][5:][:-4]+'.json', 'w') as json_file:
                    json.dump({"boxes":boxes.tolist(),"labels":labels.tolist(),"scores":scores.tolist()}, json_file)
                print(image_names[i][5:][:-4]+'.json' + ' saved')
    test_prec = np.mean(precisions)
    test_recall = np.mean(recalls)
    test_f1=np.mean(f1s)
    return test_prec,test_recall,test_f1


In [15]:
batch_size=18
transform = transforms.Compose([transforms.ToTensor()])
test_directory="test/"
test_dataset = ChessDataset(test_directory, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


In [16]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 13 # 12 class + background

in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

params = [p for p in model.parameters() if p.requires_grad]



model.load_state_dict(torch.load("models/best_model_ep5_s100_f11.0.pth"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:01<00:00, 128MB/s]


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [None]:
test_prec,test_recall,test_f1=predict(model,test_loader,device)


In [None]:
print("Test results: precison " + str(test_prec) + " recall " + str(test_recall) + " f1 "+str(test_f1))


Test results: preicison 0.9536442835138548 recall 0.9539949483806488 f1 0.9538158512966417
