# Transformer

In [1]:
import respiration.utils as utils

tuned_models = {
    # '20240511_190518',
    # '20240511_194544',
    '20240512_211825',
}

# Map model names to their paths
models = {}

manifests = []

for model_id in tuned_models:
    model_dir = utils.dir_path('models', 'transformer', model_id)

    manifest_path = utils.dir_path(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)
    best_model = manifest['trained_models'][-1]

    model_path = utils.join_paths(model_dir, best_model['model'])
    models[model_id] = model_path
    manifests.append(manifest)

utils.pretty_print(models)

{
  "20240512_211825": "/app/models/transformer/20240512_211825/20240512_211825_19.pth"
}


In [2]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios = dataset.get_scenarios(['101_natural_lighting'])

device = utils.get_torch_device()
image_size = 256

In [3]:
import torch


def temporal_shifting_frames(frames: torch.Tensor) -> torch.Tensor:
    """
    Calculate the temporal shifting of the frames. This is done by calculating the difference between the frames and
    normalizing the result.
    """
    diff_frames = frames[1:] - frames[:-1]
    sum_frames = frames[1:] + frames[:-1]
    inputs = diff_frames / (sum_frames + 1e-7)
    inputs = (inputs - torch.mean(inputs)) / torch.std(inputs)
    return inputs


def temporal_shifting_signal(time_series: torch.Tensor) -> torch.Tensor:
    # Calculate the difference between the time series
    return time_series[1:] - time_series[:-1]

In [4]:
import torch
import pandas as pd
import datetime as dt

from tqdm.auto import tqdm
from vit_pytorch import ViT
from torchvision import transforms

predictions = []

for (subject, setting) in tqdm(scenarios):
    print(f"Processing {subject} - {setting}")

    video_path = dataset.get_video_path(subject, setting)

    frames, _ = utils.read_video_rgb(video_path)
    preprocess = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])
    frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
    frames = frames.to(device)
    frames = temporal_shifting_frames(frames)

    for (model_id, model_path) in models.items():
        print(f"--> Using {model_id} model")
        # Wrap modul in nn.DataParallel to fix the model loading issue
        model = ViT(
            image_size=image_size,
            patch_size=32,
            num_classes=1,
            dim=128,
            depth=6,
            heads=16,
            mlp_dim=2048
        ).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        start = dt.datetime.now()

        with torch.no_grad():
            prediction = model(frames).cpu().detach().numpy().squeeze()

        predictions.append({
            'model': model_id,
            'subject': subject,
            'setting': setting,
            'duration': dt.datetime.now() - start,
            'signal': prediction.tolist(),
        })

    del frames

predictions = pd.DataFrame(predictions)

# Store the predictions to csv
signals_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signals_path = utils.join_paths(signals_dir, 'transformer_predictions.csv')

predictions.to_csv(signals_path, index=False)

  0%|          | 0/26 [00:00<?, ?it/s]

Processing Proband01 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband02 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband03 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband04 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband05 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband06 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband07 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband08 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband09 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband10 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband11 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband12 - 101_natural_lighting
--> Using 20240512_211825 model
Processing Proband13 - 101_natural_lighting
--> Using 20240512_211825 model
Processing P

In [5]:
predictions.head()

Unnamed: 0,model,subject,setting,duration,signal
0,20240512_211825,Proband01,101_natural_lighting,0 days 00:00:00.702146,"[-0.001394517719745636, 0.005532152950763702, ..."
1,20240512_211825,Proband02,101_natural_lighting,0 days 00:00:00.729610,"[-0.0075696781277656555, 0.011222995817661285,..."
2,20240512_211825,Proband03,101_natural_lighting,0 days 00:00:00.729791,"[0.00974736362695694, 0.0017803534865379333, -..."
3,20240512_211825,Proband04,101_natural_lighting,0 days 00:00:00.735905,"[-0.006696484982967377, 0.005128942430019379, ..."
4,20240512_211825,Proband05,101_natural_lighting,0 days 00:00:00.642627,"[0.008742131292819977, -0.005681194365024567, ..."
