
# Prepare your own dataset for DiffSinger (MIDI-less version)

## 1 Overview

This Jupyter Notebook will guide you to prepare your own dataset for DiffSinger with 44.1 kHz sampling rate.
Please read and follow the guidance carefully, take actions when there are notice for <font color="red">manual action</font> and pay attention to blocks marked with <font color="red">optional step</font>.

### 1.1 Introduction to this pipeline and MIDI-less version

This pipeline does not support customized phoneme dictionaries. It uses the [opencpop strict pinyin dictionary](../dictionaries/opencpop-strict.txt) by default.

MIDI-less version is a simplified version of DiffSinger where MIDI layers, word layers and slur layers are removed from the data labels. The model uses raw phoneme sequence with durations as input, and applies pitch embedding directly from the ground truth. Predictors for phoneme durations and pitch curve are also removed. Below are some limitations and advantages of the MIDI-less version:

- The model will not predict phoneme durations and f0 sequence by itself. You must specify `ph_dur` and `f0_seq` at inference time.
- Performance of pitch control will be better than MIDI-A version, because MIDI keys are misleading information for the diffusion decoder when f0 sequence is already embedded.
- MIDIs and slurs does not need to be labeled, thus the labeling work is easier than other versions.
- More varieties of data can be used as training materials, even including speech.

### 1.2 Install dependencies

Please run the following cell the first time you start this notebook.

**Note**: You should ensure you are in a Conda environment with Python 3.8 or 3.9 before you install dependencies of this pipeline.


In [None]:
!pip install -r requirements.txt
!conda install -c conda-forge montreal-forced-aligner --yes


### 1.3 Initializing environment

Please run the following cell every time you start this notebook.


In [None]:
import glob
import os
import shutil

import librosa
import matplotlib.pyplot as plt
import numpy as np
import parselmouth as pm
import soundfile
import textgrid as tg
import tqdm


def length(src: str):
    if os.path.isfile(src) and src.endswith('.wav'):
        return librosa.get_duration(filename=src) / 3600.
    elif os.path.isdir(src):
        total = 0
        for ch in [os.path.join(src, c) for c in os.listdir(src)]:
            total += length(ch)
        return total
    return 0


print('Environment initialized successfully.')


## 2 Raw recordings and audio slicing

### 2.1 Choose raw recordings

Your recordings must meet the following conditions:

1. They must be in one single folder. Files in sub-folders will be ignored.
2. They must be in WAV format.
3. They must have a sampling rate higher than 32 kHz.
4. They should contain only voices from human, and only one human, since multi-speaker training is not supported yet.
5. They should be clean voices with no significant noise or reverb.

<font color="red">Optional step</font>: The raw data must be sliced into parts of about 5-15 seconds. If you want to do this yourself, please skip to section 2.3. Otherwise, please edit paths in the following cell before you run it.


In [None]:
########################################

# Configuration for data paths
raw_path = r'path/to/your/raw/recordings'  # Path to your raw, unsliced recordings

########################################

assert os.path.exists(raw_path) and os.path.isdir(raw_path), 'The chosen path does not exist or is not a directory.'
print('Raw recording path:', raw_path)
print()
print('===== Recording List =====')
raw_filelist = glob.glob(f'{raw_path}/*.wav', recursive=True)
raw_length = length(raw_path)
if len(raw_filelist) > 5:
    print('\n'.join(raw_filelist[:5] + [f'... ({len(raw_filelist) - 5} more)']))
else:
    print('\n'.join(raw_filelist))
print()
print(f'Found {len(raw_filelist)} valid recordings with total length of {round(raw_length, 2)} hours.')


### 2.2 Audio slicing

We provide an audio slicer which automatically cuts recordings into short pieces.

The audio slicer is based on silence detection and has several arguments that have to be specified. You should modify these arguments according to your data.

