In [None]:
import sys
# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/torch_ecg/")
# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/bib_lookup/")

import os
import pickle
import time
from copy import deepcopy
from pathlib import Path
from typing import Dict, Union, Tuple, Sequence

import numpy as np
import torch
from torch.nn.parallel import (  # noqa: F401
    DistributedDataParallel as DDP,
    DataParallel as DP,
)  # noqa: F401
from torch_ecg.cfg import CFG
from torch_ecg.utils.misc import str2bool
from tqdm.auto import tqdm

from cfg import TrainCfg, ModelCfg
from dataset import CinC2024Dataset
from models import MultiHead_CINC2024, ECGWaveformDetector
from trainer import CINC2024Trainer

%load_ext autoreload
%autoreload 2

## Object detection model

In [None]:
TEST_FLAG = False

DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2024Dataset.__DEBUG__ = False
ECGWaveformDetector.__DEBUG__ = False
CINC2024Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True
train_config.predict_dx = False

train_config.db_dir = "/home/wenh06/Jupyter/wenhao/Hot-Data/cinc2024/ptb-xl/"
# train_config.db_dir = "/home/wenh06/Jupyter/Hot-data/cinc2024/ptb-xl-subset/"

train_config.n_epochs = 5
train_config.batch_size = 16  # 16G (Tesla T4)
# train_config.log_step = 20
# # train_config.max_lr = 1.5e-3
train_config.early_stopping.patience = train_config.n_epochs // 3

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)

In [None]:
model = ECGWaveformDetector(config=model_config)
# if torch.cuda.device_count() > 1:
#     model = DP(model)
#     # model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
ds_train = CinC2024Dataset(train_config, training=True, lazy=True)
ds_test = CinC2024Dataset(train_config, training=False, lazy=True)

In [None]:
len(ds_train), len(ds_test)

In [None]:
trainer = CINC2024Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)
# trainer._setup_dataloaders(ds_test, ds_train)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trained_model = ECGWaveformDetector.from_checkpoint("./checkpoints/hf-facebook-convnextv2-atto-1k-224-dx__epoch0_08-04_20-04_epochloss_1962.66722_metric_0.71.pth.tar")[0]

In [None]:
output = trained_model.inference(ds_train[20]["image"], show=True)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8,12))
plt.show()

## Dx model

In [None]:
TEST_FLAG = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2024Dataset.__DEBUG__ = False
MultiHead_CINC2024.__DEBUG__ = False
CINC2024Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True

train_config.db_dir = "/home/wenh06/Jupyter/Hot-data/cinc2024/ptb-xl/"
# train_config.db_dir = "/home/wenh06/Jupyter/Hot-data/cinc2024/ptb-xl-subset/"

train_config.n_epochs = 25
train_config.batch_size = 32  # 16G (Tesla T4)
# train_config.log_step = 20
# # train_config.max_lr = 1.5e-3
train_config.early_stopping.patience = train_config.n_epochs // 3

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)
model_config.backbone_name = "facebook/convnextv2-nano-22k-384"

In [None]:
model = MultiHead_CINC2024(config=model_config)
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
ds_train = CinC2024Dataset(train_config, training=True, lazy=True)
ds_test = CinC2024Dataset(train_config, training=False, lazy=True)

In [None]:
len(ds_train), len(ds_test)

In [None]:
trainer = CINC2024Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)
# trainer._setup_dataloaders(ds_test, ds_train)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
list(Path("./saved_models/").iterdir())