# Train Rhythm Former

This notebook trains and fine-tunes the Rhythm Former model on the VitalCam dataset.

In [1]:
manifest = {}

In [2]:
import respiration.utils as utils
from respiration.dataset import (
    VitalCamSet,
    VitalCamLoader,
)

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['303_normalized_face'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

device = utils.get_torch_device()
manifest['device'] = device

In [3]:
import torchvision.transforms as transforms


def preprocess_frames(frames):
    # Preprocess the frames to be in 128x128 with torch
    transform = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

    # Transform each frame
    transformed_frames = torch.stack([
        transform(frame) for frame in frames
    ])

    return transformed_frames.unsqueeze(0).to(device)

In [4]:
model_mmpd = utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth')
model_pure = utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth')
model_ubfc = utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth')

models = {
    'RhythmFormer': None,
    'RhythmFormer_MMPD': model_mmpd,
    'RhythmFormer_PURE': model_pure,
    'RhythmFormer_UBFC': model_ubfc,
}
manifest['models'] = models

In [5]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import scipy
from scipy.signal import butter, welch
from scipy.sparse import spdiags


def get_hr(y, sr=30, min=45, max=150):
    p, q = welch(y, sr, nfft=1e5 / sr, nperseg=np.min((len(y) - 1, 256)))
    return p[(p > min / 60) & (p < max / 60)][np.argmax(q[(p > min / 60) & (p < max / 60)])] * 60


def get_psd(y, sr=30, min=45, max=150):
    p, q = welch(y, sr, nfft=1e5 / sr, nperseg=np.min((len(y) - 1, 256)))
    return q[(p > min / 60) & (p < max / 60)]


def normal_sampling(mean, label_k, std):
    return math.exp(-(label_k - mean) ** 2 / (2 * std ** 2)) / (math.sqrt(2 * math.pi) * std)


def kl_loss(inputs, labels):
    criterion = nn.KLDivLoss(reduce=False)
    outputs = torch.log(inputs)
    loss = criterion(outputs, labels)
    #loss = loss.sum()/loss.shape[0]
    loss = loss.sum()
    return loss


def _detrend(input_signal, lambda_value):
    """Detrend PPG signal."""
    signal_length = input_signal.shape[0]
    # observation matrix
    H = np.identity(signal_length)
    ones = np.ones(signal_length)
    minus_twos = -2 * np.ones(signal_length)
    diags_data = np.array([ones, minus_twos, ones])
    diags_index = np.array([0, 1, 2])
    D = spdiags(diags_data, diags_index,
                (signal_length - 2), signal_length).toarray()
    detrended_signal = np.dot(
        (H - np.linalg.inv(H + (lambda_value ** 2) * np.dot(D.T, D))), input_signal)
    return detrended_signal


def calculate_hr(predictions, labels, fs=30, diff_flag=False):
    """Calculate video-level HR and SNR"""
    if diff_flag:  # if the predictions and labels are 1st derivative of PPG signal.
        predictions = _detrend(np.cumsum(predictions), 100)
        labels = _detrend(np.cumsum(labels), 100)
    else:
        predictions = _detrend(predictions, 100)
        labels = _detrend(labels, 100)
    [b, a] = butter(1, [0.75 / fs * 2, 2.5 / fs * 2], btype='bandpass')
    predictions = scipy.signal.filtfilt(b, a, np.double(predictions))
    labels = scipy.signal.filtfilt(b, a, np.double(labels))
    hr_pred = get_hr(predictions, sr=fs)
    hr_label = get_hr(labels, sr=fs)
    return hr_pred, hr_label


def calculate_psd(predictions, labels, fs=30, diff_flag=False):
    """Calculate video-level HR and SNR"""
    if diff_flag:  # if the predictions and labels are 1st derivative of PPG signal.
        predictions = _detrend(np.cumsum(predictions), 100)
        labels = _detrend(np.cumsum(labels), 100)
    else:
        predictions = _detrend(predictions, 100)
        labels = _detrend(labels, 100)
    [b, a] = butter(1, [0.75 / fs * 2, 2.5 / fs * 2], btype='bandpass')
    predictions = scipy.signal.filtfilt(b, a, np.double(predictions))
    labels = scipy.signal.filtfilt(b, a, np.double(labels))
    psd_pred = get_psd(predictions, sr=fs)
    psd_label = get_psd(labels, sr=fs)
    return psd_pred, psd_label


class TorchLossComputer(object):
    @staticmethod
    def compute_complex_absolute_given_k(output, k, N):
        two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
        hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)

        k = k.type(torch.FloatTensor).cuda()
        two_pi_n_over_N = two_pi_n_over_N.cuda()
        hanning = hanning.cuda()

        output = output.view(1, -1) * hanning
        output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
        k = k.view(1, -1, 1)
        two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
        complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
                           + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2

        return complex_absolute

    @staticmethod
    def complex_absolute(output, Fs, bpm_range=None):
        output = output.view(1, -1)

        N = output.size()[1]

        unit_per_hz = Fs / N
        feasible_bpm = bpm_range / 60.0
        k = feasible_bpm / unit_per_hz

        # only calculate feasible PSD range [0.7,4]Hz
        complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)

        return (1.0 / complex_absolute.sum()) * complex_absolute  # Analogous Softmax operator

    @staticmethod
    def cross_entropy_power_spectrum_loss(inputs, target, Fs):
        inputs = inputs.view(1, -1)
        target = target.view(1, -1)
        bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
        #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()

        complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)

        whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
        whole_max_idx = whole_max_idx.type(torch.float)

        #pdb.set_trace()

        #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1),  (target.item() - whole_max_idx.item()) ** 2
        return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(
            target[0] - whole_max_idx)

    @staticmethod
    def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
        inputs = inputs.view(1, -1)
        target = target.view(1, -1)
        bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
        #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()

        complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)

        whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
        whole_max_idx = whole_max_idx.type(torch.float)

        #pdb.set_trace()
        criterion = FocalLoss(gamma=gamma)

        #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1),  (target.item() - whole_max_idx.item()) ** 2
        return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)

    @staticmethod
    def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
        inputs = inputs.view(1, -1)
        bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
        #bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
        #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()

        complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)

        whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
        whole_max_idx = whole_max_idx.type(torch.float)

        return whole_max_idx

    @staticmethod
    def Frequency_loss(inputs, target, diff_flag, Fs, std):
        hr_pred, hr_gt = calculate_hr(inputs.detach().cpu(), target.detach().cpu(), diff_flag=diff_flag, fs=Fs)
        inputs = inputs.view(1, -1)
        target = target.view(1, -1)
        bpm_range = torch.arange(45, 150, dtype=torch.float).to(torch.device('cuda'))
        ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
        sa = ca / torch.sum(ca)

        target_distribution = [normal_sampling(int(hr_gt), i, std) for i in range(45, 150)]
        target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
        target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))

        hr_gt = torch.tensor(hr_gt - 45).view(1).type(torch.long).to(torch.device('cuda'))
        return F.cross_entropy(ca, hr_gt), kl_loss(sa, target_distribution)

    @staticmethod
    def HR_loss(inputs, target, diff_flag, Fs, std):
        psd_pred, psd_gt = calculate_psd(inputs.detach().cpu(), target.detach().cpu(), diff_flag=diff_flag, fs=Fs)
        pred_distribution = [normal_sampling(np.argmax(psd_pred), i, std) for i in range(psd_pred.size)]
        pred_distribution = [i if i > 1e-15 else 1e-15 for i in pred_distribution]
        pred_distribution = torch.Tensor(pred_distribution).to(torch.device('cuda'))
        target_distribution = [normal_sampling(np.argmax(psd_gt), i, std) for i in range(psd_gt.size)]
        target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
        target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
        return kl_loss(pred_distribution, target_distribution)


