### Finetune DETR to detect female-ish faces in paintings

In [None]:
! pip install transformers datasets huggingface_hub pytorch_lightning

In [None]:
!huggingface-cli login

In [None]:
import datasets
import matplotlib.pyplot as plt
import os
import pytorch_lightning as pl
import torch
import torchvision
import torchvision.transforms as T

from transformers import AutoModelForObjectDetection, DetrImageProcessor, Trainer, TrainingArguments

from PIL import Image as PImage

In [None]:
DETR_CHECKPOINT = "facebook/detr-resnet-50"
HF_DATASET = "thiagohersan/cordiais-faces"
HF_MODEL= "thiagohersan/detr-cordiais"

detr_processor = DetrImageProcessor.from_pretrained(
  DETR_CHECKPOINT,
  size={
    "shortest_edge": 800,
    "longest_edge": 800
  }
)

In [None]:
def to_coco_annotation(image_id, category, area, bbox):
  annotations = []
  for i in range(0, len(category)):
    new_ann = {
      "image_id": image_id,
      "category_id": category[i],
      "isCrowd": 0,
      "area": area[i],
      "bbox": list(bbox[i]),
    }
    annotations.append(new_ann)

  return annotations

In [None]:
def detr_annotate_augment_process(examples):
  image_ids = examples["image_id"]
  images, bboxes, area, categories = [], [], [], []
  for image, objects in zip(examples["image"], examples["objects"]):
    # TODO: augment here
    area.append(objects["area"])
    images.append(image)
    bboxes.append(objects["bboxes"])
    categories.append(objects["category"])

  targets = [
    {"image_id": id_, "annotations": to_coco_annotation(id_, cat_, ar_, box_)}
    for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
  ]

  return detr_processor(images=images, annotations=targets, return_tensors="pt")

In [None]:
hf_dataset_ = datasets.load_dataset(HF_DATASET).with_transform(detr_annotate_augment_process)
hf_dataset = hf_dataset_["train"].train_test_split(test_size=0.2, shuffle=True, seed=101010)

In [None]:
labels = hf_dataset["train"].features["objects"].feature["category"].names
id2label = {i:l for i,l in enumerate(labels)}
label2id = {l:i for i,l in id2label.items()}

In [None]:
def collate_fn(batch):
  pixel_values = [item["pixel_values"] for item in batch]
  encoding = detr_processor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
  labels = [item["labels"] for item in batch]
  batch = {}
  batch["pixel_values"] = encoding["pixel_values"]
  batch["pixel_mask"] = encoding["pixel_mask"]
  batch["labels"] = labels
  return batch

In [None]:
model = AutoModelForObjectDetection.from_pretrained(
  DETR_CHECKPOINT,
  id2label=id2label,
  label2id=label2id,
  revision="no_timm", 
  num_labels=len(id2label),
  num_queries=16,
  ignore_mismatched_sizes=True
)

In [None]:
training_args = TrainingArguments(
  output_dir="detr-resnet-50_finetuned_cordiais",
  per_device_train_batch_size=4,
  num_train_epochs=32,
  fp16=True,
  save_steps=200,
  logging_steps=50,
  learning_rate=1e-5,
  weight_decay=1e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=True,
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
trainer = Trainer(
  model=model,
  args=training_args,
  data_collator=collate_fn,
  train_dataset=hf_dataset["train"],
  tokenizer=detr_processor
)

trainer.train()

In [None]:
trainer.push_to_hub()

### Reload model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id2label = { 0: "female", 1: "not-female" }

In [None]:
model = DetrForObjectDetection.from_pretrained(HF_MODEL, id2label=id2label)
processor = DetrImageProcessor.from_pretrained(HF_MODEL)
model.to(device)

### Run on test data

In [None]:
COLORS = [
  [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]
]

def plot_results(pil_img, scores, labels, boxes, id2label):
  plt.figure(figsize=(16,10))
  plt.imshow(pil_img)
  ax = plt.gca()
  colors = COLORS * 100
  for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
    text = f'{id2label[label]}: {score:0.2f}'
    ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
  plt.axis('off')
  plt.show()
  plt.savefig("out.jpg")

In [None]:
pixel_values, target = test_dataset[0]
pixel_values = pixel_values.unsqueeze(0).to(device)
print(pixel_values.shape)

In [None]:
model.to(device)
with torch.no_grad():
  outputs = model(pixel_values=pixel_values, pixel_mask=None)

In [None]:
image_id = target["image_id"].item()
image = test_dataset.coco.loadImgs(image_id)[0]
image = PImage.open(os.path.join(COCORDIAIS_PATH, "test", image["file_name"]))

width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(
  outputs,
  target_sizes=[(height, width)],
  threshold=0.5
)

results = postprocessed_outputs[0]
plot_results(image, results["scores"], results["labels"], results["boxes"], id2label)