In [None]:
# 1. Catch predictions
# 2. Score Limit (>=0.5)
# 3. Check IoU with GT
#4. Pick up the class with the GT
# 5. Bbox crop with +10% expansion
# 6. Save the crop as an image and the class in txt
import json
import cv2
import os 
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from tqdm import tqdm

# 'train' mode to generate crops to train a classifier
# 'pred' mode to generate crops to be classificate by a classifier (do not generate the groundtruth label)
mode= 'pred'

def get_prediction_from_coco_json(json_path, score_thr=0.0):
    with open(json_path, 'r') as openfile:
        coco_json_result = json.load(openfile)
    if mode == 'train':
        return coco_json_result['annotations']
    else:
        return [pred for pred in coco_json_result if pred['score'] > score_thr]

def get_coco_json(gt_path=r"/home/matheus_levy/workspace/dataset/mdetection_dataset/test/1_class_annotation.coco.json"):
    with open(gt_path, 'r') as openfile:
        coco_json_gt = json.load(openfile)
    return coco_json_gt

def calculate_iou(box1, box2):
    x1, y1, w1, h1 = box1[0], box1[1], box1[2], box1[3]
    x2, y2, w2, h2 = box2[0], box2[1], box2[2], box2[3]

    x_inter = max(x1, x2)
    y_inter = max(y1, y2)
    w_inter = max(0, min(x1 + w1, x2 + w2) - x_inter)
    h_inter = max(0, min(y1 + h1, y2 + h2) - y_inter)

    area_inter = w_inter * h_inter
    area_union = w1 * h1 + w2 * h2 - area_inter

    iou = area_inter / area_union
    return iou

def get_ann_from_image_id(id, coco_preds, coco_gts):
    preds = [pred for pred in coco_preds if pred['image_id'] == id]
    if mode == 'pred':
        gts = [gt for gt in coco_gts if gt['image_id'] == id]
    else:
        gts = None
    return preds, gts

def match_preds_with_gt(preds, gts, iou_thr=0.5):
    for pred in preds:
        pred_bbox = pred['bbox']
        best_iou = 0
        best_ann = None
        for gt in gts:
            gt_bbox = gt['bbox']
            iou = calculate_iou(pred_bbox, gt_bbox)
            if iou > best_iou:
                best_iou = iou
                best_ann = gt
        if best_ann is not None:
            pred['category_id'] = best_ann['category_id']
        else:
            pred['category_id'] = -1
    return preds

def qtd_imagens_in_results(coco_preds):
    image_ids = set()
    for prediction in coco_preds:
        image_id = prediction['image_id']
        image_ids.add(image_id)
    return len(image_ids)

def qtd_imagens_in_gt(coco_gt):
    imagens = coco_gt['images']
    return len(imagens)


def image_id_to_name(id, coco_gt):
    imagens = coco_gt['images']
    for imagem in imagens:
        if imagem['id'] == id:
            return imagem['file_name']

def imagem_by_id(id, coco_gt, path_to_image):
    file_name = image_id_to_name(id, coco_gt)
    img = cv2.imread(os.path.join(path_to_image, file_name))
    return img

def crop_imagem(imagem, bbox, aumento=0.3):
    x,y, largura, altura = bbox[0], bbox[1], bbox[2], bbox[3]
    aumento_largura = largura * aumento
    aumento_altura = altura * aumento
    x -= aumento_largura / 2  # Subtrai metade do aumento da largura do ponto x
    y -= aumento_altura / 2   # Subtrai metade do aumento da altura do ponto y
    if (x<0):
        x=0
    if (y<0):
        y=0
    largura += aumento_largura  # Aumenta a largura
    altura += aumento_altura    # Aumenta a altura
    x,y, largura, altura = int(x), int(y), int(largura), int(altura)
    return imagem[y:y+altura, x:x+largura]
    
json_path = r'/home/matheus_levy/workspace/RPN_YOLO_Center_Retina/Cascade_R_CNN/predicts_bbox_cascade_swin.json.bbox.json'
save_crop_image_path = r"/home/matheus_levy/workspace/dataset/Crops/Cascade_R_CNN/test/images"
save_crop_label_path = r"/home/matheus_levy/workspace/dataset/Crops/Cascade_R_CNN/test/labels"

preds = get_prediction_from_coco_json(json_path) # Durante o treino é o proprio gt
gts = get_coco_json()
if mode == 'pred':
    qtd_imagens = qtd_imagens_in_gt(gts)
if mode == 'train':
    qtd_imagens = len(preds)
preds_classes = []

for i in range(qtd_imagens):
    p, g = get_ann_from_image_id(i, preds, gts['annotations'])
    if mode == 'pred':
        ps = match_preds_with_gt(p, g)
        preds_classes.extend(ps)
    else:
        preds_classes.extend(p)

with open("retina_without_thr.json", "w") as outfile:
    json.dump(preds_classes, outfile, indent=2)

i=0
for pred in tqdm(preds_classes):
    imagem = imagem_by_id(pred['image_id'],
    gts, path_to_image=r"/home/matheus_levy/workspace/dataset/mdetection_dataset/test")
    crop = crop_imagem(imagem, pred['bbox'])
    if crop.size == 0:
        continue
    coco_bbox = pred["bbox"]
    x, y, largura, altura = coco_bbox

    cv2.imwrite(os.path.join(save_crop_image_path, f"{i}.png"), crop)
    if mode == 'train':
        json_label = {"image_id": pred['image_id'], 
                  "category_id": pred["category_id"],
                  "bbox": pred["bbox"], 
                  } 
    else:
        json_label = {"image_id": pred['image_id'], 
                    "category_id": pred["category_id"],
                    "bbox": pred["bbox"],
                    "score": pred["score"]
                    } 
    with open(f"{save_crop_label_path}/{i}.json", "w") as outfile:
        json.dump(json_label, outfile, indent=2)
    i+=1