In [1]:
import yaml
config_p = yaml.load(open("config/ICASSP/preprocess.yaml", "r"), Loader=yaml.FullLoader)
config_m = yaml.load(open("config/ICASSP/model.yaml", "r"), Loader=yaml.FullLoader)
config_t = yaml.load(open("config/ICASSP/train.yaml", "r"), Loader=yaml.FullLoader)
configs = (config_p, config_m, config_t)

In [44]:
import sys
sys.path.append("./scripts")
from scripts.utils.model import get_model, get_vocoder
import torch
import json

with open("preprocessed_data/RWCP-SSD/latest/stats.json", "r") as f:
    stats_param = json.load(f)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(200000, configs, DEVICE)
vocoder = get_vocoder(config_m, DEVICE)

restore_step 200000
Removing weight norm...


In [3]:
from scripts.dataset import Dataset
dataset_ = Dataset("test.txt", config_p, config_t, config_m)

In [4]:
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2

def pil2cv(pil_im, color=False):
    ''' PIL型 -> OpenCV型 '''
    new_image = np.array(pil_im, dtype=np.uint8)
    if new_image.ndim == 2:  # モノクロ
        pass
    elif new_image.shape[2] == 3:  # カラー
        if color:
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
        else:
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2GRAY)
    elif new_image.shape[2] == 4:  # 透過
        if color:
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
        else:
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2GRAY)
    return new_image


def img_pad(im, max_width):
    def add_margin(pil_img, top, right, bottom, left, color):
        width, height = pil_img.size
        new_width = width + right + left
        new_height = height + top + bottom
        result = Image.new(pil_img.mode, (new_width, new_height), color)
        result.paste(pil_img, (left, top))
        return result
    pad_left = (max_width - im.width)/2 + ((max_width - im.width)%2)
    pad_right = (max_width - im.width)/2

    return add_margin(im, 0, int(pad_right), 0, int(pad_left), (0,0,0))

In [20]:
# Katakana creation by gui
import solara

result = solara.reactive("")

