In [None]:
import os
import torch
import requests
import torchvision
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from PIL import Image
from tqdm import tqdm
from IPython.display import display
from typing import Optional, Union, Tuple, List

from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import DetrImageProcessor, DetrForObjectDetection

In [None]:
# DETR 모델 불러오기 및 프로세서 초기화
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

In [None]:
# Inference를 위한 테스트 이미지
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
display(image)

In [None]:
# processor를 사용하여 이미지 전처리
inputs = processor(images=image, return_tensors="pt")

# inference 결과
outputs = model(**inputs)

In [None]:
# processor 내부 함수를 사용하여 결과 이미지 후처리
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]

In [None]:
# 예측된 bounding boxes
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(f"Label: {model.config.id2label[label.item()]} | Score: {round(score.item(), 4)} | Box: {box}")

In [None]:
# 결과 시각화
def plot_results(pil_img, scores, labels, boxes):
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = plt.cm.rainbow(np.linspace(0, 1, len(scores)))

    for score, label, (xmin, ymin, xmax, ymax), c in zip(scores, labels, boxes.tolist(), colors):
        ax.add_patch(patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        text = f"{model.config.id2label[label.item()]}: {score.item():0.4f}"
        if ymin < 20: ymin = ymin + 20
        ax.text(xmin, ymin, text, bbox={"facecolor": c, "alpha": 0.6}, clip_box=ax.clipbox, clip_on=True)

    plt.axis("off")
    plt.show()

plot_results(image, results["scores"], results["labels"], results["boxes"])

In [None]:
# COCO 2017 validation set의 이미지와 annotation 다운로드
%mkdir -p data
%wget http://images.cocodataset.org/zips/val2017.zip -P data/
%wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P data/
%unzip -q data/val2017.zip -d data/
%unzip -q data/annotations_trainval2017.zip -d data/

In [None]:
# 이미지와 annotation 정보를 로드하기 위한 Dataset 클래스 (processor 전처리 내장)
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(
        self,
        image_directory_path: str,
        image_processor,
        train: bool = True
    ):
        annotation_file_path = os.path.join(image_directory_path.split('/')[0], 'annotations/instances_val2017.json')
        super(CocoDetection, self).__init__(image_directory_path, annotation_file_path)
        self.image_processor = image_processor

    def __getitem__(self, idx):
        images, annotations = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        annotations = {'image_id': image_id, 'annotations': annotations}
        encoding = self.image_processor(images=images, annotations=annotations, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()
        target = encoding["labels"][0]

        return pixel_values, target

In [None]:
train_dataset = CocoDetection(
    image_directory_path='data/val2017',
    image_processor=processor,
    train=True
)

print("Number of training examples:", len(train_dataset))

In [None]:
# 학습을 위한 DataLoader
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = processor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    return {
        'pixel_values': encoding['pixel_values'],
        'labels': labels
    }

train_loader = DataLoader(dataset=train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
print("Batch number of train loader:", len(train_loader))

In [None]:
# finetuning을 위한 engine 함수
def finetune_model(model, train_loader, num_epochs, learning_rate):
    model.train()
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        total_loss = 0
        t = tqdm(enumerate(train_loader), total=len(train_loader), desc="Loss: ---")

        for i, batch in t:
            # pixel_values: 이미지, targets['boxes']: bbox 좌표 gt, targets['class_labels']: class gt
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels']
            targets = [{'boxes': label['boxes'].to(device), 'class_labels': label['class_labels'].to(device)} for label in labels]

            # inference (prediction + loss 등 포함) 결과
            outputs = model(pixel_values=pixel_values, labels=targets)
            loss = outputs.loss
            loss.backward()

            total_loss += loss.item()

            if i % 10 == 0:
                avg_loss = total_loss / (i+1)
                print(f"\n  ##### Iteration {i}, Average Loss: {avg_loss:.4f}")
            t.set_description(f"Loss: {loss.item():.4f}")
            t.refresh()

            # 빠른 실습을 위해
            if i == 30:
                break

        print(f"\n----- Epoch {epoch}, loss: {total_loss / len(train_loader)}")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

finetune_model(model.to(device), train_loader, num_epochs=1, learning_rate=1e-4)

In [None]:
from transformers.models.detr.configuration_detr import DetrConfig

# config 수정을 통해 모델 customize
detr_config = DetrConfig(
    d_model = 256,                   # layer들의 dimension
    dropout = 0.1,                   # embeddings, encoder, pooler에 사용될 dropout 비율
    activation_function = 'relu',    # encoder와 pooler에 사용될 activation function
    num_labels = 91,                 # 학습 데이터셋의 num_classes

    # encoder configs
    encoder_layers = 6,              # encoder layer의 개수
    encoder_attention_heads = 8,     # encoder 내부 attention layer의 attention head 개수
    encoder_ffn_dim = 2048,          # encoder 내부 FFN layer의 dimension

    # decoder configs
    decoder_layers = 6,              # decoder layer의 개수
    decoder_attention_heads = 8,     # decoder 내부 attention layer의 attention head 개수
    decoder_ffn_dim = 2048           # decoder 내부 FFN layer의 dimension
)

In [None]:
from transformers.models.detr.modeling_detr import DetrEncoder, DetrDecoder

# config에 맞는 Encoder과 Decoder
encoder = DetrEncoder(detr_config).to(device)
decoder = DetrDecoder(detr_config).to(device)

In [None]:
# positional embedding
from transformers.models.detr.modeling_detr import DetrSinePositionEmbedding

def build_position_encoding(config):
    # Transformer의 positional enocding을 2D로 일반화 -> 각 축에 d_model//2씩 생성 후 concat하여 d_model
    n_steps = config.d_model // 2
    position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)

    return position_embedding

In [None]:
# CNN backbone + positional embedding
from transformers.models.detr.modeling_detr import DetrConvEncoder, DetrConvModel

backbone_cnn = DetrConvEncoder(detr_config)
position_embedding = build_position_encoding(detr_config)

backbone = DetrConvModel(backbone_cnn, position_embedding).to(device)

In [None]:
# 테스트 이미지
import torchvision.transforms as T

transforms = T.Compose([
    T.Resize((800, 800)),  # Resize the image
    T.ToTensor(),  # Convert image to tensor
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
input = transforms(image).unsqueeze(0).to(device)

print(input.shape)

In [None]:
# 모든 pixel 처리
pixel_mask = torch.ones(((input.shape[0], input.shape[2], input.shape[3])), device=device)
print(f'Shape of input image: {input.shape}')
print(f'Shape of pixel mask: {pixel_mask.shape}')

# backbone forward
features, pos_embs = backbone(input, pixel_mask)
print(f'Num of output features: {len(features)}')

In [None]:
# Backbone의 final feature map과 downsampled mask
feature_map, mask = features[-1]
print(f'Shape of feature map: {feature_map.shape}')
print(f'Shape of downsampled mask: {mask.shape}')

In [None]:
# channel 조절해주는 1x1 convolution
input_projection = nn.Conv2d(
    in_channels=feature_map.shape[1],
    out_channels=256,
    kernel_size=1
    ).to(device)

projected_feature_map = input_projection(feature_map)
print(f'Shape of projected feature map: {projected_feature_map.shape}')

In [None]:
# feature map,
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
pos_emb = pos_embs[-1].flatten(2).permute(0, 2, 1)
flattened_mask = mask.flatten(1)

print(f'Shape of flattened_features: {flattened_features.shape}')
print(f'Shape of position embeddings: {pos_emb.shape}')
print(f'Shape of flattened_mask: {flattened_mask.shape}')

In [None]:
encoder_outputs = encoder(
    inputs_embeds=flattened_features,
    attention_mask=flattened_mask,
    object_queries=pos_emb
    )

# encoder_outputs[0]: last_hidden_state
print(f'Shape of encoder_outputs: {encoder_outputs[0].shape}')

In [None]:
# 객체 탐지를 위한 N개의 object queries 생성
N = 10
query_position_embeddings = nn.Embedding(N, 256)    # 임의의 embedding 생성
query_position_embeddings = query_position_embeddings.weight.unsqueeze(0).repeat(1, 1, 1).to(device)    # batch_size에 맞게 조정(현재는 1)
queries = torch.zeros_like(query_position_embeddings).to(device)    # Zero initialize

print(f'Shape of queries: {queries.shape}')
print(f'Query: {queries[0]}')

In [None]:
decoder_outputs = decoder(
    inputs_embeds=queries,
    object_queries=pos_emb,                              # Cross-attn에서 사용되는 positional embedding
    query_position_embeddings=query_position_embeddings, # Self-attn에서 사용되는 positional embedding
    encoder_hidden_states=encoder_outputs[0],
    encoder_attention_mask=flattened_mask
    )

# decoder_outputs[0]: last_hidden_state
print(f'Shape of decoder_outputs: {encoder_outputs[0].shape}')

In [None]:
# 모듈 전체 코드

from transformers.models.detr.modeling_detr import DetrPreTrainedModel

class DetrModel(DetrPreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)

        # backbone + positional encoding 생성
        backbone_cnn = DetrConvEncoder(config)
        position_embedding = build_position_encoding(config)
        self.backbone = DetrConvModel(backbone_cnn, position_embedding)

        # Input의 차원을 변경하는 projection layer
        self.input_projection = nn.Conv2d(backbone_cnn.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)

        # N개의 object query 생성
        self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)

        # 모델 내부 encoder와 decoder 생성
        self.encoder = DetrEncoder(config)
        self.decoder = DetrDecoder(config)

    def forward(
        self,
        pixel_values,
        pixel_mask = None,
        decoder_attention_mask = None,
        encoder_outputs = None,
        inputs_embeds = None,
        decoder_inputs_embeds = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        batch_size, num_channels, height, width = pixel_values.shape
        device = pixel_values.device

        if pixel_mask is None:
            pixel_mask = torch.ones(((batch_size, height, width)), device=device)

        # 1. pixel_values + pixel mask를 Backbone에 넣어 feature 생성
        features, pos_embs = self.backbone(pixel_values, pixel_mask)

        # Backbone의 final feature map과 downsampled mask
        feature_map, mask = features[-1]

        # 2. 1x1 convolution으로 channel을 d_model로 압축
        projected_feature_map = self.input_projection(feature_map)

        # 3. features, queries, mask flatten
        flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
        pos_emb = pos_embs[-1].flatten(2).permute(0, 2, 1)
        flattened_mask = mask.flatten(1)

        # 4. flattened_features + flattened_mask + position embeddings -> encoder 통과
        encoder_outputs = self.encoder(
            inputs_embeds=flattened_features,
            attention_mask=flattened_mask,
            object_queries=pos_emb
        )

        # 5. 객체 탐지를 위한 object query 초기화
        query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
        queries = torch.zeros_like(query_position_embeddings)

        # 6. query embeddings + pos_emb -> decoder 통과
        decoder_outputs = self.decoder(
            inputs_embeds=queries,
            attention_mask=None,
            object_queries=pos_emb,
            query_position_embeddings=query_position_embeddings,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=flattened_mask
        )

        return decoder_outputs.last_hidden_state

In [None]:
# 단일 box의 영역 계산
def box_area(boxes):
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# 두 박스 사이의 IoU(Intersection of Union) 계산
def box_iou(boxes1, boxes2):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
    right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

    width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
    inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]

    # union = (box1의 영역 넓이) + (box2의 영역 넓이) - (교집합 영역 넓이)
    union = area1[:, None] + area2 - inter

    # IoU = (Area of Overlap) / (Area of Union)
    iou = inter / union
    return iou, union


# 두 박스 사이의 GIoU(Generalized Intersection of Union) 계산
def generalized_box_iou(boxes1, boxes2):
    iou, union = box_iou(boxes1, boxes2)

    # 새로운 box C의 좌표와 생성
    top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

    # 새로운 box C의 width, height 계산 (넓이 = width*height)
    width_height = (bottom_right - top_left).clamp(min=0)  # [N,M,2]
    area = width_height[:, :, 0] * width_height[:, :, 1]

    # GIoU = IoU - ((Area of C) - (Area of Union)) / (Area of C)
    return iou - (area - union) / area

In [None]:
# Hungarian algorithm으로 최적의 조합을 찾아주는 matcher

from scipy.optimize import linear_sum_assignment
from transformers.image_transforms import center_to_corners_format

class DetrHungarianMatcher(nn.Module):
    def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
        super().__init__()
        # 세 가지 weight의 가중치 (matching cost의 weighted sum)
        self.class_cost = class_cost
        self.bbox_cost = bbox_cost
        self.giou_cost = giou_cost

    @torch.no_grad()
    def forward(self, outputs, targets):
        batch_size, num_queries = outputs["logits"].shape[:2]

        # batch로 연산하기 위해 flatten
        out_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # target label과 bbox도 concat
        target_ids = torch.cat([v["class_labels"] for v in targets])
        target_bbox = torch.cat([v["boxes"] for v in targets])

        # 1. Classification loss
        # -(class probability)로 근사하여 계산
        class_cost = -out_prob[:, target_ids]

        # 2-(1). Bbox l1-loss
        bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)

        # 2-(2). Bbox GIoU loss
        giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))

        # 3. Matrix 값을 채울 matching cost = 세 가지 loss의 Weighted sum
        cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
        cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        # scipy의 linear_sum_assignment로 hungarian algorithm 연산
        # 최적의 cost 조합인 indices 반환
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In [None]:
# 최적의 조합 indices로 최종 hungarian loss 계산
class DetrLoss(nn.Module):
    def __init__(self, matcher, num_classes, eos_coef, losses):
        super().__init__()
        self.matcher = matcher
        self.num_classes = num_classes
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

    # 1. classification loss
    def loss_labels(self, outputs, targets, indices, num_boxes):
        # 모델의 예측 확률값
        source_logits = outputs["logits"]

        # 정답 박스와 매칭된 예측 박스의 인덱스
        idx = self._get_source_permutation_idx(indices)

        # Construct label
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
        )
        target_classes[idx] = target_classes_o

        # cross entropy로 negative log likelihood 계산
        loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {"loss_ce": loss_ce}

        return losses

    # 2. bbox loss
    def loss_boxes(self, outputs, targets, indices, num_boxes):
        idx = self._get_source_permutation_idx(indices)
        source_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        # 2-(1). l1-loss 계산
        loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        # 2-(2). GIoU loss 계산
        loss_giou = 1 - torch.diag(
            generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
        )
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    # 최적의 매칭을 기준으로 pair를 조합하여 재정렬하는 함수
    def _get_source_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
        source_idx = torch.cat([source for (source, _) in indices])
        return batch_idx, source_idx

    # classification loss와 bbox loss 반환
    def get_loss(self, loss, outputs, targets, indices, num_boxes):
        loss_map = {
            "labels": self.loss_labels,
            "boxes": self.loss_boxes
        }
        return loss_map[loss](outputs, targets, indices, num_boxes)

    def forward(self, outputs, targets):
        # 마지막 layer의 출력과 target 사이의 최적의 매칭
        indices = self.matcher(outputs, targets)

        # 모든 node에 걸쳐 target bbox의 평균 개수를 계산 -> 정규화 목적
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        num_boxes = torch.clamp(num_boxes, min=1).item()

        # 두 가지의 loss 모두 계산
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        return losses

