In [1]:
import os
import time
import shutil
import logging
import argparse

from trainer.trainer import trainer
from utils.hparams import HParam
from utils.writer import MyWriter
from datasets.dataloader import create_dataloader


parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/best.yaml',
                    help="folder contain yaml files for configuration")
parser.add_argument('--clean_rerun', type=bool, default=False,
                    help="remove old checkpoint and log. Default: false")
parser.add_argument('-r', '--resume', type=bool, default=False,
                    help="resume from checkpoint. Default: false")
args = parser.parse_args(["-c", "config/prof_pse_dccrn_stft_big.yaml"])

config = HParam(args.config)
exp = config["experiment"]
env = config["env"]

with open(args.config, 'r') as f:
    # store hparams as string
    hp_str = ''.join(f.readlines())


chkpt_dir = os.path.join(env.base_dir, env.log.chkpt_dir, exp.name)
log_dir = os.path.join(env.base_dir, env.log.log_dir, exp.name)

# Cleanup existed logs
if args.clean_rerun:
    shutil.rmtree(chkpt_dir, ignore_errors=True)
    shutil.rmtree(log_dir, ignore_errors=True)

os.makedirs(chkpt_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
    

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir,
            '%s-%d.log' % (exp.name, time.time()))),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

writer = MyWriter(exp.audio, log_dir)

logger.info("Start making validate set")
testloader = create_dataloader(config, scheme="eval")
logger.info("Start making train set")
trainloader = create_dataloader(config, scheme="train")

if args.resume:
    logger.info(f"Resume training from checkpoint {args.resume}")
    if args.resume == "backup":
        args.resume = os.path.join(chkpt_dir, 'backup.pt')
    exp["model"]["pretrained_chkpt"] = args.resume

2022-06-11 15:58:23,627 - INFO - Start making validate set
2022-06-11 15:58:23,678 - INFO - Start making train set


In [2]:
from torch.profiler import profile, record_function, ProfilerActivity

In [3]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
    trainer(exp, chkpt_dir, trainloader, testloader, writer, logger, hp_str)

2022-06-11 15:59:13,089 - INFO - Starting new training run
2022-06-11 15:59:13,872 - INFO - Wrote summary at step 1
2022-06-11 15:59:14,498 - INFO - Wrote summary at step 2
2022-06-11 15:59:15,105 - INFO - Wrote summary at step 3
2022-06-11 15:59:15,768 - INFO - Wrote summary at step 4
2022-06-11 15:59:16,425 - INFO - Wrote summary at step 5


In [4]:
print(prof.key_averages().table(sort_by="cpu_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          cudaHostAlloc        49.05%      322.399s        49.05%      322.399s        1.147s       0.000us         0.00%       0.000us       0.000us     -37.00 Mb     -37.00 Mb     -92.60 Mb     -92.60 M

In [5]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Source Location                                                              
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
     

In [4]:
prof.export_chrome_trace("trace_precomp_dvec.json")