# Transformer Classifier

This notebook trains a Transformer based classifier to predict inhaling and exhaling from video frames.

In [None]:
import respiration.utils as utils

from pytz import timezone
from datetime import datetime

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
model_id = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'dataset': 'VitalCamSet',
}
model_id

In [None]:
device

## Define training and testing scenarios

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['303_normalized_face'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

In [None]:
num_frames = 300
manifest['num_frames'] = num_frames

## Define temporal shifting

In [None]:
import torch


def preprocess_frames(frames: torch.Tensor) -> torch.Tensor:
    # Normalize the frames
    frames = (frames - frames.min()) / (frames.max() - frames.min())
    return frames

In [None]:
image_size = 256
manifest['image_size'] = image_size

In [None]:
import math
import torch
import respiration.utils as utils


class ScenarioLoaderChunks:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """
    subject: str
    setting: str
    frames_per_segment: int

    def __init__(self,
                 subject: str,
                 setting: str,
                 frames_per_segment: int = num_frames):
        self.subject = subject
        self.setting = setting
        self.frames_per_segment = frames_per_segment

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        return math.ceil(self.total_frames / self.frames_per_segment)

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        start = index * self.frames_per_segment
        end = start + self.frames_per_segment
        size = min(self.frames_per_segment, self.total_frames - start)

        # Load the video frames
        frames, meta = utils.read_video_rgb(self.video_path, size, start)
        frames = torch.tensor(frames, dtype=torch.float32, device=device)

        # Permute the dimensions to match the expected input format (B, C, H, W)
        frames = frames.permute(0, 3, 1, 2)

        # Get the ground truth signal for the scenario
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)
        gt_waveform = torch.tensor(gt_waveform.copy(), dtype=torch.float32, device=device)

        #
        # Normalize the signals: This normalizes the signal between -0.5 and 0.5. The values are based on the
        # overall maximum and minimum values in the dataset.
        #

        gt_overall_max, gt_overall_min = 6680.352219172085, -6572.075174276201
        gt_waveform = (gt_waveform - gt_waveform.mean()) / (gt_overall_max - gt_overall_min)
        gt_waveform = gt_waveform[start:end]

        return frames, gt_waveform

## Test the ScenarioLoader

In [None]:
loader = ScenarioLoaderChunks('Proband22', '303_normalized_face')

chunk_frames, chunk_signal = loader[4]
chunk_frames.shape, chunk_signal.shape

## Model Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

log_dir = utils.dir_path('outputs', 'logs', model_id, mkdir=True)
writer = SummaryWriter(log_dir=log_dir)

In [None]:
from vit_pytorch import SimpleViT

image_patch_size = 16
manifest['image_patch_size'] = image_patch_size

depth = 6
manifest['depth'] = depth

heads = 16
manifest['heads'] = heads

mlp_dim = 2048
manifest['mlp_dim'] = mlp_dim

embedding_dim = 512
manifest['embedding_dim'] = embedding_dim

num_classes = 1
manifest['num_classes'] = num_classes

model = SimpleViT(
    image_size=image_size,
    patch_size=image_patch_size,
    num_classes=num_classes,
    dim=embedding_dim,
    heads=heads,
    mlp_dim=mlp_dim,
    depth=depth,
).to(device)
manifest['base_model'] = 'simple_vit'

In [None]:
from tslearn.metrics import SoftDTWLossPyTorch

epochs = 30
manifest['epochs'] = epochs

learning_rate = 0.00001
manifest['learning_rate'] = learning_rate

loss_fn = SoftDTWLossPyTorch(gamma=0.1)
manifest['loss_fn'] = 'MSELoss'

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'AdamW'

In [None]:
def train_one_epoch(epoch_index: int):
    epoch_loss = 0.0

    # Iterate over the training scenarios
    for (subject, setting) in training:
        loader = ScenarioLoaderChunks(subject, setting)

        scenario_loss = 0.0

        # Iterate over the hole scenario video in chunks
        for idy, (frames, gt_classes) in enumerate(loader):
            frames = preprocess_frames(frames)

            # Make predictions for this chunk
            outputs = model(frames).squeeze()

            # Compute the loss and its gradients
            loss = loss_fn(outputs, gt_classes)

            # Optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Gather data and report
            print(f'  {subject} #{idy:02d} loss={loss.item()}')
            scenario_loss += loss.item()

        scenario_loss /= len(loader)
        epoch_loss += scenario_loss

        print(f'  >> {subject} scenario_loss={scenario_loss}')
        writer.add_scalars('Training_Loss', {
            f'{subject}_{setting}': scenario_loss,
        }, epoch_index)
        writer.flush()

    return epoch_loss / len(training)

In [None]:
model_dir = utils.dir_path('models', 'transformer', model_id, mkdir=True)

In [None]:
models = []
best_loss = float('inf')

In [None]:
def save_manifest():
    manifest['trained_models'] = models
    manifest['best_testing_loss'] = float(best_loss)
    manifest['timestamp_finish'] = datetime.now().astimezone().isoformat()
    utils.write_json(os.path.join(model_dir, 'manifest.json'), manifest)

In [None]:
import os
from tqdm.auto import tqdm

for epoch in tqdm(range(epochs)):
    print(f'Epoch {epoch}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)

    running_loss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for inx, (subject, setting) in enumerate(testing):
            loader = ScenarioLoaderChunks(subject, setting)
            testing_loss = 0.0

            for (frames, gt_classes) in loader:
                frames = preprocess_frames(frames)
                voutputs = model(frames).squeeze()
                testing_loss += loss_fn(voutputs, gt_classes).item()

            testing_loss /= len(loader)
            writer.add_scalars('Testing_Loss', {f'{subject}_{setting}': testing_loss}, epoch)
            print(f'  >> {subject} loss={testing_loss}')

            running_loss += testing_loss

    testing_loss = running_loss / len(testing)
    print(f'LOSS training={avg_loss} testing={testing_loss}')
    writer.add_scalars('Average_Loss', {
        'Training': avg_loss,
        'Testing': testing_loss,
    }, epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if testing_loss < best_loss:
        best_loss = testing_loss
        model_name = f'{model_id}_{epoch}.pth'

        model_path = os.path.join(model_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_loss': float(testing_loss),
        })
        save_manifest()

In [None]:
save_manifest()