In [None]:
import os
from glob import glob
import shutil
import random
import itertools
from tqdm.notebook import tqdm
from pathlib import Path, PosixPath
from ultralytics import YOLO

In [None]:
class Distiller:
    def __init__(self, tuning_parameters, labelling_parameters, name=None):
        self.tuning_parameters = tuning_parameters
        self.labelling_parameters = labelling_parameters

    def tune_model(self, model, dataset, name):
        return tuned_model

    def label_dataset(self, model, unlabelled_dataset):
        return labelled_dataset

    def distil(self, big_model, small_model, labelled_dataset_1, unlabelled_dataset_2, job_name=None):
        print("Tuning Big Model")
        tuned_big_model = self.tune_model(big_model, labelled_dataset_1, name=job_name+"-big-tune")

        print("Tuning Small Model")
        tuned_small_model = self.tune_model(small_model, labelled_dataset_1, name=job_name+"-small-tune")

        print("Labelling Dataset")
        self.label_dataset(tuned_big_model, unlabelled_dataset_2)

        print("Retuning Small model")
        self.tune_model(tuned_small_model, unlabelled_dataset_2, name=job_name+"-small-retune")

In [None]:
lines_model_labelled = [
    "path: datasets/FebSynth # dataset root dir",
    "train: train/images # train images (relative to 'path')",
    "val: val/images # val images (relative to 'path')",
    "test: test/images # test images (optional)",
    "# Classes",
    "names:",
    " 0: pedestrian",
    " 1: people",
    " 2: bicycle",
    " 3: car",
    " 4: van",
    " 5: truck",
    " 6: tricycle",
    " 7: awning-tricycle",
    " 8: bus",
    " 9: motor",
]
with open("ultralytics/ultralytics/cfg/datasets/FebSynth.yaml", "wt") as yaml_out:
    yaml_out.writelines(s + '\n' for s in lines_model_labelled)

In [None]:
def batched(iterable, n):
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield batch

class YOLOv8Distiller(Distiller):
    def _init_(self, tuning_parameters, labelling_parameters):
        super()._init_(tuning_parameters, labelling_parameters)
    
    def tune_model(self, model, dataset, name):
        # Abuse of notation, converting model path to pt model
        model = YOLO(model)
        
        model.train(data=dataset, name=name, **self.tuning_parameters)
        tuned_model = PosixPath.joinpath(PosixPath(model.trainer.save_dir, "weights/best.pt"))
        return tuned_model
    
    def label_dataset(self, model, unlabelled_dataset):
        # Abuse of notation, converting model path to pt model
        model = YOLO(model)
        unlabelled_dataset = "datasets/"+unlabelled_dataset.split(".")[0]+"/train/"
        images = glob(unlabelled_dataset+"images/*")
        batches = list(batched(images, self.labelling_parameters["batch_size"]))
        for batch in tqdm(batches):
            results = model(batch)
            for i, result in enumerate(results):
                lines = []
                for cls, box in zip(result.boxes.cls, result.boxes.xywhn):
                    line = f"{int(cls)}"
                    for j in list(box):
                        line = line+f" {float(j):6.6f}"
                    lines.append(line)

                label_filename = unlabelled_dataset+"labels/"+batch[i].split("/")[-1].split(".")[0]+".txt"

                with open(label_filename, "wt") as f:
                    [f.write(line+"\n") for line in lines]

In [None]:
#raise Exception("Comment out this exception to run the labelling")
model_labelled_dir = "datasets/FebSynth"

images = glob("//Feb Data//*.png")
random.shuffle(images)

os.makedirs(model_labelled_dir, exist_ok=True)
os.makedirs(model_labelled_dir+"/train/images", exist_ok=True)
os.makedirs(model_labelled_dir+"/train/labels", exist_ok=True)
os.makedirs(model_labelled_dir+"/val/images", exist_ok=True)
os.makedirs(model_labelled_dir+"/val/labels", exist_ok=True)

for image in images:
    stem = Path(image).stem

    shutil.copy(image, model_labelled_dir+"/train/images")

# Results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

training_results = {
    "big-tune":pd.read_csv("runs/detect/kd-big-tune/results.csv").rename(columns=lambda x: x.strip()),
    "small-tune":pd.read_csv("runs/detect/kd-small-tune/results.csv").rename(columns=lambda x: x.strip()),
    "small-retune":pd.read_csv("runs/detect/kd-small-retune/results.csv").rename(columns=lambda x: x.strip()),
}

for column in [
    "metrics/precision(B)",
    "metrics/recall(B)",
    "metrics/mAP50(B)",
    "metrics/mAP50-95(B)",
]:
    metric = column.split("/")[1].split("(")[0]
    d = {"epochs":[], metric: [], "model": []}
    for key, df in training_results.items():
        d["epochs"].extend(list(df.index))
        d[metric].extend(list(df[column]))
        d["model"].extend([key]*len(df))
    data = pd.DataFrame(d)
    sns.lineplot(data=data,x="epochs",y=metric,hue="model")
    plt.show()

In [None]:
models = {
    "big-tune":"runs/detect/kd-big-tune/weights/best.pt",
    "small-tune":"runs/detect/kd-small-tune/weights/best.pt",
    "small-retune":"runs/detect/kd-small-retune/weights/best.pt",
}

for key, model_str in models.items():
    print("---------------")
    print(f"EVALUATING {key} ON TEST SET")
    print("---------------")
    model = YOLO(model_str)
    model.val(split="test")