In [None]:
from datetime import datetime
import respiration.utils as utils

# The timestamp is the unique identifier for this training run
model_id = datetime.now().strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

image_size = 256

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'base_model': 'SimpleViT',
    'dataset': 'VitalCamSet',
    'image_size': image_size,
}

## Dataset and Dataloader

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['101_natural_lighting'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

In [None]:
import torch


def temporal_shifting(frames: torch.Tensor) -> torch.Tensor:
    """
    Calculate the temporal shifting of the frames. This is done by calculating the difference between the frames and
    normalizing the result.
    """
    diff_frames = frames[1:] - frames[:-1]
    sum_frames = frames[1:] + frames[:-1]
    inputs = diff_frames / (sum_frames + 1e-7)
    inputs = (inputs - torch.mean(inputs)) / torch.std(inputs)
    return inputs

In [None]:
import math
import torch
from torchvision import transforms

import respiration.utils as utils


class ScenarioLoader:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """
    subject: str
    setting: str
    frames_per_segment: int

    def __init__(self,
                 subject: str,
                 setting: str,
                 frames_per_segment: int = 1024):
        self.subject = subject
        self.setting = setting
        self.frames_per_segment = frames_per_segment

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        return math.ceil(self.total_frames / self.frames_per_segment)

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        start = index * self.frames_per_segment
        end = start + self.frames_per_segment
        size = min(self.frames_per_segment, self.total_frames - start)

        # Load the video frames
        frames, _ = utils.read_video_rgb(self.video_path, size, start)
        preprocess = transforms.Compose([
            transforms.ToPILImage(mode='RGB'),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
        frames = frames.to(device)

        # Get the ground truth signal for the scenario
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)
        gt_waveform = torch.tensor(gt_waveform, dtype=torch.float32, device=device)
        gt_waveform = torch.nn.functional.normalize(gt_waveform, dim=0)
        gt_waveform = gt_waveform[start:end]

        return frames, gt_waveform

## Model Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

log_dir = utils.dir_path('outputs', 'logs', model_id, mkdir=True)
writer = SummaryWriter(log_dir=log_dir)

In [None]:
from vit_pytorch import SimpleViT

model = SimpleViT(
    image_size=image_size,
    patch_size=32,
    num_classes=1,
    dim=1024,
    depth=6,
    heads=16,
    mlp_dim=2048
).to(device)

In [None]:
epochs = 2
manifest['epochs'] = epochs

learning_rate = 0.001
manifest['learning_rate'] = learning_rate

loss_fn = torch.nn.MSELoss()
manifest['loss_fn'] = 'MSELoss'

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'AdamW'

In [None]:
def train_one_epoch(epoch_index: int):
    epoch_loss = 0.0

    # Iterate over the training scenarios
    for (subject, setting) in training:
        loader = ScenarioLoader(subject, setting)

        scenario_loss = 0.0

        # Iterate over the hole scenario video in chunks
        for idy, (frames, gt_labels) in enumerate(loader):
            frames = temporal_shifting(frames)
            gt_labels = gt_labels[1:]

            # Zero your gradients for every batch!
            optimizer.zero_grad()

            # Make predictions for this chunk
            outputs = model(frames).squeeze()

            # Compute the loss and its gradients
            loss = loss_fn(outputs, gt_labels)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            # Gather data and report
            chunk_loss = loss.item()
            #print(f'  {subject}_{setting}[{idy}]: loss={chunk_loss}')
            scenario_loss += chunk_loss

        scenario_loss /= len(loader)
        epoch_loss += scenario_loss

        print(f'  >> {subject} loss={scenario_loss}')
        writer.add_scalars('Training_Loss', {
            f'{subject}_{setting}': scenario_loss,
        }, epoch_index)
        writer.flush()

    return epoch_loss / len(training)

In [None]:
model_dir = utils.dir_path('models', 'transformer', model_id, mkdir=True)

In [None]:
import os
from tqdm.auto import tqdm

best_vloss = 100_000_000.
models = []

for epoch in tqdm(range(epochs)):
    print(f'Epoch {epoch}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for inx, (subject, setting) in enumerate(testing):
            loader = ScenarioLoader(subject, setting)
            testing_loss = 0.0

            for (frames, gt_labels) in loader:
                frames = temporal_shifting(frames)
                vlabels_batch = gt_labels[1:]
                voutputs = model(frames).squeeze()
                testing_loss += loss_fn(voutputs, vlabels_batch)

            testing_loss /= len(loader)
            writer.add_scalars('Testing_Loss', {f'{subject}_{setting}': testing_loss}, epoch)

            running_vloss += testing_loss

    avg_vloss = running_vloss / len(testing)
    print(f'LOSS train {avg_loss} valid {avg_vloss}')
    writer.add_scalars('Average_Loss', {
        'Training': avg_loss,
        'Validation': avg_vloss,
    }, epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_name = f'{model_id}_{epoch}.pth'

        model_path = os.path.join(model_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_loss': float(avg_vloss),
        })

In [None]:
manifest['trained_models'] = models
manifest['best_validation_loss'] = float(best_vloss)
manifest['timestamp_finish'] = datetime.now().astimezone().isoformat()
utils.write_json(os.path.join(model_dir, 'manifest.json'), manifest)