# Fine-tuning EfficientPhys

This notebook fine-tunes the EfficientPhys model using the VitalCamSet dataset. The model is pre-trained on the BP4D dataset and is fine-tuned on the VitalCamSet dataset. The model is trained to predict the respiratory rate signal from the video frames.

In [None]:
import torch
import random
import numpy as np

seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)

In [None]:
from datetime import datetime
import respiration.utils as utils

# The timestamp is the unique identifier for this training run
model_id = datetime.now().strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

base_model_path = utils.file_path('data', 'rPPG-Toolbox', 'BP4D_PseudoLabel_EfficientPhys.pth')
model_dir = utils.dir_path('models', 'fine_tuned', model_id, mkdir=True)

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'base_model': 'EfficientPhys',
    'base_model_path': base_model_path,
    'dataset': 'VitalCamSet',
}

utils.pretty_print(manifest)

In [None]:
import os
from respiration.extractor.efficient_phys import EfficientPhys

dim = 72
frame_depth = 20

model = EfficientPhys(img_size=dim, frame_depth=frame_depth)
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

key_matching = model.load_state_dict(torch.load(base_model_path, map_location=device))
key_matching

## Dataset and dataloader

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios([
    '101_natural_lighting',
])

split_ratio = 0.8
training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
testing = scenarios_all[int(len(scenarios_all) * split_ratio):]

manifest['split_ratio'] = split_ratio
manifest['training_scenarios'] = training
manifest['testing_scenarios'] = testing

In [None]:
from scipy.signal import resample


class ScenarioLoader:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """

    def __init__(self, subject: str, setting: str, chunk_size: int):
        self.subject = subject
        self.setting = setting
        self.chunk_size = chunk_size

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        size = self.total_frames // self.chunk_size

        # Prevent the last batch from being smaller than the rest
        if self.total_frames % self.chunk_size > 0:
            size -= 1

        return size

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        start_frame = index * num_frames
        frames, _ = utils.read_video_rgb(self.video_path, num_frames, start_frame)
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)

        # The ground truth signal has a higher sampling rate than the video frames and needs to be down-sampled
        total_frames = utils.get_frame_count(self.video_path)
        gt_waveform = resample(gt_waveform, total_frames)
        gt_waveform = gt_waveform[:num_frames]

        # Down-sample the video frames to the desired dimension
        frames = utils.down_sample_video(frames, dim)

        # Create the frames tensor and the ground truth tensor
        frames = torch.tensor(frames,
                              dtype=torch.float32,
                              device=device).permute(0, 3, 1, 2)
        gt_waveform = torch.tensor(gt_waveform,
                                   dtype=torch.float32,
                                   device=device)

        return frames, gt_waveform

## Define loss functions

In [None]:
from torch import nn
from abc import abstractmethod
from tslearn.metrics import SoftDTWLossPyTorch


class _LossFunction(nn.Module):
    def __init__(self):
        super(_LossFunction, self).__init__()

    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        pass


class PearsonLoss(_LossFunction):
    def name(self):
        return 'pearson'

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Calculates the Pearson loss between the true and predicted signals using PyTorch.
        :param inputs: The predicted signal
        :param targets: The true signal
        :return: The Pearson loss
        """

        # Ensure the signals have the same length
        assert targets.size(0) == inputs.size(0), "Signals must have the same length"

        vx = inputs - torch.mean(inputs)
        vy = targets - torch.mean(targets)

        cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
        return 1 - cost


