### Finetune DETR to detect female-ish faces in paintings

In [None]:
! pip install --upgrade scipy transformers datasets huggingface_hub pytorch-lightning pycocotools

In [None]:
!huggingface-cli login

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import datasets
import matplotlib.pyplot as plt
import os
import pytorch_lightning as pl
import torch
import torchvision.transforms as T

from torch.utils.data import DataLoader
from transformers import AutoModelForObjectDetection, DetrImageProcessor, Trainer, TrainingArguments

from Cocordiais import CocordiaisData
from PIL import Image as PImage

In [None]:
DETR_MODEL = "facebook/detr-resnet-50"
HF_DATASET = "thiagohersan/cordiais-faces"
HF_MODEL= "thiagohersan/detr-cordiais-autotrain"

In [None]:
detr_processor = DetrImageProcessor.from_pretrained(
  DETR_MODEL,
  size={
    "shortest_edge": 800,
    "longest_edge": 800
  }
)

cocordiais_data = CocordiaisData(detr_processor)

hf_dataset = datasets.load_dataset(HF_DATASET)
hf_dataset = hf_dataset["train"].train_test_split(test_size=0.2, shuffle=True, seed=101010)

train_dataset = hf_dataset["train"].with_transform(cocordiais_data.to_coco(train=True))
test_dataset = hf_dataset["test"].with_transform(cocordiais_data.to_coco(train=False))

print("Number of examples:\n  Train: %s\n  Test: %s" % (len(train_dataset), len(test_dataset)))

In [None]:
labels = train_dataset.features["objects"].feature["category"].names
id2label = {i:l for i,l in enumerate(labels)}
label2id = {l:i for i,l in id2label.items()}

In [None]:
model = AutoModelForObjectDetection.from_pretrained(
  DETR_MODEL,
  id2label=id2label,
  label2id=label2id,
  revision="no_timm", 
  num_labels=len(id2label),
  num_queries=16,
  ignore_mismatched_sizes=True
)

In [None]:
training_args = TrainingArguments(
  output_dir=HF_MODEL,
  per_device_train_batch_size=12,
  per_device_eval_batch_size=4,
  num_train_epochs=48,
  fp16=True,
  save_strategy="epoch",
  save_total_limit=2,
  logging_strategy="epoch",
  learning_rate=1e-5,
  weight_decay=1e-4,
  remove_unused_columns=False
)

trainer = Trainer(
  model=model,
  args=training_args,
  data_collator=cocordiais_data.collate_batch,
  train_dataset=train_dataset,
  eval_dataset=test_dataset,
  tokenizer=detr_processor
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub(HF_MODEL, private=True)

### Reload model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id2label = { 0: "female", 1: "not-female" }

In [None]:
model = DetrForObjectDetection.from_pretrained(HF_MODEL, id2label=id2label)
processor = DetrImageProcessor.from_pretrained(HF_MODEL)
model.to(device)

### Run on test data

In [None]:
COLORS = [
  [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]
]

def plot_results(pil_img, scores, labels, boxes, id2label):
  plt.figure(figsize=(16,10))
  plt.imshow(pil_img)
  ax = plt.gca()
  colors = COLORS * 100
  for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
    text = f'{id2label[label]}: {score:0.2f}'
    ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
  plt.axis('off')
  plt.show()
  plt.savefig("out.jpg")

In [None]:
pixel_values, target = test_dataset[0]
pixel_values = pixel_values.unsqueeze(0).to(device)
print(pixel_values.shape)

In [None]:
model.to(device)
with torch.no_grad():
  outputs = model(pixel_values=pixel_values, pixel_mask=None)

In [None]:
image_id = target["image_id"].item()
image = test_dataset.coco.loadImgs(image_id)[0]
image = PImage.open(os.path.join(COCORDIAIS_PATH, "test", image["file_name"]))

width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(
  outputs,
  target_sizes=[(height, width)],
  threshold=0.5
)

results = postprocessed_outputs[0]
plot_results(image, results["scores"], results["labels"], results["boxes"], id2label)