In [None]:
import os
import pickle
import re
import sys
import time
from copy import deepcopy
from pathlib import Path
from typing import Dict, Literal, Sequence, Tuple, Union

import numpy as np
import torch
import wfdb
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP  # noqa: F401
from torch_ecg.cfg import CFG, DEFAULTS
from torch_ecg.utils.misc import str2bool
from torch_ecg.utils.utils_nn import default_collate_fn as collate_fn
from tqdm.auto import tqdm

from cfg import ModelCfg, TrainCfg
from data_reader import CODE15, PTBXL, SamiTrop
from dataset import CINC2025Dataset
from models import CRNN_CINC2025, FM_CINC2025
from outputs import CINC2025Outputs
from trainer import CINC2025Trainer
from utils.samplers import BalancedBatchSampler

# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/torch_ecg/")
# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/bib_lookup/")

%load_ext autoreload
%autoreload 2

In [None]:
db_dir = Path("/home/wenh06/Jupyter/Hot-data/cinc2025/")
# db_dir = Path("/home/wenh06/Jupyter/Hot-data/cinc2025-test/")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Train models

In [None]:
TEST_FLAG = False

if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CINC2025Dataset.__DEBUG__ = False
CRNN_CINC2025.__DEBUG__ = False
FM_CINC2025.__DEBUG__ = False
CINC2025Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
train_config.db_dir = db_dir
train_config.debug = True

train_config.n_epochs = 20
train_config.batch_size = 192  # 16G (Tesla T4)
# train_config.log_step = 20

# for CRNN only, comment if using FM_CINC2025
train_config.learning_rate = 3e-4  # 5e-4, 1e-3
train_config.lr = train_config.learning_rate
train_config.max_lr = 9e-4


train_config.lr_scheduler = "one_cycle"
train_config.early_stopping.patience = train_config.n_epochs // 3

train_config.extra_experiment = True
train_config.subsample = 0.01  # 0.01, 0.10, 0.5, 1.0

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)
# model_config

In [None]:
# change model architecture
print(model_config.crnn.cnn.keys())
print(f"{model_config.crnn.cnn.name=}")
# model_config.crnn.cnn.name = "tresnetM"

In [None]:
model_config.crnn.dem_encoder.mode = "film"  # concat
model = CRNN_CINC2025(config=model_config.crnn)

# model_config.fm.dem_encoder.mode = "film"  # concat
# model_config.fm.name = "st-mem"
# model_config.fm.backbone_cache_dir = "/home/wenh06/Jupyter/models/ST-MEM/st_mem_vit_base_encoder.pth"
# model = FM_CINC2025(config=model_config.fm)

model = model.to(device=DEVICE)
if isinstance(model, DP):
    print("model size:", model.module.module_size, model.module.module_size_)
    print("Using devices:", model.device_ids)
else:
    print("model size:", model.module_size, model.module_size_)
    print("Using device:", model.device)

In [None]:
ds_train = CINC2025Dataset(train_config, training=True, lazy=True)
ds_val = CINC2025Dataset(train_config, training=False, lazy=True)

if isinstance(model, FM_CINC2025):
    print("Using FM_CINC2025 model, adjusting fs and input_len")
    ds_train.reset_resample_fs(model_config.fm.fs[model_config.fm.name], reload=False)
    ds_train.reset_input_len(model_config.fm.input_len[model_config.fm.name], reload=False)
    ds_val.reset_resample_fs(model_config.fm.fs[model_config.fm.name], reload=False)
    ds_val.reset_input_len(model_config.fm.input_len[model_config.fm.name], reload=False)

In [None]:
trainer = CINC2025Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)
# trainer._setup_dataloaders(ds_val, None)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
del trainer, model, best_model_state_dict, ds_train, ds_val
torch.cuda.empty_cache()

In [None]:
scores = [
    0.44057780695994747,
    0.4556795797767564,
    0.45502298095863425,
    0.45436638214051217,
    0.45370978332239004,
    0.44911359159553516,
]

## Evaluate

In [None]:
from torch_ecg._preprocessors import PreprocManager
from torch_ecg.utils import make_serializable

from helper_code import compute_accuracy, compute_auc, compute_challenge_score, compute_f_measure
from team_code import run_model

In [None]:
def load_model(
    model_path: Union[str, bytes, os.PathLike], model_arch: Literal["FM", "CRNN"], verbose: bool = True
) -> Dict[str, Union[dict, torch.nn.Module, PreprocManager]]:
    """Load the trained models.

    Parameters
    ----------
    model_path : `path_like`
        The path to the trained model.
    model_arch : {"FM", "CRNN"},
        Model architecture.
    verbose : bool
        Whether to display progress information.

    Returns
    -------
    model : Dict[str, Union[dict, nn.Module, PreprocManager]]
        The trained model, its training configurations and the preprocessor manager
        inferred from the training configurations.

    """
    model_path = Path(model_path).expanduser().resolve()

    print("Loading the trained model...")

    model_cls = CRNN_CINC2025 if model_arch == "CRNN" else FM_CINC2025
    model, train_config = model_cls.from_checkpoint(model_path)
    model.to(DEVICE)
    if isinstance(model, CRNN_CINC2025):
        print("Using CRNN_CINC2025 model.")
        train_config.fs = model.config.fs
        train_config.resample.fs = model.config.fs
    elif isinstance(model, FM_CINC2025):
        print("Using FM_CINC2025 model.")
        train_config.fs = model.config.fs[model.config.name]
        train_config.resample.fs = model.config.fs[model.config.name]
    else:
        raise ValueError("Unsupported model class.")
    ppm_config = CFG(random=False)
    ppm_config.update(deepcopy(train_config))
    ppm = PreprocManager.from_config(ppm_config)

    print(f"Chagas classification model loaded from {str(model_path)}")

    return {"model": model, "train_config": train_config, "preprocessor": ppm}

