# Transformer

In [None]:
import respiration.utils as utils

tuned_models = {
    '20240511_190518',
    '20240511_194544',
}

# Map model names to their paths
models = {}

manifests = []

for model_id in tuned_models:
    model_dir = utils.dir_path('models', 'transformer', model_id)

    manifest_path = utils.dir_path(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)
    best_model = manifest['trained_models'][-1]

    model_path = utils.join_paths(model_dir, best_model['model'])
    models[model_id] = model_path
    manifests.append(manifest)

utils.pretty_print(models)

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios = dataset.get_scenarios(['101_natural_lighting'])

device = utils.get_torch_device()
image_size = 256

In [None]:
import torch


def temporal_shifting(frames: torch.Tensor) -> torch.Tensor:
    """
    Calculate the temporal shifting of the frames. This is done by calculating the difference between the frames and
    normalizing the result.
    """
    diff_frames = frames[1:] - frames[:-1]
    sum_frames = frames[1:] + frames[:-1]
    inputs = diff_frames / (sum_frames + 1e-7)
    inputs = (inputs - torch.mean(inputs)) / torch.std(inputs)
    return inputs

In [None]:
import torch
import pandas as pd
import datetime as dt

from tqdm.auto import tqdm
from vit_pytorch import SimpleViT
from torchvision import transforms

predictions = []

for (subject, setting) in tqdm(scenarios):
    print(f"Processing {subject} - {setting}")

    video_path = dataset.get_video_path(subject, setting)

    frames, _ = utils.read_video_rgb(video_path)
    preprocess = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])
    frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
    frames = frames.to(device)
    frames = temporal_shifting(frames)

    for (model_id, model_path) in models.items():
        print(f"--> Using {model_id} model")
        # Wrap modul in nn.DataParallel to fix the model loading issue
        model = SimpleViT(
            image_size=image_size,
            patch_size=32,
            num_classes=1,
            dim=1024,
            depth=6,
            heads=16,
            mlp_dim=2048
        ).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        start = dt.datetime.now()

        with torch.no_grad():
            prediction = model(frames).cpu().detach().numpy().squeeze()

        predictions.append({
            'model': model_id,
            'subject': subject,
            'setting': setting,
            'duration': dt.datetime.now() - start,
            'signal': prediction.tolist(),
        })

    del frames

predictions = pd.DataFrame(predictions)

# Store the predictions to csv
signals_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signals_path = utils.join_paths(signals_dir, 'transformer_predictions.csv')

predictions.to_csv(signals_path, index=False)

In [None]:
predictions.head()