<a href="https://colab.research.google.com/github/pgurazada/explore-dinov2/blob/main/emotion_classification_dinov2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this exercise, we train a classification model that uses the feature extraction capabilities of [Dino V2](https://arxiv.org/pdf/2304.07193.pdf).

# Imports

In [1]:
!pip install -q datasets flaml

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/521.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.6/521.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m491.5/521.2 kB[0m [31m7.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.2/295.2 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import torch
import io

import numpy as np
import torchvision.transforms as T

from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from flaml import AutoML

In [3]:
!pip show torch torchvision

Name: torch
Version: 2.1.0+cu118
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, triton, typing-extensions
Required-by: fastai, torchaudio, torchdata, torchtext, torchvision
---
Name: torchvision
Version: 0.16.0+cu118
Summary: image and video datasets and models for torch deep learning
Home-page: https://github.com/pytorch/vision
Author: PyTorch Core Team
Author-email: soumith@pytorch.org
License: BSD
Location: /usr/local/lib/python3.10/dist-packages
Requires: numpy, pillow, requests, torch
Required-by: fastai


# Data

In [4]:
emotion_ds = load_dataset("sxdave/emotion_detection")

Downloading readme:   0%|          | 0.00/116 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/157 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/45 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/157 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/7.03k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.32k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.98k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.63k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.22k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.01k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.26k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.67k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.33k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.51k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.37k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.03k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.51k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.00k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.85k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.86k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.37k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.89k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.72k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.01k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.22k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.67k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.30k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.70k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.94k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.48k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.18k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.31k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.96k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.94k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.71k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.39k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.43k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.11k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.01k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.33k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.09k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.60k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.08k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.23k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.96k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.61k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.40k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.66k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.99k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.17k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.49k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.17k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.85k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.31k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.08k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.50k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.38k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.93k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.55k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.67k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.88k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.37k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.84k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.29k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.95k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.61k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.32k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.78k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.29k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.21k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.92k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.33k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.31k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.32k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.37k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.80k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.05k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.68k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.53k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.91k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.97k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.48k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.66k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.54k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.86k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.49k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.66k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.18k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.65k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.26k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.99k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.82k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/45 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.29k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.61k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.80k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.93k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.43k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.82k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.86k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.70k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.74k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.35k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.89k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.58k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.22k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.41k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.22k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.98k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.79k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.43k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.97k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.30k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.87k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.70k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.93k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.76k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.36k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.92k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/23 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.26k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.53k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.04k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.82k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.04k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.08k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.99k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.01k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.54k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.18k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.29k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.48k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.70k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.32k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.12k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.33k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [5]:
train_data = list(emotion_ds['train'])

In [6]:
test_data = list(emotion_ds['test'])

# Model for Embeddings

In [7]:
dinov2_vit14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth
100%|██████████| 330M/330M [00:01<00:00, 248MB/s]


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
dinov2_vit14.to(device)

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [9]:
transform_image = T.Compose(
    [
        T.ToTensor(),
        T.Resize(244, antialias=True),
        T.CenterCrop(224),
        T.Normalize([0.5], [0.5])
    ]
)

In [10]:
def load_image(img: str) -> torch.Tensor:
    """
    Load an image and return a tensor that can be used as an input to DINOv2.
    """

    transformed_img = transform_image(img)[:3].unsqueeze(0)

    return transformed_img

In [11]:
def compute_embeddings(files: list) -> dict:
    """
    Create an index that contains all of the images in the specified list of files.
    """
    all_embeddings = {}

    with torch.no_grad():
      for i, file in enumerate(tqdm(files)):
        embeddings = dinov2_vit14(load_image(file['image']).to(device))

        all_embeddings[i] = np.array(embeddings[0].cpu().numpy()).reshape(1, -1).tolist()

    return all_embeddings

In [12]:
train_embeddings = compute_embeddings(train_data)

100%|██████████| 157/157 [00:12<00:00, 12.55it/s]


In [13]:
test_embeddings = compute_embeddings(test_data)

100%|██████████| 23/23 [00:00<00:00, 51.04it/s]


In [14]:
ytrain = np.array([file['label'] for file in train_data])
train_embedding_list = list(train_embeddings.values())
Xtrain = np.array(train_embedding_list).reshape(-1, dinov2_vit14.embed_dim)

ytest = np.array([file['label'] for file in test_data])
test_embedding_list = list(test_embeddings.values())
Xtest = np.array(test_embedding_list).reshape(-1, dinov2_vit14.embed_dim)

# Model for Classification

In [15]:
automl = AutoML()

In [16]:
automl.fit(
    X_train=Xtrain, y_train=ytrain,
    time_budget=240,
    log_file_name='emotions.log',
    task='classification',
    metric='accuracy',
    split_ratio=.3
)

[flaml.automl.logger: 12-13 14:03:18] {1679} INFO - task = classification
[flaml.automl.logger: 12-13 14:03:18] {1690} INFO - Evaluation method: cv
[flaml.automl.logger: 12-13 14:03:18] {1788} INFO - Minimizing error metric: 1-accuracy
[flaml.automl.logger: 12-13 14:03:18] {1900} INFO - List of ML learners in AutoML Run: ['lgbm', 'rf', 'xgboost', 'extra_tree', 'xgb_limitdepth', 'lrl1']
[flaml.automl.logger: 12-13 14:03:18] {2218} INFO - iteration 0, current learner lgbm
[flaml.automl.logger: 12-13 14:03:18] {2344} INFO - Estimated sufficient time budget=2820s. Estimated necessary time budget=65s.
[flaml.automl.logger: 12-13 14:03:18] {2391} INFO -  at 0.3s,	estimator lgbm's best error=0.4145,	best estimator lgbm's best error=0.4145
[flaml.automl.logger: 12-13 14:03:18] {2218} INFO - iteration 1, current learner lgbm
[flaml.automl.logger: 12-13 14:03:18] {2391} INFO -  at 0.5s,	estimator lgbm's best error=0.3831,	best estimator lgbm's best error=0.3831
[flaml.automl.logger: 12-13 14:03:

INFO:flaml.tune.searcher.blendsearch:No low-cost partial config given to the search algorithm. For cost-frugal search, consider providing low-cost values for cost-related hps via 'low_cost_partial_config'. More info can be found at https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune


[flaml.automl.logger: 12-13 14:07:18] {2391} INFO -  at 240.0s,	estimator lrl1's best error=0.2188,	best estimator lrl1's best error=0.2188




[flaml.automl.logger: 12-13 14:07:18] {2627} INFO - retrain lrl1 for 0.4s
[flaml.automl.logger: 12-13 14:07:18] {2630} INFO - retrained model: LogisticRegression(n_jobs=-1, penalty='l1', solver='saga')
[flaml.automl.logger: 12-13 14:07:18] {1930} INFO - fit succeeded
[flaml.automl.logger: 12-13 14:07:18] {1931} INFO - Time taken to find the best model: 240.04762935638428




In [17]:
print('Best ML leaner:', automl.best_estimator)
print('Best hyperparmeter config:', automl.best_config)
print(f'Best accuracy on validation data: {(1-automl.best_loss):.3f}')
print(f'Training duration of best run: {(automl.best_config_train_time):.4f} s')

Best ML leaner: lrl1
Best hyperparmeter config: {'C': 1.0}
Best accuracy on validation data: 0.781
Training duration of best run: 0.4214 s


In [18]:
automl.model.estimator

In [19]:
ypred = automl.predict(Xtest)

In [20]:
(ypred == ytest).mean()

0.6521739130434783