For more details of each argument, see its [GitHub repository](https://github.com/openvpi/audio-slicer).

Please edit paths and arguments in the following cell before you run it.


In [None]:
########################################

# Configuration for data paths
sliced_path = r'path/to/your/sliced/recordings'  # Path to hold the sliced segments of your recordings

# Slicer arguments
db_threshold_ = -40.
min_length_ = 5000
win_l_ = 400
win_s_ = 20
max_silence_kept_ = 500

# Number of threads
num_workers = 5  # based on your CPU cores

########################################

assert 'raw_path' in locals().keys(), 'Raw path of your recordings has not been specified.'
assert not os.path.exists(sliced_path) or os.path.isdir(sliced_path), 'The chosen path is not a directory.'
os.makedirs(sliced_path, exist_ok=True)
print('Sliced recording path:', sliced_path)

from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED

from utils.slicer import Slicer


def slice_one(in_audio):
    audio, sr = librosa.load(in_audio, sr=None)
    slicer = Slicer(
        sr=sr,
        db_threshold=db_threshold_,
        min_length=min_length_,
        win_l=win_l_,
        win_s=win_s_,
        max_silence_kept=max_silence_kept_
    )
    chunks = slicer.slice(audio)
    for i, chunk in enumerate(chunks):
        soundfile.write(os.path.join(sliced_path, f'%s_slice_%04d.wav' % (os.path.basename(in_audio).rsplit('.', maxsplit=1)[0], i)), chunk, sr)


thread_pool = ThreadPoolExecutor(max_workers=num_workers)
tasks = []
for file in raw_filelist:
    tasks.append(thread_pool.submit(slice_one, file))
for task in tqdm.tqdm(tasks):
    wait([task], return_when=ALL_COMPLETED)
print()
print('===== Segment List =====')
sliced_filelist = glob.glob(f'{sliced_path}/*.wav', recursive=True)
sliced_length = length(sliced_path)
if len(sliced_filelist) > 5:
    print('\n'.join(sliced_filelist[:5] + [f'... ({len(sliced_filelist) - 5} more)']))
else:
    print('\n'.join(sliced_filelist))
print()
print(f'Sliced your recordings into {len(sliced_filelist)} segments with total length of {round(sliced_length, 2)} hours.')


### 2.3 Validating recording segments

In this section, we validate your recording segments.

<font color="red">Optional step</font>: If you skipped section 2.2, please specify the path to your sliced recordings in the following cell and run it. Otherwise, skip this cell.


In [None]:
########################################

# Configuration for data paths
sliced_path = r'path/to/your/sliced/recordings'  # Path to your sliced segments of recordings

########################################

assert os.path.exists(sliced_path) and os.path.isdir(sliced_path), 'The chosen path does not exist or is not a directory.'

print('Sliced recording path:', sliced_path)
print()
print('===== Segment List =====')
sliced_filelist = glob.glob(f'{sliced_path}/*.wav', recursive=True)
sliced_length = length(sliced_path)
if len(sliced_filelist) > 5:
    print('\n'.join(sliced_filelist[:5] + [f'... ({len(sliced_filelist) - 5} more)']))
else:
    print('\n'.join(sliced_filelist))
print()
print(f'Found {len(sliced_filelist)} valid segments with total length of {round(sliced_length, 2)} hours.')


Run the following cell to check if there are segments with an unexpected length (less than 2 seconds or more than 20 seconds).


In [None]:
reported = False
for file in tqdm.tqdm(sliced_filelist):
    wave_seconds = librosa.get_duration(filename=file)
    if wave_seconds < 2.:
        reported = True
        print(f'Too short! \'{file}\' has a length of {round(wave_seconds, 1)} seconds!')
    if wave_seconds > 20.:
        reported = True
        print(f'Too long! \'{file}\' has a length of {round(wave_seconds, 1)} seconds!')
if not reported:
    print('Congratulations! All segments have proper length.')


<font color="red">Manual action</font>: please consider removing segments too short and manually slicing segments to long, as reported above.

Move on when this is done or there are no segments reported.


## 3 Label your segments

### 3.1 Label syllable sequence

All segments should have their transcriptions (or lyrics) annotated. Run the following cell to see the example segment (from Opencpop dataset) and its corresponding annotation.


In [None]:
from IPython.display import Audio

# noinspection PyTypeChecker
display(Audio(filename='assets/2001000001.wav'))
with open('assets/2001000001.lab', 'r') as f:
    print(f.read())


<font color="red">Manual action</font>: now your task is to annotation transcriptions for each segment like the example shown above.

Each segment should have one annotation file with the same filename as it and `.lab` extension, and placed in the same directory. In the annotation file, you should write all syllables sung or spoken in this segment. Syllables should be split by space, and only syllables that appears in the dictionary are allowed. In addition, all phonemes in the dictionary should be covered in the annotations.

**Special notes**: `AP` and `SP` should not appear in the annotation.

**News**:  We developed [MinLabel](https://github.com/SineStriker/qsynthesis-revenge/tree/main/src/Test/MinLabel), a simple yet efficient tool to help finishing this step. You can download the binary executable for Windows [here](https://diffsinger-1307911855.cos.ap-beijing.myqcloud.com/label/minlabel_latest.zip).

<font color="red">Optional step</font>: if you want us to help you create all empty `lab` files (instead of creating them yourself), please run the following cell.


In [None]:
for file in tqdm.tqdm(sliced_filelist):
    filename = os.path.basename(file)
    name_without_ext = filename.rsplit('.', maxsplit=1)[0]
    annotation = os.path.join(sliced_path, f'{name_without_ext}.lab')
    if not os.path.exists(annotation):
        with open(annotation, 'a'):
            ...
print('Creating missing lab files done.')


Run the following cell to see if all segments are annotated and all annotations are valid. If there are failed checks, please fix them and run again.

A summary of your phoneme coverage will be generated. If there are some phonemes that have extremely few occurrences (for example, less than 20), it is highly recommended to add more recordings to cover these phonemes.


In [None]:
import utils.distribution as dist

# Load dictionary
dict_path = '../dictionaries/opencpop-strict.txt'
with open(dict_path, 'r', encoding='utf8') as f:
    rules = [ln.strip().split('\t') for ln in f.readlines()]
dictionary = {}
phoneme_set = set()
for r in rules:
    phonemes = r[1].split()
    dictionary[r[0]] = phonemes
    phoneme_set.update(phonemes)

# Run checks
check_failed = False
covered = set()
phoneme_map = {}
for ph in sorted(phoneme_set):
    phoneme_map[ph] = 0

segment_pairs = []

for file in tqdm.tqdm(sliced_filelist):
    filename = os.path.basename(file)
    name_without_ext = filename.rsplit('.', maxsplit=1)[0]
    annotation = os.path.join(sliced_path, f'{name_without_ext}.lab')
    if not os.path.exists(annotation):
        print(f'No annotation found for \'{filename}\'!')
        check_failed = True
        continue
    with open(annotation, 'r', encoding='utf8') as f:
        syllables = f.read().strip().split()
    if not syllables:
        print(f'Annotation file \'{annotation}\' is empty!')
        check_failed = True
    else:
        oov = []
        for s in syllables:
            if s not in dictionary:
                oov.append(s)
            else:
                for ph in dictionary[s]:
                    phoneme_map[ph] += 1
                covered.update(dictionary[s])
        if oov:
            print(f'Syllable(s) {oov} not allowed in annotation file \'{annotation}\'')
            check_failed = True

# Phoneme coverage
uncovered = phoneme_set - covered
if uncovered:
    print(f'The following phonemes are not covered!')
    print(sorted(uncovered))
    print('Please add more recordings to cover these phonemes.')
    check_failed = True

if not check_failed:
    print('Congratulations! All annotations are well prepared.')
    print('Here is a summary of your phoneme coverage.')

phoneme_list = sorted(phoneme_set)
phoneme_counts = [phoneme_map[ph] for ph in phoneme_list]
dist.draw_distribution(
    title='Phoneme Distribution Summary',
    x_label='Phoneme',
    y_label='Number of occurrences',
    items=phoneme_list,
    values=phoneme_counts
)
phoneme_summary = os.path.join(sliced_path, 'phoneme_distribution.jpg')
plt.savefig(fname=phoneme_summary,
            bbox_inches='tight',
            pad_inches=0.25)
plt.show()
print(f'Summary saved to \'{phoneme_summary}\'.')


### 3.2 Forced alignment

Given the transcriptions of each segment, we are able to align the phoneme sequence to its corresponding audio, thus obtaining position and duration information of each phoneme.

We use [Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to do forced phoneme alignment.

To run MFA alignment, please first run the following cell to resample all recordings to 16 kHz. The resampled recordings and copies of the phoneme labels will be saved at `pipelines/segments/`. Also, the folder `pipelines/textgrids/` will be created for temporarily storing aligned TextGrids.

<font color="yellow">Warning</font>: This will overwrite all files in `pipelines/segments/` and `pipelines/textgrids/`.


In [None]:
segments_dir = 'segments'
textgrids_dir = 'textgrids'
if os.path.exists(segments_dir):
    shutil.rmtree(segments_dir)
os.makedirs(segments_dir)
if os.path.exists(textgrids_dir):
    shutil.rmtree(textgrids_dir)
os.makedirs(textgrids_dir)
samplerate = 16000
for file in tqdm.tqdm(sliced_filelist):
    y, _ = librosa.load(file, sr=samplerate, mono=True)
    filename = os.path.basename(file)
    soundfile.write(os.path.join(segments_dir, filename), y, samplerate, subtype='PCM_16')
    name_without_ext = filename.rsplit('.', maxsplit=1)[0]
    annotation = os.path.join(sliced_path, f'{name_without_ext}.lab')
    shutil.copy(annotation, segments_dir)
print('Resampling and copying done.')


Run the following cell to download the MFA pretrained model and perform forced alignment. If the command fails, you can copy it into your terminal and run it manually.


In [None]:
import requests

mfa_zip = f'assets/mfa-opencpop-strict.zip'
mfa_uri = 'https://diffsinger-1307911855.cos.ap-beijing.myqcloud.com/mfa/mfa-opencpop-strict.zip'
if not os.path.exists(mfa_zip):
    # Download
    print('Model not found, downloading...')
    with open(mfa_zip, 'wb') as f:
        f.write(requests.get(mfa_uri).content)
    print('Done.')
else:
    print('Pretrained model already exists.')

segments_dir = 'segments'
textgrids_dir = 'textgrids'
os.makedirs(textgrids_dir, exist_ok=True)
print('\nRun the following command in your terminal manually if it fails here:')
print(f'mfa align pipelines/{segments_dir} {dict_path[3:]} pipelines/{mfa_zip} pipelines/{textgrids_dir} --beam 100 --clean --overwrite')

!mfa align $segments_dir $dict_path $mfa_zip $textgrids_dir --beam 100 --clean --overwrite


### 3.3 Optimize and finish the TextGrids

In this section, we run some scripts to reduce errors for long utterances and detect `AP`s which have not been labeled before. The optimized TextGrids can be saved for future use if you specify a backup directory in the following cell.

Edit the path and adjust arguments according to your needs in the following cell before you run it. Optimized results will be saved at `pipelines/textgrids/revised/`.


In [None]:
########################################

# Configuration for voice arguments based on your dataset
f0_min = 40.  # Minimum value of pitch
f0_max = 1100.  # Maximum value of pitch
br_len = 0.1  # Minimum length of aspiration in seconds
br_db = -60.  # Threshold of RMS in dB for detecting aspiration
br_centroid = 2000.  # Threshold of spectral centroid in Hz for detecting aspiration

# Other arguments, do not edit unless you understand them
time_step = 0.005  # Time step for feature extraction
min_space = 0.04  # Minimum length of space in seconds
voicing_thresh_vowel = 0.45  # Threshold of voicing for fixing long utterances
voicing_thresh_breath = 0.6  # Threshold of voicing for detecting aspiration
br_win_sz = 0.05  # Size of sliding window in seconds for detecting aspiration

########################################

# import utils.tg_optimizer as optimizer

textgrids_revised_dir = 'textgrids/revised'
os.makedirs(textgrids_revised_dir, exist_ok=True)
for wavfile in tqdm.tqdm(sliced_filelist):
    name = os.path.basename(wavfile).rsplit('.', maxsplit=1)[0]
    textgrid = tg.TextGrid()
    textgrid.read(os.path.join(textgrids_dir, f'{name}.TextGrid'))
    words = textgrid[0]
    phones = textgrid[1]
    sound = pm.Sound(wavfile)
    f0_voicing_breath = sound.to_pitch_ac(
        time_step=time_step,
        voicing_threshold=voicing_thresh_breath,
        pitch_floor=f0_min,
        pitch_ceiling=f0_max,
    ).selected_array['frequency']
    f0_voicing_vowel = sound.to_pitch_ac(
        time_step=time_step,
        voicing_threshold=voicing_thresh_vowel,
        pitch_floor=f0_min,
        pitch_ceiling=f0_max,
    ).selected_array['frequency']
    y, sr = librosa.load(wavfile, sr=24000, mono=True)
    hop_size = int(time_step * sr)
    spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_size).squeeze(0)

    # Fix long utterances
    i = j = 0
    while i < len(words):
        word = words[i]
        phone = phones[j]
        if word.mark is not None and word.mark != '':
            i += 1
            j += len(dictionary[word.mark])
            continue
        if i == 0:
            i += 1
            j += 1
            continue
        prev_word = words[i - 1]
        prev_phone = phones[j - 1]
        # Extend length of long utterances
        while word.minTime < word.maxTime - time_step:
            pos = min(f0_voicing_vowel.shape[0] - 1, int(word.minTime / time_step))
            if f0_voicing_vowel[pos] < f0_min:
                break
            prev_word.maxTime += time_step
            prev_phone.maxTime += time_step
            word.minTime += time_step
            phone.minTime += time_step
        i += 1
        j += 1

    # Detect aspiration
    i = j = 0
    while i < len(words):
        word = words[i]
        phone = phones[j]
        if word.mark is not None and word.mark != '':
            i += 1
            j += len(dictionary[word.mark])
            continue
        if word.maxTime - word.minTime < br_len:
            i += 1
            j += 1
            continue
        ap_ranges = []
        br_start = None
        win_pos = word.minTime
        while win_pos + br_win_sz <= word.maxTime:
            all_noisy = (f0_voicing_breath[int(win_pos / time_step) : int((win_pos + br_win_sz) / time_step)] < f0_min).all()
            rms_db = 20 * np.log10(np.clip(sound.get_rms(from_time=win_pos, to_time=win_pos + br_win_sz), a_min=1e-12, a_max=1))
            # print(win_pos, win_pos + br_win_sz, all_noisy, rms_db)
            if all_noisy and rms_db >= br_db:
                if br_start is None:
                    br_start = win_pos
            else:
                if br_start is not None:
                    br_end = win_pos + br_win_sz - time_step
                    if br_end - br_start >= br_len:
                        centroid = spectral_centroid[int(br_start / time_step) : int(br_end / time_step)].mean()
                        if centroid >= br_centroid:
                            ap_ranges.append((br_start, br_end))
                    br_start = None
                    win_pos = br_end
            win_pos += time_step
        if br_start is not None:
            br_end = win_pos + br_win_sz - time_step
            if br_end - br_start >= br_len:
                centroid = spectral_centroid[int(br_start / time_step) : int(br_end / time_step)].mean()
                if centroid >= br_centroid:
                    ap_ranges.append((br_start, br_end))
        # print(ap_ranges)
        if len(ap_ranges) == 0:
            i += 1
            j += 1
            continue
        words.removeInterval(word)
        phones.removeInterval(phone)
        if word.minTime < ap_ranges[0][0]:
            words.add(minTime=word.minTime, maxTime=ap_ranges[0][0], mark=None)
            phones.add(minTime=phone.minTime, maxTime=ap_ranges[0][0], mark=None)
            i += 1
            j += 1
        for k, ap in enumerate(ap_ranges):
            if k > 0:
                words.add(minTime=ap_ranges[k - 1][1], maxTime=ap[0], mark=None)
                phones.add(minTime=ap_ranges[k - 1][1], maxTime=ap[0], mark=None)
                i += 1
                j += 1
            words.add(minTime=ap[0], maxTime=min(word.maxTime, ap[1]), mark='AP')
            phones.add(minTime=ap[0], maxTime=min(word.maxTime, ap[1]), mark='AP')
            i += 1
            j += 1
        if ap_ranges[-1][1] < word.maxTime:
            words.add(minTime=ap_ranges[-1][1], maxTime=word.maxTime, mark=None)
            phones.add(minTime=ap_ranges[-1][1], maxTime=phone.maxTime, mark=None)
            i += 1
            j += 1

    # Remove short spaces
    i = j = 0
    while i < len(words):
        word = words[i]
        phone = phones[j]
        if word.mark is not None and word.mark != '':
            i += 1
            j += (1 if word.mark == 'AP' else len(dictionary[word.mark]))
            continue
        if word.maxTime - word.minTime >= min_space:
            word.mark = 'SP'
            phone.mark = 'SP'
            i += 1
            j += 1
            continue
        if i == 0:
            if len(words) >= 2:
                words[i + 1].minTime = word.minTime
                phones[j + 1].minTime = phone.minTime
                words.removeInterval(word)
                phones.removeInterval(phone)
            else:
                break
        elif i == len(words) - 1:
            if len(words) >= 2:
                words[i - 1].maxTime = word.maxTime
                phones[j - 1].maxTime = phone.maxTime
                words.removeInterval(word)
                phones.removeInterval(phone)
            else:
                break
        else:
            words[i - 1].maxTime = words[i + 1].minTime = (word.minTime + word.maxTime) / 2
            phones[j - 1].maxTime = phones[j + 1].minTime = (phone.minTime + phone.maxTime) / 2
            words.removeInterval(word)
            phones.removeInterval(phone)
    textgrid.write(os.path.join(textgrids_revised_dir, f'{name}.TextGrid'))


`TextGrid` saved in `pipelines/textgrids/revised` can be edited via [Praat](https://github.com/praat/praat). You may examine these files and fix label errors by yourself if you want a more accurate model with higher performance. However, this is not required since manual labeling takes much time.

Run the following cell to see summary of word-level pitch coverage of your dataset. (Data may not be accurate due to octave errors in pitch extraction.)


In [None]:
import utils.distribution as dist


def key_to_name(midi_key):
    note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
    return note_names[midi_key % 12] + str(midi_key // 12 - 1)


pit_map = {}
if not f0_min in locals():
    f0_min = 40.
if not f0_max in locals():
    f0_max = 1100.
if not voicing_thresh_vowel in locals():
    voicing_thresh_vowel = 0.45
for wavfile in tqdm.tqdm(sliced_filelist):
    name = os.path.basename(wavfile).rsplit('.', maxsplit=1)[0]
    textgrid = tg.TextGrid()
    textgrid.read(os.path.join(textgrids_revised_dir, f'{name}.TextGrid'))
    timestep = 0.01
    f0 = pm.Sound(wavfile).to_pitch_ac(
        time_step=timestep,
        voicing_threshold=voicing_thresh_vowel,
        pitch_floor=f0_min,
        pitch_ceiling=f0_max,
    ).selected_array['frequency']
    pitch = 12. * np.log2(f0 / 440.) + 69.
    for word in textgrid[0]:
        if word.mark in ['AP', 'SP']:
            continue
        if word.maxTime - word.minTime < timestep:
            continue
        word_pit = pitch[int(word.minTime / timestep) : int(word.maxTime / timestep)]
        word_pit = np.extract(word_pit >= 0, word_pit)
        if word_pit.shape[0] == 0:
            continue
        counts = np.bincount(word_pit.astype(np.int64))
        midi = counts.argmax()
        if midi in pit_map:
            pit_map[midi] += 1
        else:
            pit_map[midi] = 1
midi_keys = sorted(pit_map.keys())
midi_keys = list(range(midi_keys[0], midi_keys[-1] + 1))
dist.draw_distribution(
    title='Pitch Distribution Summary',
    x_label='Pitch',
    y_label='Number of occurrences',
    items=[key_to_name(k) for k in midi_keys],
    values=[pit_map.get(k, 0) for k in midi_keys]
)
pitch_summary = os.path.join(sliced_path, 'pitch_distribution.jpg')
plt.savefig(fname=pitch_summary,
            bbox_inches='tight',
            pad_inches=0.25)
plt.show()
print(f'Summary saved to \'{pitch_summary}\'.')


## 4 Building the final dataset

Congratulations! If you have gone through all sections above with success, it means that you are now prepared for building your final dataset. There are only a few steps to go before you can run scripts to train your own model.

### 4.1 Name and format your dataset

Please provide a unique name for your dataset, usually the name of the singer/speaker (whether real or virtual). For example, `opencpop` will be a good name for the dataset. You can also add tags to represent dataset version, model capacity or improvements. For example, `v2` represents the version, `large` represents the capacity, and `fix_br` means you fixed breaths since your trained last model.

Please edit the following cell before you run it. Remember only using letters, numbers and underlines (`_`).


In [None]:
########################################

# Name and tags of your dataset
dataset_name = '???'  # Required
dataset_tags = ''  # Optional

########################################

import random
import re

from textgrid import TextGrid

assert dataset_name != '', 'Dataset name cannot be empty.'
assert re.search(r'[^0-9A-Za-z_]', dataset_name) is None, 'Dataset name contains invalid characters.'
full_name = dataset_name
if dataset_tags != '':
    assert re.fullmatch(r'[^0-9A-Za-z_]', dataset_name) is None, 'Dataset tags contain invalid characters.'
    full_name += f'_{dataset_tags}'
assert not os.path.exists(f'../data/{full_name}'), f'The name \'{full_name}\' already exists in your \'data\' folder!'

print('Dataset name:', dataset_name)
if dataset_tags != '':
    print('Tags:', dataset_tags)

formatted_path = f'../data/{full_name}/raw/wavs'
os.makedirs(formatted_path)
transcriptions = []
samplerate = 44100
min_sil = int(0.1 * samplerate)
max_sil = int(2. * samplerate)
for wavfile in tqdm.tqdm(sliced_filelist):
    name = os.path.basename(wavfile).rsplit('.', maxsplit=1)[0]
    y, _ = librosa.load(wavfile, sr=samplerate, mono=True)
    tg = TextGrid()
    tg.read(os.path.join(textgrids_revised_dir, f'{name}.TextGrid'))
    ph_seq = [ph.mark for ph in tg[1]]
    ph_dur = [ph.maxTime - ph.minTime for ph in tg[1]]
    if random.random() < 0.5:
        len_sil = random.randrange(min_sil, max_sil)
        y = np.concatenate((np.zeros((len_sil,), dtype=np.float32), y))
        if ph_seq[0] == 'SP':
            ph_dur[0] += len_sil / samplerate
        else:
            ph_seq.insert(0, 'SP')
            ph_dur.insert(0, len_sil / samplerate)
    if random.random() < 0.5:
        len_sil = random.randrange(min_sil, max_sil)
        y = np.concatenate((y, np.zeros((len_sil,), dtype=np.float32)))
        if ph_seq[-1] == 'SP':
            ph_dur[-1] += len_sil / samplerate
        else:
            ph_seq.append('SP')
            ph_dur.append(len_sil / samplerate)
    ph_seq = ' '.join(ph_seq)
    ph_dur = ' '.join([str(round(d, 6)) for d in ph_dur])
    soundfile.write(os.path.join(formatted_path, f'{name}.wav'), y, samplerate)
    transcriptions.append(f'{name}|å•Š|{ph_seq}|rest|0|{ph_dur}|0')
with open(f'../data/{full_name}/raw/transcriptions.txt', 'w', encoding='utf8') as f:
    print('\n'.join(transcriptions), file=f)
print(f'All wavs and transcriptions saved at \'data/{full_name}/raw/\'.')


Now that the dataset and transcriptions have been saved, you can run the following cell to clean up all temporary files generated by pipelines above.

<font color="yellow">Warning</font>: This will remove `pipelines/segments/` and `pipelines/segments/` folders. You should specify a directory in the following cell to back up your TextGrids if you want them for future use.


In [None]:
########################################

# Optional path to back up your TextGrids
textgrids_backup_path = r''  # If left empty, the TextGrids will not be backed up

########################################

assert textgrids_backup_path == '' or not os.path.exists(textgrids_backup_path) or os.path.isdir(textgrids_backup_path), 'The backup path is not a directory.'

if textgrids_backup_path != '':
    os.makedirs(textgrids_backup_path, exist_ok=True)
    for tg in tqdm.tqdm(glob.glob(f'{textgrids_revised_dir}/*.TextGrid')):
        filename = os.path.basename(tg)
        shutil.copy(tg, os.path.join(textgrids_backup_path, filename))

shutil.rmtree(segments_dir)
shutil.rmtree(textgrids_dir)
print('Cleaning up done.')


### 4.2 Configuring parameters

Here you can configure some parameters for preprocessing, training and the neural networks. Read the explanations below and run the following cell.

#### `residual_channels` and `residual_layers`

These two hyperparameters refer to the width and the depth of the diffusion decoder network. Generally speaking, `384x20` represents a `base` model capacity and `512x20` represents a `large` model capacity. `384x30` is also a reasonable choice. Larger models consumes more GPU memory and runs slower at training and inference time, but they produce better results.

GPU memory required for training:

Base model - at least 6 GB (12 GB recommended)
Large model - at least 12 GB (24 GB recommended)

#### `test_prefixes`

All files with name prefixes specified in this list will be put into the test set. Each time when a checkpoint is saved, the program will first run inference on the test set and put the result on the TensorBoard. Thus, you can listen to these demos and judge the quality of your model. If you add less than 10 test cases, more cases will be randomly selected.

#### `max_tokens` and `max_sentences`

These two parameters jointly determine the batch size at training time, the former representing maximum number of frames in one batch and the latter limiting the maximum batch size. Larger batches consumes more GPU memory at training time. This value can be adjusted according to your GPU memory. Remember not to set this value too low because the model may not converge with small batches.

#### `lr` and `decay_steps`

These two values refer to the learning rate and number of steps everytime the learning rate decays. If you decreased your batch size, you may consider using a smaller learning rate and more decay steps.

#### `val_check_interval`, `num_ckpt_keep` and `max_updates`

These three values refer to the training steps between validating and saving checkpoints, the number of the most recent checkpoints reserved, and the maximum training steps. With default batch size and 5 hours of training data, 250k ~ 350k training steps is reasonable. If you decrease the batch size, you may increase the training steps.


In [None]:
########################################

residual_channels = 512
residual_layers = 20

test_prefixes = [

]

max_tokens = 80000
max_sentences = 48

lr = 0.0004
decay_steps = 50000

val_check_interval = 2000
num_ckpt_keep = 5
max_updates = 320000

########################################

import datetime
import random

import yaml

training_cases = [os.path.basename(w).rsplit('.', maxsplit=1)[0] for w in sliced_filelist]
valid_test_cases = []
i = 0
while i < len(training_cases):
    for prefix in test_prefixes:
        if training_cases[i].startswith(prefix):
            valid_test_cases.append(training_cases[i])
            training_cases.pop(i)
            i -= 1
            break
    i += 1
if len(valid_test_cases) < 10:
    test_prefixes += random.sample(training_cases, 10 - len(valid_test_cases))

configs = {
    'base_config': ['configs/naive/ds1000.yaml'],
    'raw_data_dir': f'data/{full_name}/raw',
    'binary_data_dir': f'data/{full_name}/binary',
    'residual_channels': residual_channels,
    'residual_layers': residual_layers,
    'test_prefixes': test_prefixes,
    'max_tokens': max_tokens,
    'max_sentences': max_sentences,
    'lr': lr,
    'decay_steps': decay_steps,
    'val_check_interval': val_check_interval,
    'num_ckpt_keep': num_ckpt_keep,
    'max_updates': max_updates,
}
with open(f'../data/{full_name}/config.yaml', 'w', encoding='utf8') as f:
    yaml.dump(configs, f, sort_keys=False, allow_unicode=True)

date = datetime.datetime.now().strftime('%m%d')
exp_name = f'{date}_{dataset_name}_ds1000'
if dataset_tags != '':
    exp_name += f'_{dataset_tags}'
print('Congratulations! All steps have been done and you are now prepared to train your own model.\n'
      'Before you start, please read and follow instructions in the repository README.\n'
      'Here are the commands for you to copy that you can run preprocessing and training:\n')

print('============ Linux ============\n'
      'export PYTHONPATH=.\n'
      'export CUDA_VISIBLE_DEVICES=0\n'
      f'python data_gen/binarize.py --config data/{full_name}/config.yaml\n'
      f'python run.py --config data/{full_name}/config.yaml --exp_name {exp_name} --reset\n')

print('===== Windows (PowerShell) =====\n'
      '$env:PYTHONPATH="."\n'
      '$env:CUDA_VISIBLE_DEVICES=0\n'
      f'python data_gen/binarize.py --config data/{full_name}/config.yaml\n'
      f'python run.py --config data/{full_name}/config.yaml --exp_name {exp_name} --reset\n')

print('===== Windows (Command Prompt) =====\n'
      'set PYTHONPATH=.\n'
      'set CUDA_VISIBLE_DEVICES=0\n'
      f'python data_gen/binarize.py --config data/{full_name}/config.yaml\n'
      f'python run.py --config data/{full_name}/config.yaml --exp_name {exp_name} --reset\n')

print(f'If you want to train your model on another machine (like a remote GPU), please copy the whole \'data/{full_name}/\' folder.')