class Neg_Pearson(nn.Module):  # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
    def __init__(self):
        super(Neg_Pearson, self).__init__()

    def forward(self, preds, labels):  # all variable operation
        loss = 0
        for i in range(preds.shape[0]):
            sum_x = torch.sum(preds[i])  # x
            sum_y = torch.sum(labels[i])  # y
            sum_xy = torch.sum(preds[i] * labels[i])  # xy
            sum_x2 = torch.sum(torch.pow(preds[i], 2))  # x^2
            sum_y2 = torch.sum(torch.pow(labels[i], 2))  # y^2
            N = preds.shape[1]
            pearson = (N * sum_xy - sum_x * sum_y) / (
                torch.sqrt((N * sum_x2 - torch.pow(sum_x, 2)) * (N * sum_y2 - torch.pow(sum_y, 2))))
            loss += 1 - pearson

        loss = loss / preds.shape[0]
        return loss


class RhythmFormer_Loss(nn.Module):
    def __init__(self):
        super(RhythmFormer_Loss, self).__init__()
        self.criterion_Pearson = Neg_Pearson()

    def forward(self, pred_ppg, labels, epoch, FS, diff_flag):
        loss_time = self.criterion_Pearson(pred_ppg.view(1, -1), labels.view(1, -1))
        loss_CE, loss_distribution_kl = TorchLossComputer.Frequency_loss(
            pred_ppg.squeeze(-1),
            labels.squeeze(-1),
            diff_flag=diff_flag,
            Fs=FS,
            std=3.0)
        loss_hr = TorchLossComputer.HR_loss(
            pred_ppg.squeeze(-1),
            labels.squeeze(-1),
            diff_flag=diff_flag,
            Fs=FS,
            std=3.0)
        if torch.isnan(loss_time):
            loss_time = 0

        loss = 0.2 * loss_time + 1.0 * loss_CE + 1.0 * loss_hr
        return loss

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.stats as stats