class MeanSquaredError(_LossFunction):
    def name(self):
        return 'mse'

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Calculates the Mean Squared Error loss between the true and predicted signals using PyTorch.
        :param inputs: The predicted signal
        :param targets: The true signal
        :return: The Mean Squared Error loss
        """
        criterion = nn.MSELoss()
        return criterion(inputs, targets)


class SoftDWT(_LossFunction):
    def name(self):
        return 'soft_dwt'

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Calculate the Cross Entropy loss between the true and predicted signals using PyTorch.
        :param inputs: The predicted signal
        :param targets: The true signal
        :return: The Mean Squared Error loss
        """

        # Transform the signals to the shape batch_size x 1 x sequence_length
        inputs = inputs.reshape(1, inputs.shape[0], 1)
        targets = targets.reshape(1, targets.shape[0], 1)

        criterion = SoftDTWLossPyTorch(gamma=0.1)
        return criterion(inputs, targets).mean()

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

log_dir = utils.dir_path('outputs', 'logs', mkdir=True)
writer = SummaryWriter(log_dir=log_dir, filename_suffix=model_id)

In [None]:
learning_rate = 0.001
manifest['learning_rate'] = learning_rate

loss_fn = SoftDWT()
manifest['loss_fn'] = loss_fn.name()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'AdamW'

# The model expects a multiple of 20 frames plus one.
num_frames = frame_depth * 20 + 1
manifest['chunk_size'] = num_frames

In [None]:
def train_one_epoch(epoch_index: int):
    epoch_loss = 0.0

    # Iterate over the training scenarios
    for (subject, setting) in training:
        loader = ScenarioLoader(subject, setting, num_frames)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        scenario_loss = 0.0

        # Iterate over the hole scenario video in chunks
        for idy, (frames, gt_labels) in enumerate(loader):
            # Make predictions for this chunk
            outputs = model(frames).squeeze()

            # The models calculate the difference between the frames. 
            # Hence, the output is one element shorter than the input.
            gt_labels = gt_labels[1:]

            # Normalize the signals with torch
            gt_labels = torch.nn.functional.normalize(gt_labels, dim=0)

            # Compute the loss and its gradients
            loss = loss_fn(outputs, gt_labels)
            loss.backward()

            # Gather data and report
            chunk_loss = loss.item()
            print(f'  {subject}_{setting}[{idy}]: loss={chunk_loss}')
            scenario_loss += chunk_loss

        # Adjust learning weights
        optimizer.step()

        scenario_loss /= len(loader)
        epoch_loss += scenario_loss

        print(f'  >> scenario_loss={scenario_loss}')
        writer.add_scalars('Training_Loss', {
            f'{subject}_{setting}': scenario_loss,
        }, epoch_index)
        writer.flush()

    return epoch_loss / len(training)

In [None]:
epochs = 5
manifest['epochs'] = epochs

best_vloss = 100_000_000.
models = []

for epoch in range(epochs):
    print(f'Epoch {epoch}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for inx, (subject, setting) in enumerate(testing):
            loader = ScenarioLoader(subject, setting, num_frames)

            testing_loss = 0.0

            for (frames, gt_labels) in loader:
                voutputs = model(frames).squeeze()

                # Remove the first element from the ground truth signal
                vlabels_batch = gt_labels[1:]

                # Normalize the signals with torch
                vlabels_batch = torch.nn.functional.normalize(vlabels_batch, dim=0)

                testing_loss += loss_fn(voutputs, vlabels_batch)

            testing_loss /= len(loader)
            writer.add_scalars(
                'Testing_Loss',
                {
                    f'{subject}_{setting}': testing_loss,
                },
                epoch)

            running_vloss += testing_loss

    avg_vloss = running_vloss / len(testing)
    print(f'LOSS train {avg_loss} valid {avg_vloss}')
    writer.add_scalars(
        'Average_Loss',
        {
            'Training': avg_loss,
            'Validation': avg_vloss,
        },
        epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_name = f'model_{model_id}_{epoch}.pth'
        model_path = os.path.join(model_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_loss': float(avg_vloss),
        })

In [None]:
manifest['tuned_models'] = models
manifest['best_validation_loss'] = float(best_vloss)
manifest['timestamp_finish'] = datetime.now().astimezone().isoformat()
utils.write_json(os.path.join(model_dir, 'manifest.json'), manifest)