In [None]:
import os
import pickle
import sys
import time
from copy import deepcopy
from pathlib import Path
from typing import Dict, Sequence, Tuple, Union

import numpy as np
import torch
import wfdb
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP  # noqa: F401
from torch_ecg.cfg import CFG
from torch_ecg.utils.misc import str2bool
from tqdm.auto import tqdm

from cfg import ModelCfg, TrainCfg
from dataset import CINC2025Dataset
from models import CRNN_CINC2025
from trainer import CINC2025Trainer

# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/torch_ecg/")
# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/bib_lookup/")




%load_ext autoreload
%autoreload 2

## Train models

In [None]:
db_dir = Path("/home/wenh06/Jupyter/wenhao/Hot-Data/cinc2025/")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
TEST_FLAG = False

if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CINC2025Dataset.__DEBUG__ = False
CRNN_CINC2025.__DEBUG__ = False
CINC2025Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
train_config.db_dir = db_dir
train_config.debug = True

train_config.n_epochs = 20
train_config.batch_size = 128  # 16G (Tesla T4)
# train_config.log_step = 20
train_config.learning_rate = 1e-4  # 5e-4, 1e-3
train_config.lr = train_config.learning_rate
train_config.max_lr = 6e-4
train_config.early_stopping.patience = train_config.n_epochs // 3

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)

In [None]:
# change model architecture

print(model_config.crnn.cnn.keys())
print(f"{model_config.crnn.cnn.name=}")
model_config.crnn.cnn.name = "tresnetM"

In [None]:
model = CRNN_CINC2025(config=model_config)
model = model.to(device=DEVICE)

if isinstance(model, DP):
    print("model size:", model.module.module_size, model.module.module_size_)
    print("Using devices:", model.device_ids)
else:
    print("model size:", model.module_size, model.module_size_)
    print("Using device:", model.device)

In [None]:
ds_train = CINC2025Dataset(train_config, training=True, lazy=True)
ds_val = CINC2025Dataset(train_config, training=False, lazy=True)

In [None]:
print(f"train size: {len(ds_train)}, val size: {len(ds_val)}")

In [None]:
trainer = CINC2025Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
del trainer, model, best_model_state_dict
torch.cuda.empty_cache()

## Debug

In [None]:
for idx, input_data in tqdm(enumerate(ds_val), total=len(ds_val)):
    pass

## Results analyze and visualize

In [None]:
df_results = dict()

In [None]:
df_results["resnet-nc"] = pd.read_csv(
    "./results/TorchECG_02-21_17-38_CRNN_CINC2025_resnet_nature_comm_bottle_neck_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["resnet-nc-se"] = pd.read_csv(
    "./results/TorchECG_02-22_01-33_CRNN_CINC2025_resnet_nature_comm_bottle_neck_se_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-m"] = pd.read_csv(
    "./results/TorchECG_02-22_09-09_CRNN_CINC2025_tresnetM_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-n"] = pd.read_csv(
    "./results/TorchECG_02-22_01-38_CRNN_CINC2025_tresnetN_adamw_amsgrad_LR_0.0001_BS_128.csv"
)
df_results["tresnet-f"] = pd.read_csv(
    "./results/TorchECG_02-22_09-13_CRNN_CINC2025_tresnetF_adamw_amsgrad_LR_0.0001_BS_128.csv"
)

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "train"
metric = "challenge_score"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "val"
metric = "challenge_score"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "train"
metric = "chagas_f_measure"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))

part = "val"
metric = "chagas_f_measure"

for k, df in df_results.items():
    df_metric = df[df.part == part][[metric, "epoch"]].dropna()
    ax.plot(df_metric.epoch, df_metric[metric], label=k)
ax.legend(loc="best")