In [1]:
! pip install -q umap-learn speechbrain

In [2]:
%%writefile loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class LabelDifference(nn.Module):
    def __init__(self, distance_type="l1"):
        super(LabelDifference, self).__init__()
        self.distance_type = distance_type

    def forward(self, labels):
        # labels: [bs, label_dim]
        # output: [bs, bs]
        if self.distance_type == "l1":
            return torch.abs(labels[:, None, :] - labels[None, :, :]).sum(dim=-1)
        else:
            raise ValueError(self.distance_type)


class FeatureSimilarity(nn.Module):
    def __init__(self, similarity_type="l2"):
        super(FeatureSimilarity, self).__init__()
        self.similarity_type = similarity_type

    def forward(self, features):
        # labels: [bs, feat_dim]
        # output: [bs, bs]
        if self.similarity_type == "l2":
            return -(features[:, None, :] - features[None, :, :]).norm(2, dim=-1)
        else:
            raise ValueError(self.similarity_type)


class RnCLoss(nn.Module):
    def __init__(self, temperature=2, label_diff="l1", feature_sim="l2"):
        super(RnCLoss, self).__init__()
        self.t = temperature
        self.label_diff_fn = LabelDifference(label_diff)
        self.feature_sim_fn = FeatureSimilarity(feature_sim)

    def forward(self, features, labels):
        # features: [bs, 2, feat_dim]
        # labels: [bs, label_dim]

        features = torch.cat([features[:, 0], features[:, 1]], dim=0)  # [2bs, feat_dim]
        labels = labels.repeat(2, 1)  # [2bs, label_dim]

        label_diffs = self.label_diff_fn(labels)
        logits = self.feature_sim_fn(features).div(self.t)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits -= logits_max.detach()
        exp_logits = logits.exp()

        n = logits.shape[0]  # n = 2bs

        # remove diagonal
        logits = logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(
            n, n - 1
        )
        exp_logits = exp_logits.masked_select(
            (1 - torch.eye(n).to(logits.device)).bool()
        ).view(n, n - 1)
        label_diffs = label_diffs.masked_select(
            (1 - torch.eye(n).to(logits.device)).bool()
        ).view(n, n - 1)

        loss = 0.0
        for k in range(n - 1):
            pos_logits = logits[:, k]  # 2bs
            pos_label_diffs = label_diffs[:, k]  # 2bs
            neg_mask = (
                label_diffs >= pos_label_diffs.view(-1, 1)
            ).float()  # [2bs, 2bs - 1]
            pos_log_probs = pos_logits - torch.log(
                (neg_mask * exp_logits).sum(dim=-1)
            )  # 2bs
            loss += -(pos_log_probs / (n * (n - 1))).sum()

        return loss


Writing loss.py


In [3]:
%%writefile convnext.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath, trunc_normal_


class Block(nn.Module):
    def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=3, groups=dim
        )  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(
            dim, 4 * dim
        )  # pointwise conv, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = (
            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 1)  # (N, C, W) -> (N, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 2, 1)  # (N, W, C) -> (N, C, W)
        x = input + self.drop_path(x)
        return x


class ConvNeXt(nn.Module):
    def __init__(
        self,
        in_chans=1,
        num_classes=1000,
        depths=[3, 3, 9, 3],
        dims=[96, 192, 384, 768],
        drop_path_rate=0.0,
        layer_scale_init_value=1e-6,
        head_init_scale=1.0,
    ):
        super().__init__()

        self.downsample_layers = (
            nn.ModuleList()
        )  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = (
            nn.ModuleList()
        )  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[
                    Block(
                        dim=dims[i],
                        drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value,
                    )
                    for j in range(depths[i])
                ]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv1d, nn.Linear)):
            trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean(-1))  # global average pooling, (N, C, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(
                x, self.normalized_shape, self.weight, self.bias, self.eps
            )
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None] * x + self.bias[:, None]
            return x


def convnext_tiny(**kwargs):
    model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
    return model


