In [77]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from glob import glob
import sys
from ensemble_boxes import weighted_boxes_fusion as wbf

In [79]:
def xywhn2xyxy(norm_center_x, norm_center_y, norm_w, norm_h):
    norm_center_x = float(norm_center_x)
    norm_center_y = float(norm_center_y)
    norm_w = float(norm_w)
    norm_h = float(norm_h)
    x_min = float(0) if (norm_center_x - norm_w / 2) < 0 else (norm_center_x - norm_w / 2)
    y_min = float(0) if (norm_center_y - norm_h / 2) < 0 else (norm_center_y - norm_h / 2)
    x_max = float(1) if (norm_center_x + norm_w / 2) > 1 else (norm_center_x + norm_w / 2) 
    y_max = float(1) if (norm_center_y + norm_h / 2) > 1 else (norm_center_y + norm_h / 2)
    return (x_min, y_min, x_max, y_max)

def xyxyn2xywhn(bbox):
    x_min, y_min, x_max, y_max  = bbox
    norm_w = np.abs(x_max-x_min)
    norm_h = np.abs(y_max-y_min)
    norm_center_x = (x_max+x_min)/2
    norm_center_y = (y_max+y_min)/2
    return (norm_center_x, norm_center_y, norm_w, norm_h)

In [81]:
def make_df_list(results_path):
    df_list = []
    file_names = []

    for i, result in enumerate(os.listdir(results_path)):
        if result not in ["yolo10_a100", "dino_swin_re"]:
            continue
        data = []
        file_names = sorted(os.listdir(os.path.join(results_path, result)))
        for file_name in tqdm(file_names):
            file_path = os.path.join(results_path, result, file_name)
            with open(file_path, 'r') as file:
                for line in file:
                    label, norm_center_x, norm_center_y, norm_w, norm_h, cs = line.split()
                    x_min, y_min, x_max, y_max = xywhn2xyxy(norm_center_x, norm_center_y, norm_w, norm_h)
                    data.append([file_name, int(label), float(x_min), 
                                    float(y_min), float(x_max), float(y_max), float(cs)])
        df_list.append(pd.DataFrame(data, columns=["file_name", "label", "x_min", 
                                        "y_min", "x_max", "y_max", "cs"]))
    print(f"{i} dataframes were created")
    return df_list, file_names

In [82]:
def weights_box_fusion(results_path, save_dir):
    
    df_list, file_names = make_df_list(results_path)

    for file_name in tqdm(file_names):
        boxes_list=[]
        scores_list=[]
        labels_list=[]

        for df in df_list:
            sub_df = df[df['file_name']==file_name]
            boxes_list.append(sub_df[['x_min', 'y_min', 'x_max', 'y_max']].apply(lambda row: [row['x_min'], 
                                                                                              row['y_min'], 
                                                                                              row['x_max'], 
                                                                                              row['y_max']], 
                                                                                              axis=1).tolist())
            scores_list.append(sub_df['cs'].tolist())
            labels_list.append(sub_df['label'].tolist())
        
        if any(box_list for box_list in boxes_list if box_list):
            boxes, scores, labels = wbf(
                boxes_list, scores_list, labels_list,
                weights=None, iou_thr=0.55, skip_box_thr=0.001
            )
        
        with open(os.path.join(save_dir, file_name),'w') as f:
            for box, cs, label in zip(boxes, scores, labels):
                norm_center_x, norm_center_y, norm_w, norm_h=xyxyn2xywhn(box)
                f.write("%d %lf %lf %lf %lf %lf\n"%(label, norm_center_x, norm_center_y, norm_w, norm_h, cs))

In [83]:
results_path = '/workspace/traffic_light/submission'
save_dir = '/workspace/traffic_light/submission/wbf_dino_swin_epoch_2_yolo10_a100'

os.makedirs(save_dir, exist_ok=True)
weights_box_fusion(results_path, save_dir)

100%|██████████| 13505/13505 [00:00<00:00, 36795.75it/s]
100%|██████████| 13505/13505 [00:00<00:00, 25417.56it/s]


4 dataframes were created


100%|██████████| 13505/13505 [03:07<00:00, 72.01it/s]
