<a href="https://colab.research.google.com/github/topkek777/grad_work/blob/master/vkr2023_ipynb%22.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Установка и подключение необходимых библиотек

In [None]:
pip install -qq datasets transformers evaluate timm albumentations
from tqdm import tqdm
import numpy as np
import os
from PIL import Image, ImageDraw

from google.colab import drive
drive.mount('/content/drive')

Подключение к порталу HuggingFace для сохранения обученной модели

In [None]:
from huggingface_hub import notebook_login

notebook_login()
# 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

Установка датасета

In [None]:
import datasets
#litter_data = load_dataset("Kili/plastic_in_river")  # Загрузка с сайта
litter_data = datasets.load_from_disk('/content/drive/MyDrive/dataset')  # Загрузка с Google Drive
#litter_data = datasets.load_from_disk('/content/drive/MyDrive/datasetnew')

In [None]:
litter_data["train"].features

Преобразование датасета под необходимый формат COCO, с которым работает алгоритм DETR

In [None]:
temp = [i for i in range(len(litter_data["train"]))]

In [None]:
new_bbox = []
for i in tqdm(range((len(litter_data["train"])))):
    new_bbox_i = []

    for j in range(len(litter_data["train"][i]["litter"]["label"])):
        new_bbox_j = []

        x, y, w, h = tuple(litter_data["train"][i]["litter"]["bbox"][j])
        im_to_bbox = litter_data["train"][i]["image"].getbbox()

        new_bbox_j.append(round((x - w/2) * im_to_bbox[2],1))
        new_bbox_j.append(round((y - h/2) * im_to_bbox[3],1))
        new_bbox_j.append(round(w* im_to_bbox[2],1))
        new_bbox_j.append(round(h* im_to_bbox[3],1))
        new_bbox_i.append(new_bbox_j)
        
    new_bbox.append(new_bbox_i)

In [None]:
litter_data["train"][0]["litter"]

In [None]:
new_area = []

for i in tqdm(range(len(litter_data["train"]))):
    new_area_i = []

    for j in range(len(litter_data["train"][i]["litter"]["label"])):
        new_area_i.append(round(new_bbox[i][j][2] * new_bbox[i][j][3],1))
        
    new_area.append(new_area_i)

In [None]:
new_area[:10]

In [None]:
litter_data["train"] = litter_data["train"].add_column(name="image_id", column=temp)
litter_data["train"] = litter_data["train"].add_column(name="bbox_new", column=new_bbox)
litter_data["train"] = litter_data["train"].add_column(name="area", column=new_area)

litter_data["train"] = litter_data["train"].rename_column("litter", "objects")
litter_data["train"] = litter_data["train"].flatten()
litter_data["train"] = litter_data["train"].remove_columns("objects.bbox")

In [None]:
categories = ["PLASTIC_BAG", "PLASTIC_BOTTLE", "OTHER_PLASTIC_WASTE", "NOT_PLASTIC_WASTE"]

id2label = {index: x for index, x in enumerate(categories, start=0)}
label2id = {v: k for k, v in id2label.items()}

Код для проверки отрисовки рамок на изображениях

In [None]:
'''for i in range(len(litter_data["train"][3]["objects.label"])):
    box = litter_data["train"][3]["bbox_new"][i]
    class_idx = litter_data["train"][3]["objects.label"][0]
    x, y, w, h = tuple(box)
    draw.rectangle((x, y, x + w, y + h), outline="red", width=2)
    draw.text((x, y), id2label[class_idx], fill="white")'''

Необходимая трансформация для дальнейшего обучения модели

In [None]:
import albumentations
import torch

from transformers import AutoImageProcessor

checkpoint = "facebook/detr-resnet-50"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
transform = albumentations.Compose(
    [
        albumentations.Resize(480, 480),
        albumentations.HorizontalFlip(p=1.0),
        albumentations.RandomBrightnessContrast(p=1.0),
    ],
    bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
)

In [None]:
def formatted_anns(image_id, category, area, bbox):

    annotations = []
    for i in range(0, len(category)):
        new_ann = {
            "image_id": image_id,
            "category_id": category[i],
            "isCrowd": 0,
            "area": area[i],
            "bbox": list(bbox[i]),
        }
        annotations.append(new_ann)

    return annotations

In [None]:
def transform_aug_ann(examples):

    #image_ids = examples["image_id"]
    image_ids, images, bboxes, area, categories = [], [], [], [], []
    for image_id, image, labels, bboxe, areas in zip(examples["image_id"],
              examples["image"], examples["objects.label"], examples["bbox_new"], examples["area"]):
        
        if len(bboxe) == 0:
            continue
        fl = False
        for bb in bboxe:
            if bb[2] == 0 or bb[3] == 0:
                fl = True
        if fl:
            continue
        image = np.array(image.convert("RGB"))[:, :, ::-1]
        out = transform(image=image, bboxes=bboxe, category=labels)

        image_ids.append(image_id)
        images.append(out["image"])
        bboxes.append(out["bboxes"])
        categories.append(out["category"])
        area.append(areas)

    targets = [
        {"image_id": id_, "annotations": formatted_anns(id_, cat_, ar_, box_)}
        for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
    ]
    #print(images)
    return image_processor(images=images, annotations=targets, return_tensors="pt")

In [None]:
litter_data["train"] = litter_data["train"].with_transform(transform_aug_ann)

Далее идет обучение модели

In [None]:
def collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    encoding = image_processor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
    labels = [item["labels"] for item in batch]
    batch = {}
    batch["pixel_values"] = encoding["pixel_values"]
    batch["pixel_mask"] = encoding["pixel_mask"]
    batch["labels"] = labels
    return batch

In [None]:
from transformers import AutoModelForObjectDetection

model = AutoModelForObjectDetection.from_pretrained(
    checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="azaza1",
    per_device_train_batch_size=8,
    num_train_epochs=10,
    fp16=True,
    save_steps=200,
    logging_steps=50,
    learning_rate=1e-4,
    weight_decay=1e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=litter_data["train"],
    tokenizer=image_processor,
)

trainer.train()

Загрузка на портал HuggingFace

In [None]:
trainer.push_to_hub()

Использование модели для детекции пластика

In [None]:
from transformers import pipeline
import requests

url = "https://datasets-server.huggingface.co/assets/Kili/plastic_in_river/--/default/train/59/image/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)

obj_detector = pipeline("object-detection", model="TopKek/detr-resnet-50_plastic_in_river_10ep")
obj_detector(image)

In [None]:
image_processor = AutoImageProcessor.from_pretrained("TopKek/detr-resnet-50_plastic_in_river_10ep")
model = AutoModelForObjectDetection.from_pretrained("TopKek/detr-resnet-50_plastic_in_river_10ep")

with torch.no_grad():
    inputs = image_processor(images=imagee, return_tensors="pt")
    outputs = model(**inputs)
    target_sizes = torch.tensor([imagee.size[::-1]])
    results = image_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )

In [None]:
draw = ImageDraw.Draw(image)

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    x, y, x2, y2 = tuple(box)
    print(x,y,x2,y2)
    draw.rectangle((x, y, x2, y2), outline="red", width=1)
    draw.text((x, y), model.config.id2label[label.item()], fill="white")

image