# Fine-tuning EfficientPhys on the VitalCamSet dataset

In [None]:
import torch
import random
import numpy as np

seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)

In [None]:
import os
import respiration.utils as utils
from datetime import datetime

# Initializing in a separate cell, so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

output_dir = os.path.join('..', 'models', 'efficeint_phys_fine_tuned', timestamp)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# pretrained_model = os.path.join('..', 'data', 'rPPG-Toolbox', 'UBFC-rPPG_EfficientPhys.pth')
pretrained_model = os.path.join('..', 'data', 'rPPG-Toolbox', 'BP4D_PseudoLabel_EfficientPhys.pth')

device = utils.get_torch_device()

manifest = {
    'device': str(device),
    'timestamp': datetime.now().astimezone().isoformat(),
    'model': 'EfficientPhys',
    'pretrained_model': pretrained_model,
    'dataset': 'VitalCamSet',
}

utils.pretty_print(manifest)

In [None]:
import os
from respiration.extractor.efficient_phys import EfficientPhys

dim = 72
frame_depth = 20

# Wrap modul in nn.DataParallel
model = EfficientPhys(img_size=dim, frame_depth=frame_depth)
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

key_matching = model.load_state_dict(torch.load(pretrained_model, map_location=device))
key_matching

## Dataset and dataloader

In [None]:
import respiration.dataset as repository

dataset = repository.from_default()
scenarios_all = dataset.get_scenarios([
    '101_natural_lighting',
])

split_ratio = 0.8
training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
testing = scenarios_all[int(len(scenarios_all) * split_ratio):]

manifest['split_ratio'] = split_ratio
manifest['training'] = training
manifest['testing'] = testing

In [None]:
# The model expects a multiple of 20 frames plus one.
chunk_size = frame_depth * 75 + 1
manifest['chunk_size'] = chunk_size

In [None]:
from scipy.signal import resample
from torch.utils.data import Dataset


class RespirationDataset(Dataset):
    def __init__(self,
                 source: repository.Dataset,
                 scenarios: list[tuple[str, str]],
                 to: torch.device = torch.device('cpu')):
        self.dataset = source
        self.scenarios = scenarios
        self.device = to

    def __len__(self):
        return len(self.scenarios)

    def __getitem__(self, idx):
        subject, scenario = self.scenarios[idx]
        frames, _ = self.dataset.get_video_rgb(subject, scenario, False)
        gt_waveform, _ = self.dataset.get_ground_truth_rr_signal(subject, scenario)

        # The ground truth signal is sample with a higher frequency than the frames...
        gt_waveform = resample(gt_waveform, len(frames))

        gt_waveform = gt_waveform[:chunk_size]
        frames = frames[:chunk_size]

        # Down-sample the video frames to the desired dimension
        frames = utils.down_sample_video(frames, dim)

        # Create the frames tensor and the ground truth tensor
        frames = torch.tensor(frames,
                              dtype=torch.float32,
                              device=self.device).permute(0, 3, 1, 2)
        gt_waveform = torch.tensor(gt_waveform,
                                   dtype=torch.float32,
                                   device=self.device)

        return frames, gt_waveform


training_data = RespirationDataset(dataset, training, device)
test_data = RespirationDataset(dataset, testing, device)

In [None]:
from torch.utils.data import DataLoader

