# Transformer Classifier

This notebook trains a Transformer based classifier to predict inhaling and exhaling from video frames.

In [None]:
import respiration.utils as utils

from pytz import timezone
from datetime import datetime

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
model_id = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'dataset': 'VitalCamSet',
}
model_id

In [None]:
device

## Define training and testing scenarios

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['101_natural_lighting'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

In [None]:
num_frames = 300
manifest['num_frames'] = num_frames

frame_patch_size = 2
manifest['frame_patch_size'] = frame_patch_size

## Test Temporal Shift Module

In [None]:
import torch
import torch.nn as nn


class TSMPPG(nn.Module):
    def __init__(self, n_segment=10, fold_div=3):
        super(TSMPPG, self).__init__()
        self.n_segment = n_segment
        self.fold_div = fold_div

    def forward(self, x):
        nt, c, h, w = x.size()
        n_batch = nt // self.n_segment  # 30
        print(f'nt={nt} c={c} h={h} w={w} n_batch={n_batch}')

        x = x.view(n_batch, self.n_segment, c, h, w)
        print(f'x.shape={x.shape}')

        fold = c // self.fold_div
        print(f'fold={fold}')

        out = torch.zeros_like(x)
        out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
        out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
        out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

        return out.view(nt, c, h, w)

In [None]:
import torch
import torch.nn as nn


class TSMOriginal(nn.Module):
    """
    Temporal Shift Module
    """

    def __init__(self, n_frame, fold_div=3):
        super(TSMOriginal, self).__init__()
        self.n_frame = n_frame
        self.fold_div = fold_div

    def forward(self, x):
        nt, c, h, w = x.size()
        print(f'nt={nt} c={c} h={h} w={w}')

        x = x.view(nt // self.n_frame, self.n_frame, c, h, w)
        print(f'x.shape={x.shape}')

        fold = c // self.fold_div
        print(f'fold={fold}')

        last_fold = c - (self.fold_div - 1) * fold
        print(f'last_fold={last_fold}')

        r_channel, g_channel, b_channel = torch.split(x, [fold, fold, last_fold], dim=2)
        print(f'channel.shape={r_channel.shape}')

        # Shift left
        padding_1 = torch.zeros_like(r_channel[:, -1, :, :, :])
        print(f'padding_1.shape={padding_1.shape}')
        padding_1 = padding_1.unsqueeze(1)
        print(f'padding_1.shape={padding_1.shape}')

        _, out1 = torch.split(r_channel, [1, self.n_frame - 1], dim=1)
        print(f'out1.shape={out1.shape}')

        out1 = torch.cat([out1, padding_1], dim=1)
        print(f'out1.shape={out1.shape}')

        # Shift right
        padding_2 = torch.zeros_like(g_channel[:, 0, :, :, :])
        padding_2 = padding_2.unsqueeze(1)
        out2, _ = torch.split(g_channel, [self.n_frame - 1, 1], dim=1)
        out2 = torch.cat([padding_2, out2], dim=1)

        out = torch.cat([out1, out2, b_channel], dim=2)
        out = out.view(nt, c, h, w)

        return out

In [None]:
subject = 'Proband16'
setting = '101_natural_lighting'

frames, meta = dataset.get_video_rgb(subject, setting, 300, start_position=200)
frames = torch.tensor(frames, dtype=torch.float32, device=device)
frames = frames.permute(0, 3, 1, 2)

tsm = TSMPPG(300).to(device)
tsm_frames = tsm(frames)

In [None]:
import matplotlib.pyplot as plt

test_frame = tsm_frames[10].permute(1, 2, 0).cpu().detach().numpy()

# Create a gird and display all three channels
fig, ax = plt.subplots(1, 3, figsize=(20, 20))

for idx in range(3):
    ax[idx].imshow(test_frame[:, :, idx], cmap='gray')
    ax[idx].axis('off')

plt.show()

In [None]:
# Show the test frame
plt.imshow(test_frame / 255.0)
plt.axis('off')
plt.show()

In [None]:
# Show the diff frame
diff = frames[0] - frames[10]
diff = diff.permute(1, 2, 0).cpu().detach().numpy()

# Normalize the diff frame
diff = (diff - diff.min()) / (diff.max() - diff.min())

plt.imshow(diff)
plt.axis('off')
plt.show()

## Define the model

In [None]:
import torch


def temporal_shifting_frames(frames: torch.Tensor) -> torch.Tensor:
    # Use unfold to create sliding windows
    diff_video = frames[1:] - frames[:-1]

    # Normalize the diff video between 0 and 1
    diff_video = (diff_video - diff_video.min()) / (diff_video.max() - diff_video.min())

    return diff_video


def temporal_shifting_signal(time_series: torch.Tensor) -> torch.Tensor:
    """
    Create a binary signal from the time series. The signal is 1 if the value is greater than the previous value, and 0 otherwise.
    """

    # Shift the signal that no negative values are present
    min_value = torch.min(time_series)
    if min_value < 0:
        time_series = time_series - min_value

    # Calculate the difference between the time series
    diff = time_series[1:] - time_series[:-1]

    # Make all >0 values 1 and all <0 values 0
    diff = torch.where(diff > 0, torch.tensor(1.0, device=device), torch.tensor(0.0, device=device))

    # Make to int
    diff = diff.to(torch.long)

    return diff

In [None]:
image_size = 224
manifest['image_size'] = image_size

# Keep every 10th data point
downsample_factor = 10
manifest['downsample_factor'] = downsample_factor

In [None]:
import cv2
import math
import torch
import numpy as np
import respiration.utils as utils

from torchvision import transforms


class ScenarioLoader:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """
    subject: str
    setting: str

    def __init__(self,
                 subject: str,
                 setting: str):
        self.subject = subject
        self.setting = setting

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        return downsample_factor

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        # Load the video frames
        cap = cv2.VideoCapture(self.video_path)

        frames = []
        for position in range(0, self.total_frames, downsample_factor):
            # Seek to the down-sampled position in the video
            cap.set(cv2.CAP_PROP_POS_FRAMES, position + index)

            ret, frame = cap.read()
            if not ret:
                break

            frames.append(frame)

        cap.release()

        frames = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames])

        preprocess = transforms.Compose([
            transforms.ToPILImage(mode='RGB'),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
        frames = frames.to(device)

        # Get the ground truth signal for the scenario
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)
        gt_waveform = torch.tensor(gt_waveform, dtype=torch.float32, device=device)
        gt_waveform = torch.nn.functional.normalize(gt_waveform, dim=0)
        gt_waveform = gt_waveform[index::downsample_factor]

        return frames, gt_waveform

In [None]:
def count_classes(xxx):
    xxx = torch.argmax(xxx, dim=1)
    count = np.unique(xxx.cpu().detach().numpy(), return_counts=True)
    return count

## Test the ScenarioLoader

In [None]:
loader = ScenarioLoader('Proband16', '101_natural_lighting')

chunk_frames, chunk_signal = loader[0]
chunk_frames.shape, chunk_signal.shape

In [None]:
diff = chunk_frames[0] - chunk_frames[1]
diff = diff.permute(1, 2, 0).cpu().detach().numpy()

# Normalize the diff frame
diff = (diff - diff.min()) / (diff.max() - diff.min())

plt.imshow(diff)
plt.axis('off')
plt.show()

In [None]:
shifted = temporal_shifting_frames(chunk_frames)
test_frame = shifted[10].permute(1, 2, 0).cpu().detach().numpy()

# Show the diff frame
plt.imshow(test_frame)
plt.axis('off')
plt.show()

## Model Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

log_dir = utils.dir_path('outputs', 'logs', model_id, mkdir=True)
writer = SummaryWriter(log_dir=log_dir)

In [None]:
from vit_pytorch import SimpleViT

# 2 classes: inhale, exhale
num_classes = 2
manifest['num_classes'] = num_classes

image_patch_size = 16
manifest['image_patch_size'] = image_patch_size

depth = 6
manifest['depth'] = depth

heads = 16
manifest['heads'] = heads

mlp_dim = 2048
manifest['mlp_dim'] = mlp_dim

embedding_dim = 512
manifest['embedding_dim'] = embedding_dim

spatial_depth = 6
manifest['spatial_depth'] = spatial_depth

temporal_depth = 6
manifest['temporal_depth'] = temporal_depth

model = SimpleViT(
    image_size=image_size,
    patch_size=image_patch_size,
    num_classes=num_classes,
    dim=embedding_dim,
    heads=heads,
    mlp_dim=mlp_dim,
    depth=depth,
).to(device)
manifest['base_model'] = 'SimpleViT'

In [None]:
epochs = 30
manifest['epochs'] = epochs

learning_rate = 0.00001
manifest['learning_rate'] = learning_rate

loss_fn = torch.nn.CrossEntropyLoss()
manifest['loss_fn'] = 'CrossEntropyLoss'

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'AdamW'

In [None]:
def train_one_epoch(epoch_index: int):
    epoch_loss = 0.0

    # Iterate over the training scenarios
    for (subject, setting) in training:
        loader = ScenarioLoader(subject, setting)

        scenario_accuracy = 0.0

        # Iterate over the hole scenario video in chunks
        for idy, (frames, gt_classes) in enumerate(loader):
            frames = temporal_shifting_frames(frames)

            gt_classes = temporal_shifting_signal(gt_classes)
            # Cut the gt_classes to match the frames
            gt_classes = gt_classes[:frames.shape[0]]

            # Make predictions for this chunk
            outputs = model(frames)
            predicted_classes = outputs.argmax(dim=1)

            # Compute the loss and its gradients
            loss = loss_fn(outputs, gt_classes)

            # Optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Gather data and report
            accuracy = (predicted_classes == gt_classes).float().mean()
            print(f'  {subject} #{idy:02d} outputs.count={count_classes(outputs)} accuracy={accuracy}')
            scenario_accuracy += accuracy

        scenario_accuracy /= len(loader)
        epoch_loss += scenario_accuracy

        print(f'  >> {subject} accuracy={scenario_accuracy}')
        writer.add_scalars('Training_Accuracy', {
            f'{subject}_{setting}': scenario_accuracy,
        }, epoch_index)
        writer.flush()

    return epoch_loss / len(training)

In [None]:
model_dir = utils.dir_path('models', 'transformer', model_id, mkdir=True)

In [None]:
def save_manifest(best_accuracy):
    manifest['trained_models'] = models
    manifest['best_testing_accuracy'] = float(best_accuracy)
    manifest['timestamp_finish'] = datetime.now().astimezone().isoformat()
    utils.write_json(os.path.join(model_dir, 'manifest.json'), manifest)

In [None]:
import os
from tqdm.auto import tqdm

best_accuracy = 0.0
models = []

for epoch in tqdm(range(epochs)):
    print(f'Epoch {epoch}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)

    running_accuracy = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for inx, (subject, setting) in enumerate(testing):
            loader = ScenarioLoader(subject, setting)
            testing_accuracy = 0.0

            for (frames, gt_classes) in loader:
                frames = temporal_shifting_frames(frames)
                gt_classes = temporal_shifting_signal(gt_classes)
                gt_classes = gt_classes[:frames.shape[0]]
                voutputs = model(frames)
                testing_accuracy += (voutputs.argmax(dim=1) == gt_classes).float().mean()

            testing_accuracy /= len(loader)
            writer.add_scalars('Testing_Loss', {f'{subject}_{setting}': testing_accuracy}, epoch)
            print(f'  >> {subject} accuracy={testing_accuracy}')

            running_accuracy += testing_accuracy

    testing_accuracy = running_accuracy / len(testing)
    print(f'LOSS training={avg_loss} testing={testing_accuracy}')
    writer.add_scalars('Average_Accuracy', {
        'Training': avg_loss,
        'Testing': testing_accuracy,
    }, epoch)
    writer.flush()

    # Track the best performance, and save the model's state
    if testing_accuracy > best_accuracy:
        best_accuracy = testing_accuracy
        model_name = f'{model_id}_{epoch}.pth'

        model_path = os.path.join(model_dir, model_name)
        torch.save(model.state_dict(), model_path)

        models.append({
            'model': model_name,
            'epoch': epoch,
            'validation_accuracy': float(testing_accuracy),
        })
        save_manifest(testing_accuracy)

In [None]:
save_manifest(best_accuracy)