In [None]:
from utils import *
from datasets import load_dataset
path_labels = "../datasets/Norway/train/annotations/xmls/"
path_imgs = "../datasets/Norway/train/images/"

create_metadata(path_imgs, path_labels)

In [None]:
dataloader = load_dataset(path_imgs, split="train[:50]")

In [None]:
dataset = dataloader
dataset[1]

In [None]:
from transformers import YolosFeatureExtractor
feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small', size=(800,800)) # , reduce_labels=True

In [None]:

def transforms(example_batch):
    images = example_batch["image"]
    ids_ = example_batch["image_id"]
    objects = example_batch["annotations"]
    targets = [
        {"image_id": id_, "annotations": object_} for id_, object_ in zip(ids_, objects)
    ]
    inputs = feature_extractor(images=images, annotations=targets , return_tensors="pt")
    return inputs
dataset = dataset.with_transform(transforms)

In [None]:
ds = dataset.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
from transformers import YolosForObjectDetection
label2id = {'D00': 0, 'D10': 1, 'D40': 2, 'D20': 3, } # 'pothole': 4
id2label = {"0":'D00', "1":'D10', "2":'D40', "3":'D20'} # "4":'pothole'

model = YolosForObjectDetection.from_pretrained('hustvl/yolos-small',
                                                id2label=id2label,
                                                label2id=label2id,
                                                ignore_mismatched_sizes=True)

In [None]:
import torch
def collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    labels = [item["labels"] for item in batch]
    batch = {}
    batch["pixel_values"] = torch.stack(pixel_values)
    batch["labels"] = labels
    return batch

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=3,
    evaluation_strategy="steps",
    num_train_epochs=1,
    fp16=False,
    save_steps=100,
    eval_steps=1,
    logging_steps=1,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=collate_fn
)

trainer.train()