### Finetune DETR to detect female-ish faces in paintings

In [None]:
! pip install --upgrade scipy transformers datasets huggingface_hub pytorch-lightning pycocotools

In [None]:
!huggingface-cli login

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import datasets
import matplotlib.pyplot as plt
import os
import pytorch_lightning as pl
import torch
import torchvision.transforms as T

from torch.utils.data import DataLoader
from transformers import DetrConfig, DetrForObjectDetection, DetrImageProcessor

from Cocordiais import CocordiaisData
from PIL import Image as PImage

### Load dataset from HF and turn to COCO format

In [None]:
DETR_MODEL = "facebook/detr-resnet-50"
HF_DATASET = "thiagohersan/cordiais-faces"
HF_MODEL= "thiagohersan/detr-cordiais"

### Create DataLoaders

In [None]:
detr_processor = DetrImageProcessor.from_pretrained(
  DETR_MODEL,
  size={
    "shortest_edge": 800,
    "longest_edge": 800
  }
)

cocordiais_data = CocordiaisData(detr_processor)

hf_dataset = datasets.load_dataset(HF_DATASET)
hf_dataset = hf_dataset["train"].train_test_split(test_size=0.2, shuffle=True, seed=101010)

train_dataset = hf_dataset["train"].with_transform(cocordiais_data.to_coco(train=True))
test_dataset = hf_dataset["test"].with_transform(cocordiais_data.to_coco(train=False))

print("Number of examples:\n  Train: %s\n  Test: %s" % (len(train_dataset), len(test_dataset)))

In [None]:
train_dataloader = DataLoader(
  train_dataset,
  collate_fn=cocordiais_data.collate_batch,
  batch_size=12,
  shuffle=True
)

val_dataloader = DataLoader(
  test_dataset,
  collate_fn=cocordiais_data.collate_batch,
  batch_size=4,
  shuffle=False
)

In [None]:
# verify
pixel_values, _, target = train_dataset[0].values()
print(pixel_values.shape)
print(target)

batch = next(iter(train_dataloader))
print(batch.keys())
pimg = T.ToPILImage()(batch["pixel_values"][0])
print(pimg.size)
pimg


### Train with PyTorchLightning

In [None]:
labels = train_dataset.features["objects"].feature["category"].names
id2label = {i:l for i,l in enumerate(labels)}
label2id = {l:i for i,l in id2label.items()}

class Detr(pl.LightningModule):
  def __init__(self, dl_train, dl_val, lr, lr_backbone, weight_decay):
    super().__init__()
    # replace COCO classification head with custom head
    self.model = DetrForObjectDetection.from_pretrained(
      DETR_MODEL,
      revision="no_timm", 
      num_labels=len(id2label),
      num_queries=16,
      ignore_mismatched_sizes=True
    )

    self.lr = lr
    self.lr_backbone = lr_backbone
    self.weight_decay = weight_decay

    self.dataloader_train = dl_train
    self.dataloader_val = dl_val
    self.batch_size_train = dl_train.batch_size
    self.batch_size_val = dl_val.batch_size

  def forward(self, pixel_values, pixel_mask):
    outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)
    return outputs

  def common_step(self, batch, batch_idx):
    pixel_values = batch["pixel_values"]
    pixel_mask = batch["pixel_mask"]
    labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

    outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

    loss = outputs.loss
    loss_dict = outputs.loss_dict

    return loss, loss_dict

  def training_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)
    self.log("training_loss", loss, batch_size=self.batch_size_train)
    for k,v in loss_dict.items():
      self.log("train_" + k, v.item(), batch_size=self.batch_size_train)
    return loss

  def validation_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)     
    self.log("validation_loss", loss, batch_size=self.batch_size_val)
    for k,v in loss_dict.items():
      self.log("validation_" + k, v.item(), batch_size=self.batch_size_val)
    return loss

  def configure_optimizers(self):
    param_dicts = [
      {
        "params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]
      },
      {
        "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
        "lr": self.lr_backbone,
      },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=self.lr,
    weight_decay=self.weight_decay)

    return optimizer

  def train_dataloader(self):
    return self.dataloader_train

  def val_dataloader(self):
    return self.dataloader_val


In [None]:
# verify the outputs
model = Detr(dl_train=train_dataloader, dl_val=val_dataloader, lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)
outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])
print(outputs.logits.shape)

In [None]:
!rm -rf lightning_logs
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
trainer = pl.Trainer(max_epochs=48, gradient_clip_val=0.1, accelerator="auto")
trainer.fit(model)

In [None]:
model.model.push_to_hub(HF_MODEL, private=True)
detr_processor.push_to_hub(HF_MODEL, private=True)

### Reload model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id2label = { 0: "female", 1: "not-female" }

In [None]:
model = DetrForObjectDetection.from_pretrained(HF_MODEL, id2label=id2label)
processor = DetrImageProcessor.from_pretrained(HF_MODEL)
model.to(device)

### Run on test data

In [None]:
COLORS = [
  [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]
]

def plot_results(pil_img, scores, labels, boxes, id2label):
  plt.figure(figsize=(16,10))
  plt.imshow(pil_img)
  ax = plt.gca()
  colors = COLORS * 100
  for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
    text = f'{id2label[label]}: {score:0.2f}'
    ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
  plt.axis('off')
  plt.show()

In [None]:
pixel_values, target = test_dataset[0]
pixel_values = pixel_values.unsqueeze(0).to(device)
print(pixel_values.shape)

In [None]:
model.to(device)
with torch.no_grad():
  outputs = model(pixel_values=pixel_values, pixel_mask=None)

In [None]:
image_id = target["image_id"].item()
image = T.ToPILImage()(pixel_values)

width, height = image.size
postprocessed_outputs = processor.post_process_object_detection(
  outputs,
  target_sizes=[(height, width)],
  threshold=0.5
)

results = postprocessed_outputs[0]
plot_results(image, results["scores"], results["labels"], results["boxes"], id2label)