# Train Rhythm Former

This notebook trains and fine-tunes the Rhythm Former model on the VitalCam dataset.

In [None]:
import respiration.utils as utils
from respiration.dataset import VitalCamSet

manifest = {}

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['101_natural_lighting'])

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

device = utils.get_torch_device()
manifest['device'] = device

image_size = (128, 128)
manifest['image_size'] = image_size

In [None]:
from pytz import timezone
from datetime import datetime

from respiration.training import HybridLoss

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
timestamp = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
manifest['timestamp'] = timestamp

epochs = 22
manifest['epochs'] = epochs

min_freq = 0.1
manifest['min_freq'] = min_freq

max_freq = 0.5
manifest['max_freq'] = max_freq

loss_fn = HybridLoss(
    min_freq=min_freq,
    max_freq=max_freq,
    spectral_magnitude_weight=0.0,
    spectral_convergence_weight=0.0,
)

manifest['loss_fn'] = 'HybridLoss'
manifest['loss_fn_config'] = loss_fn.get_config()

# learning_rate = 0.00001
learning_rate = 9e-3
manifest['learning_rate'] = learning_rate

print('timestamp:', timestamp)

In [None]:
import torch

from tqdm.auto import tqdm
from respiration.extractor.rhythm_former import RhythmFormer
from respiration.training import VitalCamLoader

model = RhythmFormer(
    image_size=image_size,
)
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

manifest['model'] = 'RhythmFormer'

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'Adam'

model_dir = utils.dir_path(
    'models',
    'rhythm_former',
    timestamp,
    'RhythmFormer',
    mkdir=True)
best_loss = float('inf')

manifest['models'] = []
manifest['losses'] = []

In [None]:
for epoch in range(epochs):
    model.train()
    training_loader = VitalCamLoader(training, parts=2, device=device)
    testing_loader = VitalCamLoader(testing, parts=2, device=device)

    train_loss = 0
    for (frames, target) in tqdm(training_loader, desc=f'Training'):
        frames = utils.preprocess_frames(frames)
        # target = target.unsqueeze(0)

        # Forward pass
        output = model(frames).squeeze(0)
        # print(f'output.shape={output.shape} target.shape={target.shape}')

        # Compute the loss
        loss = loss_fn(output, target)
        # print(f'loss={loss.item()}')
        train_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del frames, target, output, loss

    testing_loss = 0
    model.eval()
    with torch.no_grad():
        for (frames, target) in tqdm(testing_loader, desc=f'Testing'):
            frames = utils.preprocess_frames(frames)
            # target = target.unsqueeze(0)

            output = model(frames).squeeze(0)

            loss = loss_fn(output, target)
            # print(f'loss={loss.item()}')
            testing_loss += loss.item()

            del frames, target, output, loss

    # Compute the average loss
    train_loss /= len(training_loader)
    testing_loss /= len(testing_loader)

    if testing_loss < best_loss:
        best_loss = testing_loss
        model_file = utils.file_path(model_dir, f'RF_{epoch}.pth')

        manifest['best_loss'] = best_loss
        manifest['models'].append({
            'epoch': epoch,
            'model_file': model_file,
            'train_loss': train_loss,
            'test_loss': testing_loss,
        })

        torch.save(model.state_dict(), model_file)

    print(f'[{epoch + 1}/{epochs}] '
          f'train-loss={train_loss:.3f} '
          f'test-loss={testing_loss:.3f}')
    manifest['losses'].append({
        'epoch': epoch,
        'train_loss': train_loss,
        'test_loss': testing_loss,
    })

    # Save the manifest
    manifest['epoch'] = epoch
    manifest_file = utils.file_path(model_dir, 'manifest.json')
    utils.write_json(manifest_file, manifest)