# Transformer

This notebook trains a transformer model to predict the respiration rate from a sequence of frames.

In [None]:
import respiration.utils as utils

from pytz import timezone
from datetime import datetime

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
model_id = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'dataset': 'VitalCamSet',
}
model_id

In [None]:
device

## Define training and testing scenarios

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['101_natural_lighting'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

## Model Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

log_dir = utils.dir_path('outputs', 'logs', model_id, mkdir=True)
writer = SummaryWriter(log_dir=log_dir)

In [None]:
from vit_pytorch import SimpleViT

image_size = 256
manifest['image_size'] = image_size

image_patch_size = 16
manifest['image_patch_size'] = image_patch_size

depth = 6
manifest['depth'] = depth

heads = 16
manifest['heads'] = heads

embedding_dim = 512
manifest['embedding_dim'] = embedding_dim

mlp_dim = embedding_dim * 4
manifest['mlp_dim'] = mlp_dim

num_classes = 1
manifest['num_classes'] = num_classes

model = SimpleViT(
    image_size=image_size,
    patch_size=image_patch_size,
    num_classes=num_classes,
    dim=embedding_dim,
    heads=heads,
    mlp_dim=mlp_dim,
    depth=depth,
).to(device)
manifest['base_model'] = 'simple_vit'

In [None]:
import torch
from respiration.training import HybridLoss

epochs = 22
manifest['epochs'] = epochs

learning_rate = 9e-3
manifest['learning_rate'] = learning_rate

loss_fn = HybridLoss()
manifest['loss_fn'] = 'HybridLoss'
manifest['loss_fn_config'] = loss_fn.get_config()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'Adam'

In [None]:
from respiration.training import VitalCamLoader


def train_model(scenario_split=4):
    model.train(True)

    epoch_loss = 0.0

    training_loader = VitalCamLoader(training, parts=scenario_split, device=device)

    # Iterate over the training scenarios
    for (frames, target) in training_loader:
        frames = utils.preprocess_frames(frames, (image_size, image_size), device).squeeze()

        # Make predictions for this chunk
        outputs = model(frames).squeeze()

        # Compute the loss and its gradients
        loss = loss_fn(outputs, target)
        print(f'LOSS: {loss.item():.3f}')

        # Optimize the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(training_loader)


def test_model(scenario_split=15):
    model.eval()
    running_loss = 0.0
    testing_loader = VitalCamLoader(testing, parts=scenario_split, device=device)

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for (frames, target) in testing_loader:
            frames = utils.preprocess_frames(frames, (image_size, image_size), device).squeeze()
            voutputs = model(frames).squeeze()
            running_loss += loss_fn(voutputs, target).item()

    return running_loss / len(testing_loader)

In [None]:
models = []
best_loss = float('inf')

In [None]:
model_dir = utils.dir_path('models', 'transformer', model_id, mkdir=True)


def save_manifest():
    manifest['trained_models'] = models
    manifest['best_testing_loss'] = float(best_loss)
    manifest['timestamp_finish'] = datetime.now().astimezone().isoformat()
    utils.write_json(os.path.join(model_dir, 'manifest.json'), manifest)

In [None]:
import os
from tqdm.auto import tqdm

for epoch in tqdm(range(epochs)):
    print(f'Epoch {epoch}:')

    training_loss = train_model()
    testing_loss = test_model()

    print(f'LOSS training={training_loss} testing={testing_loss}')
    writer.add_scalars('Average_Loss', {
        'Training': training_loss,
        'Testing': testing_loss,
    }, epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if testing_loss < best_loss:
        best_loss = testing_loss
        model_name = f'{model_id}_{epoch}.pth'

        model_path = os.path.join(model_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_loss': float(testing_loss),
        })

    manifest['epoch'] = epoch
    save_manifest()