# Infer using a Transformer

> Module for training on a dataset of embeddings

In [None]:
# | default_exp models.timm_feature_extractor

In [None]:
%reload_ext nb_black
%reload_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import sys

__root = "../../"
sys.path.append(__root)

In [None]:
# | export
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)

from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo,
)
import pandas as pd
import timm
from torch_snippets import *

In [None]:
# | export

mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
mean_transform = ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            NormalizeVideo(mean, std),
        ]
    ),
)


def extract_features_for_all_frames(
    model, frames_folder, features_folder, device, batch_size=64
):
    feature_extractor = timm.create_model(model, pretrained=True, num_classes=0).to(
        device
    )
    frames_files = sorted(Glob(frames_folder), key=lambda x: int(stem(stem(x))))
    for frames_path in (tracker := track2(frames_files)):
        item = stem(frames_path)
        if item in ["477.frames", "407.frames"]:
            continue
        to = features_folder / f"{item}.features.tensor"
        if exists(to):
            continue
        frames = loaddill(frames_path).permute(1, 0, 2, 3)
        frames = {"video": frames}
        frames = mean_transform(frames)["video"].permute(1, 0, 2, 3)
        tracker.send(f"processing {item} @ {frames}")
        with torch.no_grad():
            try:
                frames = frames.split(batch_size)
                preds = []
                for frame_batch in frames:
                    pred = feature_extractor(frame_batch.to(device)).cpu()
                    preds.append(pred)
                preds = torch.cat(preds)
                dumpdill(preds, to, silent=True)
            except Exception as e:
                Warn(f"{e} @ {item}")

Usage

```python
from torch_snippets import *

root = P("/mnt/347832F37832B388/ml-datasets/ssbd")
annotations = pd.read_csv(f"{root}/annotations.csv")

MODELS = ["vgg19", "resnet18", "resnet50", "densenet121"]
for model in MODELS:
    frames_folder = root / "ssbd-frames/10fps"
    features_folder = root / f"ssbd-frames-features/10fps/{model}/"
    makedir(features_folder)
    extract_features_for_all_frames(model, frames_folder, features_folder, "cuda")
```