In [None]:
# bbox 좌표 예측을 위한 MLP prediction head
class DetrMLPPredictionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

In [None]:
class DetrForObjectDetection(DetrPreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)

        # DETR의 기본 encoder-decoder 모델
        self.model = DetrModel(config)

        # Prediction heads
        # Class label을 예측하는 linear head ("no object"를 포함하여 num_classes+1)
        self.class_labels_classifier = nn.Linear(
            config.d_model, config.num_labels + 1
        )
        # bbox 좌표를 예측하는 MLP head
        self.bbox_predictor = DetrMLPPredictionHead(
            input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
        )

    def forward(
        self,
        pixel_values,
        pixel_mask = None,
        decoder_attention_mask = None,
        encoder_outputs = None,
        inputs_embeds = None,
        decoder_inputs_embeds = None,
        labels = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        # 1.이미지를 DETR 기본 모델을 통해 전달하여 encoder + decoder 출력 반환
        sequence_output = self.model(
            pixel_values,
            pixel_mask=pixel_mask,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 2. Prediction heads를 통과하여 class logit과 pred bbox 생성
        logits = self.class_labels_classifier(sequence_output)
        pred_boxes = self.bbox_predictor(sequence_output).sigmoid()

        loss, loss_dict = None, None
        if labels is not None:
            # 3. Prediction과 Ground Truth 사이의 최적의 매칭 생성
            matcher = DetrHungarianMatcher(
                class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
            )
            # 4. 최적의 조합을 기반으로 하는 classification loss, bbox loss 계산
            losses = ["labels", "boxes"]
            criterion = DetrLoss(
                matcher=matcher,
                num_classes=self.config.num_labels,
                eos_coef=self.config.eos_coefficient,
                losses=losses,
            )
            criterion.to(self.device)
            outputs_loss = {}
            outputs_loss["logits"] = logits
            outputs_loss["pred_boxes"] = pred_boxes

            # (실제 3,4번 동시에 연산되는 구간)
            loss_dict = criterion(outputs_loss, labels)

            # 5. 모든 loss의 weighted sum으로 최종 loss 계산
            weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
            weight_dict["loss_giou"] = self.config.giou_loss_coefficient
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        return loss

In [None]:
model = DetrForObjectDetection(detr_config)

In [None]:
# training DETR detector
def train_model(model, train_loader, num_epochs=1, learning_rate=1e-4):
    model.train()

    optimizer = AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        total_loss = 0
        t = tqdm(enumerate(train_loader), total=len(train_loader), desc="Loss: ---")
        for i, batch in t:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels']
            targets = [{'boxes': label['boxes'].to(device), 'class_labels': label['class_labels'].to(device)} for label in labels]

            loss = model(pixel_values=pixel_values, labels=targets)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

            if i % 50 == 0:
                avg_loss = total_loss / (i+1)
                print(f"\n  ##### Iteration {i}, Average Loss: {avg_loss:.4f}")
            t.set_description(f"Loss: {loss.item():.4f}")
            t.refresh()

            if i == 20:
               break

        print(f"\n----- Epoch {epoch}, loss: {total_loss / len(train_loader)}")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model = DetrForObjectDetection(detr_config).to(device)
train_model(model, train_loader, num_epochs=1, learning_rate=1e-4)