In [None]:
"""modified from https://gist.github.com/endes0/0967d7c5bb1877559c4ae84be05e036c"""
from tika import parser

import torchaudio
import argparse
from sanitize_filename import sanitize
import re
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm

import torch
import json
from dataclasses import dataclass
# import pysbd
from typing import List
from loguru import logger
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter


In [None]:
# https://github.com/suno-ai/bark/blob/main/notebooks/long_form_generation.ipynb
from bark.generation import (
    generate_text_semantic,
    preload_models,
)
from bark.api import semantic_to_waveform
from bark import generate_audio, SAMPLE_RATE
SPEAKER = "v2/en_speaker_6"
silence = np.zeros(int(0.25 * SAMPLE_RATE))  # quarter second of silence

preload_models()



In [None]:

@dataclass
class Writer:
    out_dir: Path
    # tts: TTS
    
    def __post_init__(self):
        self.m3u = open(self.out_dir / 'playlist.m3u', 'w')
        self.m3u.write('#EXTM3U\n')
        self.chapter = 1

    def write_chapter(self, waveforms: torch.tensor, SAMPLE_RATE=24000):
        wav_f = self.out_dir / f'{self.chapter}.ogg'
        torchaudio.save(wav_f, waveforms.cpu(), SAMPLE_RATE)
        self.m3u.write(f'{wav_f}\n')
        self.chapter += 1
        return wav_f


    def close(self):
        self.m3u.close()

def split_into_sentences(text, tokenizer) -> List[str]:        
    limit = 200
    chunk_limit = limit
    splitter = RecursiveCharacterTextSplitter(
        length_function=lambda x: len(tokenizer.encode(x)),
        chunk_size=chunk_limit,
        chunk_overlap=0,
        keep_separator=True,
        strip_whitespace=True,
        separators=[
            "\n\n", "\n", "\xa0", '<div>', '<p>', '<br>', "\r", ".",  "!", "?", 
            '"', "'", "‘", "’", "“", "”", "„", "‟",  
            "(", ")", "[", "]", "{", "}", 
            "…", ":", ";", "—", "   "
            " ", '' # these ensure that there is always something to split by so chunks are always at limit
    ],
    )
    texts = splitter.split_text(text)
    ls = [splitter._length_function(x) for x in texts]
    logger.debug(f'split lengths {ls}. max={max(ls)} chunk_limit={chunk_limit}')
    assert all([l<=limit for l in ls]), 'all senteces should be below limit'
    return texts


In [None]:
__file__ = '../01_epub_tortise.ipynb'
root_dir = Path(__file__).resolve().absolute().parent
root_dir


In [None]:
# Get the command line arguments
parser2 = argparse.ArgumentParser()
parser2.add_argument('--epub', type=Path, 
                     default=root_dir/'data/A Short Guide to the Inner Citadel - Massimo Pigliucci.epub',
                    #  default=root_dir/'data/golden_saying_of_epictetus.epub',
                    help='PDF file to read')
parser2.add_argument('-o', '--out', type=Path, default=None, help='Output folder')
parser2.add_argument('-f', '--force', action='store_true', default=False, help='Overwrite')
parser2.add_argument('-t', '--test', action='store_true', default=False, help='Overwrite')
parser2.add_argument('-l', '--limit', type=int, default=400,
                    help='Maximum number of characters to synthesize at once')
parser2.add_argument('-m', '--model', type=str, 
                    default="tts_models/multilingual/multi-dataset/xtts_v1",
                    # default='facebook/fastspeech2-en-ljspeech',
                    help='fairseq model to use from HuggingFace Hub')
parser2.add_argument('-s', '--speaker', type=Path, default=root_dir / "data/speakers/donaldrobertson.wav",
                    help='Speaker wav to use from the model')
args = parser2.parse_args([])

if args.out is None:
    from datetime import datetime
    timestamp = datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')
    args.out = root_dir / 'out' / (sanitize(args.epub.stem).replace(' ', '_').lower() + timestamp)

# load epib
parsed = parser.from_file(str(args.epub))
text = parsed["content"]
if args.test:
    text = text[:1000]


# make output directory
out_dir = Path(args.out)
if out_dir.exists():
    if not args.force:
        logger.warning('Output folder already exists. Use -f to overwrite.')
        exit(1)
    else:
        for f in out_dir.glob('*'):
            f.unlink()
        out_dir.rmdir()
out_dir.mkdir()
logger.info(f'Output folder: {out_dir}')


In [None]:


# write metadata to dir
from json_tricks import dump, dumps, load, loads, strip_comments
f_metadata = out_dir / 'metadata.json'
with open(f_metadata, 'w') as fo:
    dump(dict(
        epub_metadata=parsed['metadata'],
        args=args.__dict__,
        
    ), fo, indent=4)

# should be torch tensors containing 22.05kHz waveform data.
# see https://github.com/neonbjb/tortoise-tts/blob/5bbb0e0b97ea2f62c12e90402e8ad4faee55e697/tortoise/api.py#L365C82-L365C140
ref, INPUT_SAMPLE_RATE = torchaudio.load(args.speaker)
reference_clips = [ref[..., -400000:]] # take just the last ~12 seconds

# load model
use_cuda = False if args.test else torch.cuda.is_available()
logger.info(f'use_cuda {use_cuda}')


# tts = TextToSpeech(use_deepspeed=True, kv_cache=True, half=True)


In [None]:
INPUT_SAMPLE_RATE
OUTPUT_SAMPLE_RATE = 24000


In [None]:
tokenizer = tts.tokenizer
segs = split_into_sentences(text, tokenizer)
waveforms = []
writer = Writer(out_dir)
for i, t in enumerate(tqdm(segs, desc='chunks')):
    t = t.replace('\n', ' ').strip()
    # Skip empty text
    if t == None or t == '':
        continue
    # check if contains words or numbers
    if not re.search('[a-zA-Z0-9]', t):
        logger.debug(f'Skipping text without words or numbers `{t}`')
        continue
    logger.debug(f'current sentence `{t}`')
    
    # wav_t = tts.tts_with_preset(t, voice_samples=reference_clips, preset='fast', verbose=i==0) # ultra_fast, fast, standard
    wav_t = generate_audio(t, history_prompt=SPEAKER)
    wav = wav_t.cpu()
    waveforms.append(wav)
    
    len_wav = sum([w.shape[-1] for w in waveforms])
    if len_wav > 10000000//4:  # ~20G of RAM, ~2 minutes of audio output, ~7 minutes to generate
        wavs = torch.concat(waveforms, dim=-1).cpu().squeeze(0)
        wav_f = writer.write_chapter(wavs, OUTPUT_SAMPLE_RATE)
        logger.warning(f"wrote chapter {wav_f}")
        waveforms = []
    
    1/0
        
if len(waveforms):  
    writer.write_chapter(waveforms)
writer.close()


In [None]:
from IPython.display import Audio


In [None]:
# Test

len_wav = sum([w.shape[-1] for w in waveforms])
print(len_wav)

wavs = torch.concat(waveforms, dim=-1).cpu().squeeze(0)
writer.write_chapter(wavs)