class HRHybridLoss(nn.Module):
    def __init__(self, alpha=0.2, beta=1.0, gamma=1.0):
        super(HRHybridLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, prediction, ground_truth):
        # Temporal Loss (negative Pearson correlation)
        l_time = -self.pearson_correlation(prediction, ground_truth)
        print(f'l_time={l_time}')

        # Frequency Loss
        l_freq = self.frequency_loss(prediction, ground_truth)
        print(f'l_freq={l_freq}')

        # HR Distance Loss
        l_hr = self.hr_distance_loss(prediction, ground_truth)
        print(f'l_hr={l_hr}')

        # Combine losses
        # total_loss = self.alpha * l_time + self.beta * l_freq + self.gamma * l_hr
        total_loss = self.beta * l_freq + self.gamma * l_hr

        return total_loss

    def pearson_correlation(self, x, y):
        # Compute Pearson correlation
        x_mean = torch.mean(x)
        y_mean = torch.mean(y)

        num = torch.sum((x - x_mean) * (y - y_mean))
        den = torch.sqrt(torch.sum((x - x_mean) ** 2) * torch.sum((y - y_mean) ** 2))

        return num / den

    def frequency_loss(self, pred_bvp, gt_bvp):
        # Compute PSD
        pred_psd = torch.abs(torch.fft.fft(pred_bvp)) ** 2
        gt_psd = torch.abs(torch.fft.fft(gt_bvp)) ** 2

        # Find max index of ground truth PSD
        gt_max_idx = torch.argmax(gt_psd)

        # Compute cross-entropy
        return F.cross_entropy(pred_psd.unsqueeze(0), gt_max_idx.unsqueeze(0))

    def hr_distance_loss(self, pred_hr, gt_hr):
        # Assuming pred_hr and gt_hr are distributions
        return F.kl_div(pred_hr, gt_hr, reduction='batchmean')

In [11]:
from pytz import timezone
from datetime import datetime

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
timestamp = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
manifest['timestamp'] = timestamp

epochs = 50
manifest['epochs'] = epochs

loss_fn = HRHybridLoss()
manifest['loss_fn'] = 'HRHybridLoss'

learning_rate = 0.000001
manifest['learning_rate'] = learning_rate

In [12]:
from tqdm.auto import tqdm
from respiration.extractor.rhythm_former import RhythmFormer

for model_name, model_path in models.items():
    model = RhythmFormer()
    # Fix model loading: Some key have an extra 'module.' prefix
    model = torch.nn.DataParallel(model)
    model.to(device)

    manifest = manifest.copy()
    manifest['model'] = model_name
    manifest['models'] = []

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    manifest['optimizer'] = 'AdamW'

    if model_path is not None:
        model.load_state_dict(torch.load(model_path, map_location=device))

    model_dir = utils.dir_path(
        'models',
        'rhythm_former',
        timestamp,
        model_name,
        mkdir=True)
    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        training_loader = VitalCamLoader(training, parts=6, device=device)
        testing_loader = VitalCamLoader(testing, parts=6, device=device)

        train_loss = 0
        for (frames, target) in tqdm(training_loader, desc=f'Training'):
            frames = preprocess_frames(frames)
            target = target.unsqueeze(0)

            # Forward pass
            output = model(frames)
            # print(f'output.shape={output.shape}')

            # Compute the loss
            loss = loss_fn(output, target)
            train_loss += loss.item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del frames, target, output, loss

        testing_loss = 0
        model.eval()
        with torch.no_grad():
            for (frames, target) in tqdm(testing_loader, desc=f'Testing'):
                frames = preprocess_frames(frames)
                target = target.unsqueeze(0)

                output = model(frames)
                loss = loss_fn(output, target)
                testing_loss += loss.item()

                del frames, target, output, loss

        # Compute the average loss
        train_loss /= len(training)
        testing_loss /= len(testing)

        if testing_loss < best_loss:
            best_loss = testing_loss
            model_file = utils.file_path(model_dir, f'{model_name}_{epoch}.pth')

            manifest['best_loss'] = best_loss
            manifest['models'].append({
                'epoch': epoch,
                'model_file': model_file,
                'train_loss': train_loss,
                'test_loss': testing_loss,
            })

            torch.save(model.state_dict(), model_file)

        print(f'{model_name}[{epoch + 1}/{epochs}] '
              f'train-loss={train_loss} '
              f'test-loss={testing_loss}')

    # Save the manifest
    manifest_file = utils.file_path(model_dir, 'manifest.json')
    utils.write_json(manifest_file, manifest)

Training:   0%|          | 0/120 [00:00<?, ?it/s]

l_time=0.26790231466293335


RuntimeError: Expected target size [1, 600], got [1]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

xxx = torch.randn(3600)
yyy = torch.randn(3600)

# F.kl_div(xxx.log(), yyy)

normal_sampling(np.argmax(0.3), 0, 3)
# HRHybridLoss()(xxx, yyy)