In [None]:
from datetime import datetime
import respiration.utils as utils

# The timestamp is the unique identifier for this training run
model_id = datetime.now().strftime('%Y%m%d_%H%M%S')
device = utils.get_torch_device()

image_size = 256

# The manifest will store all the metadata for this training run
manifest = {
    'id': model_id,
    'device': str(device),
    'timestamp_start': datetime.now().astimezone().isoformat(),
    'base_model': 'SimpleViT',
    'dataset': 'VitalCamSet',
    'image_size': image_size,
}

## Dataset and Dataloader

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['101_natural_lighting'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

In [None]:
import math
import torch
from torchvision import transforms


class ScenarioLoader:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """

    def __init__(self, subject: str, setting: str, chunk_size: int):
        self.subject = subject
        self.setting = setting
        self.chunk_size = chunk_size + 1

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        return math.ceil(self.total_frames / self.chunk_size)

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        start = index * self.chunk_size
        end = start + self.chunk_size
        size = min(self.chunk_size, self.total_frames - start)

        # Load the video frames
        frames, _ = utils.read_video_rgb(self.video_path, size, start)
        preprocess = transforms.Compose([
            transforms.ToPILImage(mode='RGB'),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
        frames = frames.to(device)

        # Get the ground truth signal for the scenario
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)
        gt_waveform = torch.tensor(gt_waveform, dtype=torch.float32, device=device)
        gt_waveform = torch.nn.functional.normalize(gt_waveform, dim=0)
        gt_waveform = gt_waveform[1 + start:end]

        return frames, gt_waveform

In [None]:
data = ScenarioLoader('Proband16', '101_natural_lighting', 128)
print(len(data))

In [None]:
frames, gt = data[27]
frames.shape, gt.shape

In [None]:
frames.device, gt.device

In [None]:
import matplotlib.pyplot as plt

plt.plot(gt.cpu())
plt.show()

In [None]:
import cv2

frame_1 = frames[0].permute(1, 2, 0).cpu().numpy()
print(frame_1.shape)

frame_gray = cv2.cvtColor(frame_1, cv2.COLOR_RGB2GRAY)

# Plot the first frame
plt.imshow(frame_gray)
plt.show()

## Define frame preprocessing

In [None]:
def pre_procesing_big_small(frames: torch.Tensor) -> torch.Tensor:
    diff_frames = frames[1:] - frames[:-1]
    sum_frames = frames[1:] + frames[:-1]
    inputs = diff_frames / (sum_frames + 1e-7)
    inputs = (inputs - torch.mean(inputs)) / torch.std(inputs)
    return inputs


# Bad
def pre_procesing_normalize(frames: torch.Tensor) -> torch.Tensor:
    diff_frames = frames[1:] - frames[:-1]
    normalized = torch.nn.functional.normalize(diff_frames, dim=0)
    return normalized

In [None]:
inputs = pre_procesing_big_small(frames)
inputs.shape

In [None]:
plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
plt.show()

## Model Training

In [None]:
from vit_pytorch import SimpleViT

model = SimpleViT(
    image_size=image_size,
    patch_size=32,
    num_classes=1,
    dim=1024,
    depth=6,
    heads=16,
    mlp_dim=2048
).to(device)

In [None]:
epochs = 20
manifest['epochs'] = epochs

learning_rate = 0.001
manifest['learning_rate'] = learning_rate

loss_fn = torch.nn.MSELoss()
manifest['loss_fn'] = 'MSELoss'

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
manifest['optimizer'] = 'AdamW'

In [None]:
from tqdm.auto import tqdm

model.train()

for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch}:")
    epoch_loss = 0

    for (subject, setting) in training:
        loader = ScenarioLoader(subject, setting, 256)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        scenario_loss = 0

        for idx, (frames, gt) in enumerate(data):
            inputs = pre_procesing_big_small(frames)
            prediction = model(inputs).squeeze()
            loss = loss_fn(prediction, gt)
            loss.backward()

            scenario_loss += loss.item()

        scenario_loss /= len(data)
        epoch_loss += scenario_loss
        print(f"--> scenario_loss: {subject} {scenario_loss}")

        # Adjust learning weights
        optimizer.step()

    epoch_loss /= len(training)
    print(f"--> epoch_loss: {epoch_loss}")

# Store the model
model_path = utils.dir_path('models', 'transformer', model_id)
torch.save(model.state_dict(), utils.join_paths(model_path, f'{model_id}.pth'))

In [None]:
import numpy as np

model.eval()

testing_loss = 0

predictions = []

for (subject, setting) in testing:
    loader = ScenarioLoader(subject, setting, 256)

    scenario_loss = 0

    parts = []

    for idx, (frames, gt) in enumerate(data):
        inputs = pre_procesing_big_small(frames)
        prediction = model(inputs).squeeze()
        loss = torch.nn.functional.mse_loss(prediction, gt)
        loss.backward()
        parts.append(prediction)

        scenario_loss += loss.item()

    scenario_loss /= len(data)
    testing_loss += scenario_loss
    print(f"--> scenario_loss: {subject} {scenario_loss}")

    # Combine the parts into a single numpy array
    prediction = torch.cat(parts, dim=0)
    predictions.append(prediction.cpu().detach())

testing_loss /= len(testing)
print(f"--> testing_loss: {testing_loss}")

predictions = np.array(predictions)

In [None]:
import respiration.analysis as analysis

prediction = predictions[-1]
gt_signal = dataset.get_breathing_signal('Proband26', '101_natural_lighting')
gt_signal = gt_signal[:prediction.shape[0]]

compare = analysis.SignalComparator(gt_signal, prediction, 30)
compare.errors()

In [None]:
compare.signal_distances()

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(compare.ground_truth, label='Ground Truth')
plt.plot(compare.prediction, label='Prediction')
plt.show()

In [None]:
from tqdm.auto import tqdm

errors = []

for idx, (subject, setting) in tqdm(enumerate(testing), total=len(testing)):
    prediction = predictions[idx]
    gt_signal = dataset.get_breathing_signal(subject, setting)
    gt_signal = gt_signal[:prediction.shape[0]]

    compare = analysis.SignalComparator(gt_signal, prediction, 30)
    errors.append({
        'subject': subject,
        'setting': setting,
        'errors': compare.errors(),
        'distances': compare.signal_distances(),
    })

utils.pretty_print(errors)