# TS-CAN: rPPG Estimation

This notebook demonstrates the use of the TS-CAN model for remote photoplethysmography (rPPG) estimation. The model is based on the paper [Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement](https://arxiv.org/abs/2006.03790) by Xin Liu and Josh Fromm. The model is implemented in PyTorch.

The rPPG signal is later used to estimate the respiratory rate of the subject.

## Load the TS-CAN model

In [None]:
import torch
import respiration.utils as utils
from respiration.extractor.ts_can import TSCAN

device = torch.device('cpu')

dim = 72

model = TSCAN(img_size=dim)
model = torch.nn.DataParallel(model).to(device)

model_path = utils.file_path('data', 'rPPG-Toolbox', 'BP4D_PseudoLabel_TSCAN.pth')
model.load_state_dict(torch.load(model_path, map_location=device))

model = model.module.to(device)

## Load the test video

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband21'
setting = '101_natural_lighting'

frames, meta = dataset.get_video_rgb(subject, setting, show_progress=True)

pleth = dataset.get_vital_sign(subject, setting, utils.VitalSigns.pleth)
breath = dataset.get_vital_sign(subject, setting, utils.VitalSigns.thorax_abdomen)

In [None]:
import respiration.extractor.mtts_can.preprocess as preprocess

raw, diff = preprocess.preprocess_video_frames(frames, dim)

frame_depth = 20

# The model expects a number of frames that is a multiple of frame_depth
cut_off = (raw.shape[0] // frame_depth) * frame_depth
input_resized = raw[:cut_off]
input_normalized = diff[:cut_off]

# Permute from (T, H, W, C) to (T, C, H, W)
input_resized = torch.tensor(input_resized).permute(0, 3, 1, 2)
input_normalized = torch.tensor(input_normalized).permute(0, 3, 1, 2)

print(input_resized.shape, input_normalized.shape)

# Stack the two channels
frames_chunk = torch.cat((input_normalized, input_resized), dim=1).to(device)
frames_chunk.shape

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2, figsize=(20, 6))

# Permute from (C, H, W) to (H, W, C)
diff_frame = input_normalized[0].permute(1, 2, 0)
raw_frame = input_resized[0].permute(1, 2, 0)

# Normalize the frames
diff_frame = (diff_frame - diff_frame.min()) / (diff_frame.max() - diff_frame.min())
raw_frame = (raw_frame - raw_frame.min()) / (raw_frame.max() - raw_frame.min())

axs[0].imshow(diff_frame)
axs[0].set_title('Diff channel')

axs[1].imshow(raw_frame)
axs[1].set_title('Raw channel')

plt.show()

In [None]:
with torch.no_grad():
    output = model(frames_chunk)

prediction = output.cpu().detach().numpy().squeeze()
prediction.shape

## Compare the predicted rPPG signal with the ground truth

In [None]:
import respiration.analysis as analysis

comparator = analysis.SignalComparator(
    prediction,
    breath[1:len(prediction) + 1],
    sample_rate=meta.fps,
    lowpass=0.08,
    highpass=0.5,
    detrend_tarvainen=False,
    filter_signal=True,
    normalize_signal=True,
)

In [None]:
comparator.errors()

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 6))

# Plot the predicted rPPG signal
plt.plot(comparator.prediction, label='Predicted rPPG signal')
plt.plot(comparator.ground_truth, label='Ground truth pleth signal')

plt.title('rPPG estimation using DeepPhys')
plt.xlabel('Frame')
plt.ylabel('Signal value')

plt.legend()
plt.show()