@solara.component
def Katakana():
    def on_click_text(text):
        tmp = result.get()
        result.set(tmp + text)
    def on_click_reset():
        result.set("")
    def on_click_repeat():
        tmp = result.get()
        result.set(tmp + tmp)


    solara.Text(result.get())
    with solara.Columns([2]):
        # reset text
        solara.Button("Reset", on_click=lambda: on_click_reset())
        # repeat text
        solara.Button("Repeat", on_click=lambda: on_click_repeat())
    with solara.Columns([20]):
        solara.Button("ア", on_click=lambda: on_click_text("ア")), solara.Button("イ", on_click=lambda: on_click_text("イ")), solara.Button("ウ", on_click=lambda: on_click_text("ウ")), solara.Button("エ", on_click=lambda: on_click_text("エ")), solara.Button("オ", on_click=lambda: on_click_text("オ"))
        solara.Button("カ", on_click=lambda: on_click_text("カ")), solara.Button("キ", on_click=lambda: on_click_text("キ")), solara.Button("ク", on_click=lambda: on_click_text("ク")), solara.Button("ケ", on_click=lambda: on_click_text("ケ")), solara.Button("コ", on_click=lambda: on_click_text("コ"))
        solara.Button("サ", on_click=lambda: on_click_text("サ")), solara.Button("シ", on_click=lambda: on_click_text("シ")), solara.Button("ス", on_click=lambda: on_click_text("ス")), solara.Button("セ", on_click=lambda: on_click_text("セ")), solara.Button("ソ", on_click=lambda: on_click_text("ソ"))
        solara.Button("タ", on_click=lambda: on_click_text("タ")), solara.Button("チ", on_click=lambda: on_click_text("チ")), solara.Button("ツ", on_click=lambda: on_click_text("ツ")), solara.Button("テ", on_click=lambda: on_click_text("テ")), solara.Button("ト", on_click=lambda: on_click_text("ト"))
    with solara.Columns([20]):
        solara.Button("ナ", on_click=lambda: on_click_text("ナ")), solara.Button("ニ", on_click=lambda: on_click_text("ニ")), solara.Button("ヌ", on_click=lambda: on_click_text("ヌ")), solara.Button("ネ", on_click=lambda: on_click_text("ネ")), solara.Button("ノ", on_click=lambda: on_click_text("ノ"))
        solara.Button("ハ", on_click=lambda: on_click_text("ハ")), solara.Button("ヒ", on_click=lambda: on_click_text("ヒ")), solara.Button("フ", on_click=lambda: on_click_text("フ")), solara.Button("ヘ", on_click=lambda: on_click_text("ヘ")), solara.Button("ホ", on_click=lambda: on_click_text("ホ"))
        solara.Button("マ", on_click=lambda: on_click_text("マ")), solara.Button("ミ", on_click=lambda: on_click_text("ミ")), solara.Button("ム", on_click=lambda: on_click_text("ム")), solara.Button("メ", on_click=lambda: on_click_text("メ")), solara.Button("モ", on_click=lambda: on_click_text("モ"))
        solara.Button("ヤ", on_click=lambda: on_click_text("ヤ")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("ユ", on_click=lambda: on_click_text("ユ")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("ヨ", on_click=lambda: on_click_text("ヨ"))
    with solara.Columns([20]):
        solara.Button("ラ", on_click=lambda: on_click_text("ラ")), solara.Button("リ", on_click=lambda: on_click_text("リ")), solara.Button("ル", on_click=lambda: on_click_text("ル")), solara.Button("レ", on_click=lambda: on_click_text("レ")), solara.Button("ロ", on_click=lambda: on_click_text("ロ"))
        solara.Button("ワ", on_click=lambda: on_click_text("ワ")), solara.Button("ヲ", on_click=lambda: on_click_text("ヲ")), solara.Button("ン", on_click=lambda: on_click_text("ン")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text(""))
        solara.Button("ガ", on_click=lambda: on_click_text("ガ")), solara.Button("ギ", on_click=lambda: on_click_text("ギ")), solara.Button("グ", on_click=lambda: on_click_text("グ")), solara.Button("ゲ", on_click=lambda: on_click_text("ゲ")), solara.Button("ゴ", on_click=lambda: on_click_text("ゴ"))
        solara.Button("ザ", on_click=lambda: on_click_text("ザ")), solara.Button("ジ", on_click=lambda: on_click_text("ジ")), solara.Button("ズ", on_click=lambda: on_click_text("ズ")), solara.Button("ゼ", on_click=lambda: on_click_text("ゼ")), solara.Button("ゾ", on_click=lambda: on_click_text("ゾ"))
    with solara.Columns([20]):
        solara.Button("ダ", on_click=lambda: on_click_text("ダ")), solara.Button("ヂ", on_click=lambda: on_click_text("ヂ")), solara.Button("ヅ", on_click=lambda: on_click_text("ヅ")), solara.Button("デ", on_click=lambda: on_click_text("デ")), solara.Button("ド", on_click=lambda: on_click_text("ド"))
        solara.Button("バ", on_click=lambda: on_click_text("バ")), solara.Button("ビ", on_click=lambda: on_click_text("ビ")), solara.Button("ブ", on_click=lambda: on_click_text("ブ")), solara.Button("ベ", on_click=lambda: on_click_text("ベ")), solara.Button("ボ", on_click=lambda: on_click_text("ボ"))
        solara.Button("パ", on_click=lambda: on_click_text("パ")), solara.Button("ピ", on_click=lambda: on_click_text("ピ")), solara.Button("プ", on_click=lambda: on_click_text("プ")), solara.Button("ペ", on_click=lambda: on_click_text("ペ")), solara.Button("ポ", on_click=lambda: on_click_text("ポ"))
        solara.Button("ヴ", on_click=lambda: on_click_text("ヴ")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text(""))
    with solara.Columns([20]):
        solara.Button("ァ", on_click=lambda: on_click_text("ァ")), solara.Button("ィ", on_click=lambda: on_click_text("ィ")), solara.Button("ゥ", on_click=lambda: on_click_text("ゥ")), solara.Button("ェ", on_click=lambda: on_click_text("ェ")), solara.Button("ォ", on_click=lambda: on_click_text("ォ"))
        solara.Button("ャ", on_click=lambda: on_click_text("ャ")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("ュ", on_click=lambda: on_click_text("ュ")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("ョ", on_click=lambda: on_click_text("ョ"))
        solara.Button("ー", on_click=lambda: on_click_text("ー")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text(""))
        solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text("")), solara.Button("", on_click=lambda: on_click_text(""))

Katakana()

In [52]:
import solara
from IPython.display import Audio
import torchaudio
from scripts.utils.tools import to_device, plot_mel, expand

synth_onomatopoeia = solara.reactive("ピィピィ")
image_pil_show = solara.reactive(
    Image.open("./sample/ipaexg_24pt_c3-whistle3-000-0271-517-repeat2.png")
)
image_input = solara.reactive(None)
im_w = solara.reactive(None)
audio = solara.reactive(torchaudio.load("sample/tmp.wav")[0])
mel_img = solara.reactive(
    Image.open("./sample/tmp.png")
)
class_name = solara.reactive("whistle3")

@solara.component

def Page():
    def on_click_generate_visual_onomatopoeia():
        fs = config_p["visual_text"]["fontsize"]
        bgcolor = tuple(config_p["visual_text"]["color"]["background"])
        txtcolor = tuple(config_p["visual_text"]["color"]["text"])
        font = ImageFont.truetype(
            str(Path(config_p["path"]["font"])), 
            fs
        )
        canvas_width = len(synth_onomatopoeia.get()) * dataset_.width
        canvas = Image.new("RGB", (canvas_width, fs), (255, 255, 255))
        w = 0
        for char in synth_onomatopoeia.get():
            c_im = Image.new("RGB", (fs, fs), bgcolor)
            c_draw = ImageDraw.Draw(c_im)
            c_draw.text((0, 0), char, fill=txtcolor, font=font)
            c_im = img_pad(c_im, dataset_.width)
            canvas.paste(c_im, (w, 0))
            w += fs
        canvas_1ch = canvas.convert("L")
        image_input.set(canvas_1ch)
        canvas = Image.new("RGB", (fs*len(synth_onomatopoeia.get()), fs), (255, 255, 255))
        w = 0
        for char in synth_onomatopoeia.get():
            c_im = Image.new("RGB", (fs, fs), bgcolor)
            c_draw = ImageDraw.Draw(c_im)
            c_draw.text((0, 0), char, fill=txtcolor, font=font)
            canvas.paste(c_im, (w, 0))
            w += fs
        image_pil_show.set(canvas)

    def on_click_synthesize():
        onomatopoeia = synth_onomatopoeia.get()
        name = [onomatopoeia]
        class_id = np.array([dataset_.audiotype_map[class_name.get()]])
        text = np.array([[dataset_.symbol_to_id[t] for t in list(onomatopoeia)]])
        text_lens = np.array([len(onomatopoeia)])
        visualono = [image_input.get()]
        batch = (
            name,
            class_id,
            text,
            text_lens,
            max(text_lens),
            None, None, None, None, None, None,
            visualono, [None]
        )
        batch = to_device(batch, DEVICE)
        output = model(*(batch[1:]), config_t["use_image"])
        from scripts.utils.model import vocoder_infer
        wav = vocoder_infer(
            mels = output[1].detach().transpose(1,2),
            vocoder = vocoder,
            model_config=config_m,
            preprocess_config=config_p,
            Normalize=False
        )
        torchaudio.save("sample/tmp.wav", torch.tensor(wav), 22050)
        duration = output[5][0, :].detach().cpu().numpy()
        print(duration)
        energy_break = [duration[0]]
        for j in range(1,len(duration)-1):
            energy_break.append(energy_break[j-1]+duration[j])    
        energy = output[2][0, :].detach().cpu().numpy()
        energy = expand(energy, duration)
        mel_prediction = output[1][0, :].detach().transpose(0, 1)
        data = [[mel_prediction.cpu().numpy(), energy, energy_break]]
        stats = [stats_param["energy"][0], stats_param["energy"][1]]
        im_np = plot_mel(data, stats, [""])
        # save
        im = Image.fromarray(im_np)
        im.save("sample/tmp.png")

    solara.InputText(label="Your onomatopoeia", value=synth_onomatopoeia, continuous_update=True)
    with solara.Columns([7]):
        solara.Button("Generate visual onomatopoeia", on_click=on_click_generate_visual_onomatopoeia)
        solara.Button("<-")
        solara.Button("->")
        solara.Button("expand")
        solara.Button("shrink")
        solara.Button("expand all")
        solara.Button("shrink all")

    solara.Text("Visual onomatopoeia")
    solara.Image(image_pil_show.get())

    solara.Text("Sound class")
    labels = [label for label in dataset_.audiotype_map.keys()]
    solara.ToggleButtonsSingle(value=class_name, values=labels)
    solara.Button("synthesize", on_click=lambda: on_click_synthesize())
    solara.Text("synthesized environmental sound")   
    display(Audio(audio.get(), rate=22050))
    solara.Text("Mel spectrogram")
    solara.Image(mel_img.get())


Page()