In [None]:
ext_expr_dir = Path("./saved_models/extra-experiments/").resolve()
ext_expr_models = sorted([str(item) for item in ext_expr_dir.glob("BestModel*.pth.tar")])
print(f"Found {len(ext_expr_models)} models")

In [None]:
train_config = deepcopy(TrainCfg)
train_config.db_dir = db_dir
train_config.debug = True
train_config.extra_experiment = True

In [None]:
# ds_test = CINC2025Dataset(train_config, training=False, lazy=True, part="test")
# df_test = ds_test.reader._df_records.loc[ds_test.records]
# labels = df_test.chagas.astype(int).values

# print(f"Found {len(labels)} test samples")

import gzip
import json

tqdm.pandas(dynamic_ncols=True)

PROJECT_DIR = "./"
with gzip.open(Path(PROJECT_DIR) / "utils" / "code-15-data-split-64-16-20.json.gz", "rt") as f:
    code_15_data_split = json.load(f)
with gzip.open(Path(PROJECT_DIR) / "utils" / "ptb-xl-data-split-64-16-20.json.gz", "rt") as f:
    ptb_xl_data_split = json.load(f)
with gzip.open(Path(PROJECT_DIR) / "utils" / "sami-trop-data-split-64-16-20.json.gz", "rt") as f:
    sami_trop_data_split = json.load(f)

test_records = code_15_data_split["test"] + [item + "_hr" for item in ptb_xl_data_split["test"]] + sami_trop_data_split["test"]

all_records = find_records(db_dir)

df_test = pd.DataFrame(all_records, columns=["path"])
df_test["path"] = df_test["path"].progress_apply(lambda s: db_dir / s)
df_test["record"] = df_test["path"].progress_apply(lambda s: s.name)
df_test = df_test.set_index("record")
df_test = df_test.loc[test_records]
df_test["chagas"] = df_test["path"].progress_apply(lambda s: load_label(str(s)))

labels = df_test["chagas"].values

print(f"Found {len(df_test)} test samples")

In [None]:
pattern = r"subsample-(\d+%)-([A-Z]+)_"
eval_results = []

for model_path in tqdm(ext_expr_models[: len(ext_expr_models) // 3], desc="Evaluating models", dynamic_ncols=True):
    filename = Path(model_path).name
    match = re.search(pattern, filename)

    if match:
        subsample_ratio = match.group(1)
        model_arch = match.group(2)
        print(f"{subsample_ratio=}, {model_arch=}")
        model = load_model(model_path, model_arch)
        binary_outputs = np.zeros_like(labels)
        probability_outputs = np.zeros_like(labels).astype(float)
        for idx, row in tqdm(
            enumerate(df_test.itertuples(index=False)), total=len(df_test), dynamic_ncols=True, desc="Evaluating test samples"
        ):
            binary_output, probability_output = run_model(str(row.path), model, verbose=False)
            binary_outputs[idx] = int(binary_output)
            probability_outputs[idx] = probability_output
        challenge_score = compute_challenge_score(labels, probability_outputs)
        auroc, auprc = compute_auc(labels, probability_outputs)
        accuracy = compute_accuracy(labels, binary_outputs)
        f_measure = compute_f_measure(labels, binary_outputs)

        eval_results.append(
            make_serializable(
                {
                    "challenge_score": challenge_score,
                    "auroc": auroc,
                    "auprc": auprc,
                    "accuracy": accuracy,
                    "f_measure": f_measure,
                    "subsample_ratio": subsample_ratio,
                    "model_arch": model_arch,
                }
            )
        )

In [None]:
import json

In [None]:
Path("./saved_models/ext-expr-eval-results-1.json").write_text(json.dumps(eval_results))

## Results analyze and visualize

In [None]:
df_results = dict()

In [None]:
df_results["resnet-nc"] = pd.read_csv(
    "./results/TorchECG_02-21_17-38_CRNN_CINC2025_resnet_nature_comm_bottle_neck_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["resnet-nc-se"] = pd.read_csv(
    "./results/TorchECG_02-22_01-33_CRNN_CINC2025_resnet_nature_comm_bottle_neck_se_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-m"] = pd.read_csv(
    "./results/TorchECG_02-22_09-09_CRNN_CINC2025_tresnetM_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-n"] = pd.read_csv(
    "./results/TorchECG_02-22_01-38_CRNN_CINC2025_tresnetN_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-f"] = pd.read_csv(
    "./results/TorchECG_02-22_09-13_CRNN_CINC2025_tresnetF_adamw_amsgrad_LR_0.0001_BS_128.csv"
)

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "train"
metric = "challenge_score"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "val"
metric = "challenge_score"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "train"
metric = "chagas_f_measure"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "val"
metric = "chagas_f_measure"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")