In [1]:
import hashlib
import os
import tabulate

import numpy as np
import pandas as pd
import scipy.io
import torch

In [2]:
a = scipy.io.loadmat("wave2value.mat")["data_store"][0, 0]
b = scipy.io.loadmat("wave2wave.mat")["data_store"][0, 0]
c = scipy.io.loadmat("subject_idx_with_age.mat")["subjects"][0]
x_a = torch.tensor(a[0]).to(torch.float).transpose(-1, -2)
x_b = torch.tensor(b[0]).to(torch.float).transpose(-1, -2)
y_a = torch.tensor(a[1]).to(torch.float)
y_b = torch.tensor(b[1]).to(torch.float)
print(x_a.shape, y_a.shape)
print(x_b.shape, y_b.shape)

torch.Size([20444, 3, 1000]) torch.Size([20444, 2])
torch.Size([33635, 3, 1000]) torch.Size([33635, 1000])


In [None]:
def scalarize(arr, dtype=None):
    out = [
        np.asarray(x).ravel()[0] 
        if np.asarray(x).size else np.nan for x in arr
    ]
    if dtype is None: return np.array(out)
    else: return np.array(out, dtype=dtype)
profile = pd.DataFrame({
    'id'        : scalarize(b[2], dtype=str).flatten(),
    'group'     : scalarize(b[3], dtype=str).flatten(),
    'repeat'    : scalarize(b[4], dtype=bool).flatten(),
    'condition' : scalarize(b[5], dtype=int).flatten()
})
profile = profile.rename(columns={"id": "measurement"})
profile["subject"] = profile["measurement"].str.split("_").str[0]
profile["health"] = profile["group"] != "hypertensive"
profile["system"] = profile["group"] != "original"
profile["pulse"] = profile.groupby("measurement").cumcount()
profile["pulse_norm"] = (profile.groupby("measurement")["pulse"].transform(
    lambda s: 0.0 if len(s) <= 1 else (s - s.min()) / (s.max() - s.min())
).round(4))
# arm and age
def scalar(x):
    x = np.asarray(x)
    return x.item() if x.size == 1 else x
measurement = pd.DataFrame([{
    "measurement": scalar(s[0]),
    "arm": bool(scalar(s[4])),
    "age": int(scalar(s[5])),
} for s in c])
profile = profile.merge(
    measurement[["measurement", "arm", "age"]],
    on="measurement", how="left", validate="many_to_one",
)
# systole, diastole
def tensor_hash(sample: torch.Tensor) -> str:
    sample = sample.contiguous().cpu().numpy()          # type: ignore
    return hashlib.sha1(sample.tobytes()).hexdigest()   # type: ignore
hash_to_index_a = {tensor_hash(x_a[i]): i for i in range(x_a.shape[0])}
overlap_indices_b = []
overlap_pairs = []  # (index_in_b, index_in_a)
for i in range(x_b.shape[0]):
    h = tensor_hash(x_b[i])
    if h in hash_to_index_a:
        overlap_indices_b.append(i)
        overlap_pairs.append((i, hash_to_index_a[h]))
assert len(overlap_indices_b) == x_a.shape[0]
profile.loc[overlap_indices_b, 'systole'] = y_a[:, 0].numpy()
profile.loc[overlap_indices_b, 'diastole'] = y_a[:, 1].numpy()
profile['systole'] = profile['systole'].astype(float).round(4)
profile['diastole'] = profile['diastole'].astype(float).round(4)
# save
profile = profile[[
    # subject level
    'subject', 'group', 'health', 'system', 'age',
    # measurement level
    'measurement', 'repeat', 'arm',
    # sample level
    'pulse', 'pulse_norm', 'condition', 'systole', 'diastole', 
]]

In [4]:
n = 4
df_fmt = profile.copy()
for col in df_fmt.select_dtypes(include=["float"]).columns:
    df_fmt[col] = df_fmt[col].map(
        lambda x: f"{x:.4f}" if pd.notna(x) else "nan"
    )
df_show = pd.concat([
    df_fmt.head(n),
    pd.DataFrame(
        [["..."] * df_fmt.shape[1]], columns=df_fmt.columns, index=["..."]
    ),
    df_fmt.tail(n),
],axis=0)
print(tabulate.tabulate(
    df_show, headers=df_show.columns,   # type: ignore
    tablefmt="github", showindex=True,
))

|       | subject   | group        | health   | system   | age   | measurement   | repeat   | arm   | pulse   | pulse_norm   | condition   | systole   | diastole   |
|-------|-----------|--------------|----------|----------|-------|---------------|----------|-------|---------|--------------|-------------|-----------|------------|
| 0     | S001      | original     | True     | False    | 28    | S001          | False    | False | 0       | 0.0000       | 1           | nan       | nan        |
| 1     | S001      | original     | True     | False    | 28    | S001          | False    | False | 1       | 0.0023       | 1           | nan       | nan        |
| 2     | S001      | original     | True     | False    | 28    | S001          | False    | False | 2       | 0.0045       | 1           | 138.2190  | 92.0960    |
| 3     | S001      | original     | True     | False    | 28    | S001          | False    | False | 3       | 0.0068       | 1           | 139.8688  | 91.0879    |
| ..

In [5]:
os.makedirs("data/raw", exist_ok=True)
profile.to_csv("data/raw/profile.csv", index=False)
np.save("data/raw/x.npy", x_b.numpy().astype(np.float32))
np.save("data/raw/y.npy", y_b.numpy().astype(np.float32))