In [None]:
import os
import pickle
import sys
import time
from copy import deepcopy
from pathlib import Path
from typing import Dict, Sequence, Tuple, Union

import numpy as np
import torch
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP  # noqa: F401
from torch_ecg.cfg import CFG
from torch_ecg.utils.misc import str2bool
from tqdm.auto import tqdm

from cfg import ModelCfg, TrainCfg
from dataset import CinC2024Dataset
from models import ECGWaveformDetector, ECGWaveformDigitizer, MultiHead_CINC2024
from trainer import CINC2024Trainer
from utils.misc import view_image_with_bbox

# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/torch_ecg/")
# sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/bib_lookup/")




%load_ext autoreload
%autoreload 2

In [None]:
db_dir = Path("/home/wenh06/Jupyter/wenhao/Hot-Data/cinc2024/ptb-xl/")
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

## Object detection model

In [None]:
TEST_FLAG = False

if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2024Dataset.__DEBUG__ = False
ECGWaveformDetector.__DEBUG__ = False
CINC2024Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True
train_config.predict_dx = False

train_config.bbox_mode = "merge_horizontal"

train_config.db_dir = db_dir

train_config.n_epochs = 12
train_config.batch_size = 10  # 16G (Tesla T4)
# train_config.log_step = 20
train_config.learning_rate = 2e-5  # 5e-4, 1e-3
train_config.lr = train_config.learning_rate
train_config.max_lr = 7e-5
train_config.early_stopping.patience = train_config.n_epochs // 3

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)

model_config.object_detection.bbox_mode = train_config.bbox_mode

In [None]:
model = ECGWaveformDetector(config=model_config)
# if torch.cuda.device_count() > 1:
#     model = DP(model)
#     # model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
ds_train = CinC2024Dataset(train_config, training=True, lazy=True)
ds_test = CinC2024Dataset(train_config, training=False, lazy=True)

In [None]:
len(ds_train), len(ds_test)

In [None]:
data = ds_train[0]
view_image_with_bbox(
    data["image"],
    data["bbox"]["annotations"],
    fmt=ds_train.config.bbox_format,
    cat_names=model_config.object_detection.class_names,
)

In [None]:
trainer = CINC2024Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)
# trainer._setup_dataloaders(ds_test, ds_train)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
del trainer, model, best_model_state_dict
torch.cuda.empty_cache()

## Segmentation (Digitization) model

In [None]:
TEST_FLAG = False

if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2024Dataset.__DEBUG__ = False
ECGWaveformDetector.__DEBUG__ = False
CINC2024Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True
train_config.predict_dx = False
train_config.predict_bbox = False
train_config.predict_mask = True
train_config.roi_only = True
train_config.roi_padding = 0.0

train_config.db_dir = db_dir

train_config.n_epochs = 10
train_config.batch_size = 6  # 16G (Tesla T4)
train_config.log_step = 120
train_config.learning_rate = 4e-5  # 5e-4, 1e-3
train_config.lr = train_config.learning_rate
train_config.max_lr = 9e-5
train_config.early_stopping.patience = train_config.n_epochs // 2

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)

In [None]:
model = ECGWaveformDigitizer(config=model_config)
# if torch.cuda.device_count() > 1:
#     model = DP(model)
#     # model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
ds_train = CinC2024Dataset(train_config, training=True, lazy=True)
ds_test = CinC2024Dataset(train_config, training=False, lazy=True)

In [None]:
len(ds_train), len(ds_test)

In [None]:
data = ds_train[0]
view_image_with_bbox(data["image"], mask=data["mask"])

In [None]:
trainer = CINC2024Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)
# trainer._setup_dataloaders(ds_test, ds_train)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
model.inference(ds_test[10]["image"], show=True)

In [None]:
del trainer, model, best_model_state_dict
torch.cuda.empty_cache()

## Dx model

In [None]:
TEST_FLAG = False

if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2024Dataset.__DEBUG__ = False
MultiHead_CINC2024.__DEBUG__ = False
CINC2024Trainer.__DEBUG__ = False

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True

train_config.db_dir = db_dir

train_config.n_epochs = 25
train_config.batch_size = 17  # 16G (Tesla T4)
# train_config.log_step = 20
train_config.learning_rate = 5e-5  # 5e-4, 1e-3
train_config.lr = train_config.learning_rate
train_config.max_lr = 0.0001
train_config.early_stopping.patience = train_config.n_epochs // 3

train_config.predict_dx = True
train_config.predict_bbox = False
train_config.predict_mask = False
train_config.roi_only = False
train_config.roi_padding = 0.0

train_config.backbone_freeze = False

# augmentations configurations
# TODO: add augmentation configs

model_config = deepcopy(ModelCfg)
model_config.backbone_name = "facebook/convnextv2-nano-22k-384"
model_config.backbone_freeze = train_config.backbone_freeze

model_config.backbone_input_size = {"shortest_edge": 768}

In [None]:
model = MultiHead_CINC2024(config=model_config)
# if torch.cuda.device_count() > 1:
#     model = DP(model)
# model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
ds_train = CinC2024Dataset(train_config, training=True, lazy=True)
ds_test = CinC2024Dataset(train_config, training=False, lazy=True)

In [None]:
len(ds_train), len(ds_test)

In [None]:
view_image_with_bbox(ds_train[0]["image"])

In [None]:
trainer = CINC2024Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)
# trainer._setup_dataloaders(ds_test, ds_train)

In [None]:
best_model_state_dict = trainer.train()

In [None]:
trainer.log_manager.flush()
trainer.log_manager.close()

In [None]:
del trainer, model, best_model_state_dict
torch.cuda.empty_cache()