This repository contains Python code to retrieve Steam games with similar store banners, using Facebook's DINO.
Image similarity is assessed by the cosine similarity between image features encoded by DINO.
DINO is a method to train self-supervised models, especially well-suited for Vision Transformers (ViT). Model checkpoints were pre-trained on ImageNet-1k (1.28M images with 1000 classes) with no label.
In this repository, image features are extracted:
- following different strategies,
- and based on different models (
ViT-S/16
,ViT-S/8
,ViT-B/16
,ViT-B/8
).
Data is identical to the one used in steam-CLIP
.
It consists of vertical Steam banners (300x450 resolution), available for 29982 out of 48792 games, i.e. 61.4% of games.
Images are resized to 224x224 resolution and available in an archive (703 MB) as a release in this repository.
However, DINO has its own pre-processing pipeline, as in eval_linear.py
and eval_knn.py
:
- resize to 256 resolution, i.e. the smallest edge of the image will match this number,
- center-crop at 224 resolution, i.e. a square crop is made,
- normalize intensity.
preprocess = pth_transforms.Compose(
[
pth_transforms.Resize(
256, interpolation=pth_transforms.InterpolationMode.BICUBIC
),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
Therefore, it would have been better:
- either to use 256 resolution for the input,
- or to use 224 resolution (as I did) but without resizing-then-center-cropping when calling DINO.
Choices of pre-processing are discussed in this Github issue of DINOv2.
This is the case for eval_copy_detection.py
:
transform = pth_transforms.Compose([
pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
Please note that the call to Resize()
here leads to a square output, losing the aspect ratio of the original image.
This is also the case for eval_image_retrieval.py
:
transform = pth_transforms.Compose([
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
Run match_steam_banners_with_DINO.ipynb
.
Results were obtained in different settings identified by a suffix, e.g. ComplexB8
, where:
Simple
stands for the simple feature extraction, similar to the code ineval_knn.py
,Complex
stands for the complex feature extraction, similar to the code ineval_linear.py
,B8
for ViT-B/8: theBase
architecture with patch resolution8
.
If we look for trucks in banners similar to Euro Truck Simulator 2's banner, results are:
- similar for
Simple
andComplex
, - more satisfactory with
B16
andS8
, compared toB8
orS16
, - slightly more satisfacfory with
B16
compared toS8
.
Qualitatively, I would rank the strategies, starting with the most satisfactory one:
SimpleB16
ComplexB16
ComplexS8
SimpleS8
SimpleB8
ComplexB8
ComplexS16
SimpleS16
In summary:
ViT-S/16
<ViT-B/8
<ViT-S/8
<ViT-B/16
The ranking is compatible with the performance observed in the paper for the k-NN task:
ViT-S/16
<ViT-B/16
<ViT-B/8
<ViT-S/8
with the exception that B/16
seems to be the best performing model in our few test cases.
NB: B/8
is expected to under-perform w.r.t. S/8
, as its hyperparameters could have been further optimized.
The following results are obtained with ComplexB16
.
Settings are compared below for the same game, EuroTruck Simulator 2. Sorted from the best to the worst output.
SimpleB16
ComplexB16
ComplexS8
SimpleS8
SimpleB8
ComplexB8
ComplexS16
SimpleS16
The following result is obtained with ComplexB8
.
Other strategies for the creation of the image embedding would include:
- the concatenation of features extracted at multiple scales,
- the concatenation of the [CLS] token with GeM pooled patch tokens, as for copy detection.
- Facebook's DINO:
- My usage of OpenAI's CLIP:
steam-CLIP
: retrieve games with similar banners, using OpenAI's CLIP (resolution 224),steam-image-search
: retrieve games using natural language queries,heroku-flask-api
: serve the matching results through an API built with Flask on Heroku.