training_loader = DataLoader(training_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

In [None]:
# Display image and label.
#train_features, train_labels = next(iter(training_loader))
#print(f"Feature batch shape: {train_features.size()}")
#print(f"Labels batch shape: {train_labels.size()}")
#print(f"Device: {train_features.device}")

## Define loss functions

In [None]:
from abc import abstractmethod
from torch import nn


class _LossFunction(nn.Module):
    def __init__(self):
        super(_LossFunction, self).__init__()

    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        pass


class PearsonLoss(_LossFunction):
    def name(self):
        return 'pearson'

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Calculates the Pearson loss between the true and predicted signals using PyTorch.
        :param inputs: The predicted signal
        :param targets: The true signal
        :return: The Pearson loss
        """

        # Ensure the signals have the same length
        assert targets.size(0) == inputs.size(0), "Signals must have the same length"

        vx = inputs - torch.mean(inputs)
        vy = targets - torch.mean(targets)

        cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
        return 1 - cost


class MeanSquaredError(_LossFunction):
    def name(self):
        return 'mse'

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Calculates the Mean Squared Error loss between the true and predicted signals using PyTorch.
        :param inputs: The predicted signal
        :param targets: The true signal
        :return: The Mean Squared Error loss
        """
        criterion = nn.CrossEntropyLoss()
        return criterion(inputs, targets)


class SoftDTW(_LossFunction):
    def __init__(self, gamma=1.0):
        super(SoftDTW, self).__init__()
        self.gamma = gamma

    def name(self):
        return 'soft_dwt'

    def forward(self, x, y):
        n, m = x.size(0), y.size(0)
        # Create a matrix to hold all the costs
        D = torch.zeros((n + 1, m + 1), device=x.device, dtype=torch.float)
        D[0, 1:] = float('inf')
        D[1:, 0] = float('inf')

        # Calculating the cost matrix
        dist = lambda a, b: (a - b).pow(2)
        for i in range(1, n + 1):
            for j in range(1, m + 1):
                cost = dist(x[i - 1], y[j - 1])
                # Apply the soft minimum operation
                D[i, j] = cost + self.softmin(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])

        return D[n, m]

    def softmin(self, *args):
        minimum = torch.min(torch.stack(args))
        return -self.gamma * torch.log(
            torch.sum(torch.exp((-torch.stack(args) + minimum) / self.gamma))
        ) + minimum

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir=os.path.join(output_dir, 'logs'))

In [None]:
learning_rate = 0.005
loss_fn = MeanSquaredError()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

manifest['learning_rate'] = learning_rate
manifest['loss_fn'] = loss_fn.name()
manifest['optimizer'] = 'AdamW'

In [None]:
def train_one_epoch(epoch_index: int):
    running_loss = 0.0
    last_loss = 0.0

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for idx, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        for idy in range(inputs.size(0)):
            # Zero your gradients for every batch!
            optimizer.zero_grad()

            # Make predictions for this batch
            outputs = model(inputs[idy]).squeeze()

            # The models calculate the difference between the frames. 
            # Hence, the output is one element shorter than the input.
            gt_labels = labels[idy]
            gt_labels = gt_labels[1:]

            # Normalize the signals with torch
            gt_labels = torch.nn.functional.normalize(gt_labels, dim=0)

            # Compute the loss and its gradients
            loss = loss_fn(outputs, gt_labels)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            # Gather data and report
            running_loss += loss.item()

        last_loss = running_loss / inputs.size(0)
        running_loss = 0.0

        print(f'  batch {idx + 1} loss: {last_loss}')
        tb_x = epoch_index * len(training_loader) + idx
        writer.add_scalar('Training_Loss', last_loss, tb_x)
        writer.flush()

    return last_loss

In [None]:
EPOCHS = 10
manifest['epochs'] = EPOCHS

best_vloss = 100_000_000.
models = []

for epoch in range(EPOCHS):
    print(f'EPOCH {epoch + 1}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for inx, vdata in enumerate(test_loader):
            vinputs, vlabels = vdata

            for idy in range(vinputs.size(0)):
                voutputs = model(vinputs[idy]).squeeze()
                vlabels_batch = vlabels[idy]

                # Remove the first element from the ground truth signal
                vlabels_batch = vlabels_batch[1:]

                # Normalize the signals with torch
                vlabels_batch = torch.nn.functional.normalize(vlabels_batch, dim=0)

                vloss = loss_fn(voutputs, vlabels_batch)
                running_vloss += vloss

    avg_vloss = running_vloss / (inx + 1)
    print(f'LOSS train {avg_loss} valid {avg_vloss}')
    writer.add_scalars(
        'Average_Loss',
        {
            'Training': avg_loss,
            'Validation': avg_vloss,
        },
        epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_name = f'model_{timestamp}_{epoch}.pth'
        model_path = os.path.join(output_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_loss': float(avg_vloss),
        })

In [None]:
manifest['models'] = models
manifest['best_validation_loss'] = float(best_vloss)
manifest['exit'] = datetime.now().astimezone().isoformat()
utils.write_json(os.path.join(output_dir, 'manifest.json'), manifest)