def convnext_small(**kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
    return model


def convnext_base(**kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    return model


def convnext_large(**kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    return model


def convnext_xlarge(**kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
    return model


Writing convnext.py


In [4]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [5]:
import glob
import os

import mne
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

data_dir = "/kaggle/input/ppg-dalia-processed"

def get_dataset(test_idx):
    test_files = [os.path.join(data_dir, f"S{test_idx}.npz")]
    train_files = [
        os.path.join(data_dir, f"S{train_idx}.npz")
        for train_idx in range(1, 16)
        if train_idx != test_idx
    ]

    files = {
        "train": train_files,
        "test": test_files,
    }
    ecgs = {"train": [], "test": []}
    ppgs = {"train": [], "test": []}

    labels = {"train": [], "test": []}

    for split in ["train", "test"]:
        for file in files[split]:
            data = np.load(file)
            ecg = data["ecg"]
            ppg = data["ppg"]
            label = data["label"]
            ecg_resampled = mne.filter.resample(ecg.astype(float), down=700 / 64)

            ecgs[split].append(ecg_resampled)
            ppgs[split].append(ppg)
            labels[split].append(label)

        ecgs[split] = np.concatenate(ecgs[split], axis=0)
        ppgs[split] = np.concatenate(ppgs[split], axis=0)
        labels[split] = np.concatenate(labels[split], axis=0)

    mean, std = labels["train"].mean(), labels["train"].std()


    def normalize_labels(label):
        return (label - mean) / std


    def unnormalize_labels(label):
        return std * label + mean


    class PPG_DaLiA(Dataset):

        def __init__(self, ecgs, ppgs, labels):
            self.ecgs = (ecgs - np.min(ecgs)) / (np.max(ecgs) - np.min(ecgs))
            self.ppgs = (ppgs - np.min(ppgs)) / (np.max(ppgs) - np.min(ppgs))
            self.labels = normalize_labels(labels)

        def __len__(self):
            return len(self.labels)

        def __getitem__(self, index):
            ecg = torch.as_tensor(self.ecgs[index], dtype=torch.float32)
            ppg = torch.as_tensor(self.ppgs[index], dtype=torch.float32)
            label = torch.as_tensor(self.labels[index], dtype=torch.float32)

            ecg = ecg.unsqueeze(0)
            ppg = ppg.unsqueeze(0)
            label = label.unsqueeze(0)

            return ecg, ppg, label


    dataset = {
        "train": PPG_DaLiA(ecgs["train"], ppgs["train"], labels["train"]),
        "test": PPG_DaLiA(ecgs["test"], ppgs["test"], labels["test"]),
    }
    return dataset, unnormalize_labels

In [6]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.aggregation import MeanMetric
from tqdm.auto import tqdm
from speechbrain.nnet.CNN import SincConv

from convnext import ConvNeXt
from loss import RnCLoss

device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

session_wise_mae = []

for test_idx in range(1, 16):
    print(f"=== TEST INDEX - {test_idx} ===")
    dataset, unnormalize_labels = get_dataset(test_idx)
    train_loader = DataLoader(dataset["train"], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset["test"], batch_size=batch_size)
    sincnet = SincConv(
        out_channels=8,
        in_channels=1,
        kernel_size=25,
        stride=2,
        sample_rate=64,
        min_low_hz=0.5,
        min_band_hz=1,
    )

    sincnet.to(device)

    ecg_enc = ConvNeXt(depths=[3, 3, 9, 3], dims=[32, 64, 128, 256], num_classes=1)
    ecg_enc.to(device)

    ppg_enc = ConvNeXt(depths=[3, 3, 9, 3], dims=[32, 64, 128, 256], num_classes=1,in_chans = 8)
    ppg_enc.to(device)

    model_size = 0
    for param in ppg_enc.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    # Pretraining: Rank-N-Constrast
    epochs = 10
    lr = 3e-4
    optimizer_sincnet = torch.optim.Adam(sincnet.parameters(), lr=lr)
    optimizer_ppg = torch.optim.Adam(ppg_enc.parameters(), lr=lr)
    optimizer_ecg = torch.optim.Adam(ecg_enc.parameters(), lr=lr)
    criterion = RnCLoss(temperature=2, label_diff="l1", feature_sim="l2")

    avg_loss = MeanMetric()
    train_losses = []

    pbar = tqdm(total=epochs*len(train_loader))

    for epoch in range(epochs):
        for ecgs, ppgs, labels in train_loader:
            ecgs = ecgs.to(device)
            ppgs = ppgs.to(device)
            labels = labels.to(device)
            
            ppgs_filtered = sincnet(ppgs.mT).mT
            ecg_feats = ecg_enc.forward_features(ecgs)
            ppg_feats = ppg_enc.forward_features(ppgs_filtered)
            feats = torch.stack([ecg_feats, ppg_feats], dim=1)
            loss = criterion(feats, labels)

            optimizer_ecg.zero_grad()
            optimizer_ppg.zero_grad()
            loss.backward()
            optimizer_ecg.step()
            optimizer_ppg.step()

            train_losses.append(loss.item())
            avg_loss.update(loss.item())
            pbar.update()

        print(f"Epoch={epoch}, Loss={avg_loss.compute():.2f}")
        avg_loss.reset()

    pbar.close()

    # Fine-Tuning

    model = ppg_enc
    epochs = 3
    lr = 3e-4
    optimizer = torch.optim.Adam(model.head.parameters(), lr=lr)

    train_losses = []
    val_mae = []

    for epoch in tqdm(range(epochs)):
        for _, inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                inputs = sincnet(inputs.mT).mT
                feats = model.forward_features(inputs)
            preds = model.head(feats)
            loss = F.mse_loss(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

    # Test
    batched_errors = []
    for _, inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            inputs = sincnet(inputs.mT).mT
            preds = model(inputs)

        errors = F.l1_loss(
            unnormalize_labels(preds), unnormalize_labels(labels), reduction='none',
        )
        batched_errors.append(errors.cpu().numpy())

    batched_errors = np.concatenate(batched_errors, axis=0)
    mae = np.mean(batched_errors)

    print("Test MAE:", round(mae, 2))

    session_wise_mae.append(mae)

=== TEST INDEX - 1 ===
Model params: 2.86 M


  0%|          | 0/18780 [00:00<?, ?it/s]

Epoch=0, Loss=3.05
Epoch=1, Loss=2.61
Epoch=2, Loss=2.50
Epoch=3, Loss=2.43
Epoch=4, Loss=2.37
Epoch=5, Loss=2.31
Epoch=6, Loss=2.26
Epoch=7, Loss=2.20
Epoch=8, Loss=2.15
Epoch=9, Loss=2.10


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 5.69
=== TEST INDEX - 2 ===
Model params: 2.86 M


  0%|          | 0/18940 [00:00<?, ?it/s]

Epoch=0, Loss=3.01
Epoch=1, Loss=2.53
Epoch=2, Loss=2.42
Epoch=3, Loss=2.35
Epoch=4, Loss=2.30
Epoch=5, Loss=2.24
Epoch=6, Loss=2.18
Epoch=7, Loss=2.13
Epoch=8, Loss=2.06
Epoch=9, Loss=2.01


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 4.89
=== TEST INDEX - 3 ===
Model params: 2.86 M


  0%|          | 0/18860 [00:00<?, ?it/s]

Epoch=0, Loss=3.02
Epoch=1, Loss=2.56
Epoch=2, Loss=2.45
Epoch=3, Loss=2.38
Epoch=4, Loss=2.32
Epoch=5, Loss=2.26
Epoch=6, Loss=2.20
Epoch=7, Loss=2.15
Epoch=8, Loss=2.09
Epoch=9, Loss=2.04


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 2.9
=== TEST INDEX - 4 ===
Model params: 2.86 M


  0%|          | 0/18790 [00:00<?, ?it/s]

Epoch=0, Loss=2.96
Epoch=1, Loss=2.55
Epoch=2, Loss=2.45
Epoch=3, Loss=2.38
Epoch=4, Loss=2.33
Epoch=5, Loss=2.27
Epoch=6, Loss=2.21
Epoch=7, Loss=2.16
Epoch=8, Loss=2.11
Epoch=9, Loss=2.05


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 7.12
=== TEST INDEX - 5 ===
Model params: 2.86 M


  0%|          | 0/18770 [00:00<?, ?it/s]

Epoch=0, Loss=2.97
Epoch=1, Loss=2.60
Epoch=2, Loss=2.52
Epoch=3, Loss=2.45
Epoch=4, Loss=2.38
Epoch=5, Loss=2.33
Epoch=6, Loss=2.27
Epoch=7, Loss=2.22
Epoch=8, Loss=2.16
Epoch=9, Loss=2.10


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 13.57
=== TEST INDEX - 6 ===
Model params: 2.86 M


  0%|          | 0/19400 [00:00<?, ?it/s]

Epoch=0, Loss=2.97
Epoch=1, Loss=2.59
Epoch=2, Loss=2.50
Epoch=3, Loss=2.44
Epoch=4, Loss=2.38
Epoch=5, Loss=2.34
Epoch=6, Loss=2.27
Epoch=7, Loss=2.22
Epoch=8, Loss=2.18
Epoch=9, Loss=2.12


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 4.43
=== TEST INDEX - 7 ===
Model params: 2.86 M


  0%|          | 0/18760 [00:00<?, ?it/s]

Epoch=0, Loss=2.97
Epoch=1, Loss=2.59
Epoch=2, Loss=2.52
Epoch=3, Loss=2.47
Epoch=4, Loss=2.42
Epoch=5, Loss=2.38
Epoch=6, Loss=2.33
Epoch=7, Loss=2.27
Epoch=8, Loss=2.22
Epoch=9, Loss=2.16


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 2.87
=== TEST INDEX - 8 ===
Model params: 2.86 M


  0%|          | 0/18960 [00:00<?, ?it/s]

Epoch=0, Loss=3.04
Epoch=1, Loss=2.54
Epoch=2, Loss=2.44
Epoch=3, Loss=2.38
Epoch=4, Loss=2.33
Epoch=5, Loss=2.26
Epoch=6, Loss=2.20
Epoch=7, Loss=2.15
Epoch=8, Loss=2.09
Epoch=9, Loss=2.04


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 11.25
=== TEST INDEX - 9 ===
Model params: 2.86 M


  0%|          | 0/18890 [00:00<?, ?it/s]

Epoch=0, Loss=2.93
Epoch=1, Loss=2.52
Epoch=2, Loss=2.43
Epoch=3, Loss=2.37
Epoch=4, Loss=2.32
Epoch=5, Loss=2.27
Epoch=6, Loss=2.22
Epoch=7, Loss=2.17
Epoch=8, Loss=2.12
Epoch=9, Loss=2.06


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 11.12
=== TEST INDEX - 10 ===
Model params: 2.86 M


  0%|          | 0/18560 [00:00<?, ?it/s]

Epoch=0, Loss=2.98
Epoch=1, Loss=2.55
Epoch=2, Loss=2.45
Epoch=3, Loss=2.38
Epoch=4, Loss=2.32
Epoch=5, Loss=2.27
Epoch=6, Loss=2.22
Epoch=7, Loss=2.16
Epoch=8, Loss=2.11
Epoch=9, Loss=2.06


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 3.85
=== TEST INDEX - 11 ===
Model params: 2.86 M


  0%|          | 0/18810 [00:00<?, ?it/s]

Epoch=0, Loss=2.94
Epoch=1, Loss=2.58
Epoch=2, Loss=2.48
Epoch=3, Loss=2.42
Epoch=4, Loss=2.36
Epoch=5, Loss=2.31
Epoch=6, Loss=2.25
Epoch=7, Loss=2.19
Epoch=8, Loss=2.16
Epoch=9, Loss=2.10


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 6.12
=== TEST INDEX - 12 ===
Model params: 2.86 M


  0%|          | 0/18990 [00:00<?, ?it/s]

Epoch=0, Loss=2.94
Epoch=1, Loss=2.56
Epoch=2, Loss=2.46
Epoch=3, Loss=2.38
Epoch=4, Loss=2.33
Epoch=5, Loss=2.26
Epoch=6, Loss=2.22
Epoch=7, Loss=2.16
Epoch=8, Loss=2.12
Epoch=9, Loss=2.07


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 6.42
=== TEST INDEX - 13 ===
Model params: 2.86 M


  0%|          | 0/18800 [00:00<?, ?it/s]

Epoch=0, Loss=2.96
Epoch=1, Loss=2.59
Epoch=2, Loss=2.50
Epoch=3, Loss=2.44
Epoch=4, Loss=2.40
Epoch=5, Loss=2.35
Epoch=6, Loss=2.29
Epoch=7, Loss=2.24
Epoch=8, Loss=2.18
Epoch=9, Loss=2.13


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 2.75
=== TEST INDEX - 14 ===
Model params: 2.86 M


  0%|          | 0/18820 [00:00<?, ?it/s]

Epoch=0, Loss=2.98
Epoch=1, Loss=2.56
Epoch=2, Loss=2.47
Epoch=3, Loss=2.41
Epoch=4, Loss=2.37
Epoch=5, Loss=2.31
Epoch=6, Loss=2.26
Epoch=7, Loss=2.21
Epoch=8, Loss=2.15
Epoch=9, Loss=2.11


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 3.7
=== TEST INDEX - 15 ===
Model params: 2.86 M


  0%|          | 0/18980 [00:00<?, ?it/s]

Epoch=0, Loss=3.01
Epoch=1, Loss=2.56
Epoch=2, Loss=2.48
Epoch=3, Loss=2.42
Epoch=4, Loss=2.37
Epoch=5, Loss=2.31
Epoch=6, Loss=2.26
Epoch=7, Loss=2.21
Epoch=8, Loss=2.16
Epoch=9, Loss=2.11


  0%|          | 0/3 [00:00<?, ?it/s]

Test MAE: 3.73


In [7]:
mean_mae = np.mean(session_wise_mae)
std_mae = np.std(session_wise_mae)

print(f"MAE [bpm, session-wise]: {mean_mae:.2f} ± {std_mae:.2f}")

MAE [bpm, session-wise]: 6.03 ± 3.28
