In [1]:
cd ..

/home/kikohs/work/clearaudio


In [2]:
%load_ext autoreload
%autoreload 2



In [3]:
import io
import logging
import os
import matplotlib
import matplotlib.pyplot as plt

import torch
import torchaudio
import librosa
import librosa.display
import time
from pathlib import Path

from hydra import compose, initialize_config_dir
from omegaconf import DictConfig, open_dict  
from omegaconf import OmegaConf

from clearaudio.transforms import signal
from clearaudio.datasets import audio
from clearaudio.utils.utils import print_stats, play_audio_vscode
from clearaudio.viz import plots
from clearaudio import generate_wavenet

plt.style.use('dark_background')
[width, height] = matplotlib.rcParams['figure.figsize']
if width < 10:
  matplotlib.rcParams['figure.figsize'] = [width * 2.5, height]

LOG = logging.getLogger(__name__)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def load_checkpoint_data(checkpoint_path: str):
    ckp = Path(checkpoint_path).expanduser()
    if not ckp.exists():
        LOG.error('Checkpoint not found.')
        return None

    ckp = torch.load(ckp)    
    return ckp


def get_config_from_chekckpoint_or_def(checkpoint_path: str, 
    dataset_name: str = 'tan_scitas', 
    trainer_name: str = "tan_wavenet_scitas",
    generator_name: str = "tan_generator"):

    ckp = load_checkpoint_data(checkpoint_path)
    cfg = None
    # override current config with internal config from the checkpoint
    if 'config' not in ckp:
        conf_path = str(Path(os.getcwd()).resolve() / "clearaudio" / "conf")
        with initialize_config_dir(config_dir=conf_path, job_name="test_app"):
            cfg = compose(config_name="config", 
                          overrides=[f"dataset={dataset_name}", 
                            f"trainer={trainer_name}", f"generator={generator_name}"
                        ])
    else:
        cfg = OmegaConf.create(ckp['config'])

    return cfg, ckp

def get_eq_transform(cfg: DictConfig, eq_name: str = '0', subset_mode: str = 'train'):
    transforms = signal.SoxEffectTransform.from_config(cfg)
    eq_transform = list(filter(lambda x: x.name == eq_name, transforms[subset_mode or 'train']))

    if not eq_transform:
        LOG.error(f"{eq_name} not found in config: \n {cfg.dataset.low_quality_effect}")
        return None
    
    return eq_transform[0]


def create_lq_song_from_eq(song_path: str, output_folder: str, eq: signal.SoxEffectTransform):
    """Create low quality version of song using training EQs"""
    return Path(eq.process_file(song_path, output_folder, override=True))



In [7]:
chkp = "~/test_data/audio/tanbur/checkpoint_epoch_6_meq.pt"
cfg, ckp = get_config_from_chekckpoint_or_def(chkp)


In [9]:
# Create LQ version
for i in range(0, 10):
    eq = get_eq_transform(cfg, str(i))
    song1 = '~/data/audio/tanbur/DC44k/31 Zekrs_1.wav'
    song2 = '~/data/audio/tanbur/DC44k/18 Yadegari_1.wav'
    output_dir = Path("~/data/audio/tanbur/2022_03_24")
    lq_song_path = create_lq_song_from_eq(song1, output_dir, eq)
    generate_wavenet.generate(lq_song_path, output_dir, cfg, ckp)
    lq_song_path = create_lq_song_from_eq(song2, output_dir, eq)
    generate_wavenet.generate(lq_song_path, output_dir, cfg, ckp)

Last activation :  Mish()


  return F.conv1d(input, weight, bias, self.stride,
189it [01:16,  2.47it/s]


Last activation :  Mish()


181it [01:13,  2.46it/s]


Last activation :  Mish()


189it [01:16,  2.46it/s]


Last activation :  Mish()


181it [01:13,  2.46it/s]


Last activation :  Mish()


189it [01:19,  2.39it/s]


Last activation :  Mish()


181it [01:15,  2.39it/s]


Last activation :  Mish()


189it [01:18,  2.41it/s]


Last activation :  Mish()


181it [01:14,  2.44it/s]


Last activation :  Mish()


189it [01:17,  2.45it/s]


Last activation :  Mish()


181it [01:13,  2.46it/s]


Last activation :  Mish()


189it [01:17,  2.44it/s]


Last activation :  Mish()


181it [01:14,  2.43it/s]


Last activation :  Mish()


189it [01:18,  2.42it/s]


Last activation :  Mish()


181it [01:15,  2.41it/s]


Last activation :  Mish()


189it [01:17,  2.43it/s]


Last activation :  Mish()


181it [01:14,  2.43it/s]


Last activation :  Mish()


189it [01:17,  2.43it/s]


Last activation :  Mish()


181it [01:15,  2.41it/s]


Last activation :  Mish()


189it [01:19,  2.38it/s]


Last activation :  Mish()


181it [01:15,  2.39it/s]


In [56]:
generate_wavenet.generate(lq_song_path, output_dir, cfg, ckp)

Last activation :  Mish()


181it [02:40,  1.13it/s]
