# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import gradio as gr
import numpy as np
import torch
import torchaudio
import random
import librosa
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav, logging
from cosyvoice.utils.common import set_all_random_seed

inference_mode_list = ['Pretrained Voice', '3s Fast Cloning', 'Cross‑lingual Cloning', 'Instruction Control', 'Voice Conversion']
instruct_dict = {
    'Pretrained Voice': '1. Select a pretrained voice\n2. Click Generate',
    '3s Fast Cloning': '1. Upload or record a short (≤30s) prompt audio\n2. Enter the matching prompt text\n3. Click Generate',
    'Cross‑lingual Cloning': '1. Upload or record a short (≤30s) prompt audio\n2. Enter text in a different language\n3. Click Generate',
    'Instruction Control': '1. Select a pretrained voice\n2. Enter instruction text\n3. Click Generate',
    'Voice Conversion': '1. Upload source audio (speech to convert)\n2. Upload target speaker audio (voice to imitate)\n3. Click Generate'
}
stream_mode_list = [('否', False), ('是', True)]
max_val = 0.8


def generate_seed():
    seed = random.randint(1, 100000000)
    return {
        "__type__": "update",
        "value": seed
    }


def postprocess(speech, top_db=60, hop_length=220, win_length=440):
    speech, _ = librosa.effects.trim(
        speech, top_db=top_db,
        frame_length=win_length,
        hop_length=hop_length
    )
    if speech.abs().max() > max_val:
        speech = speech / speech.abs().max() * max_val
    speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
    return speech


def change_instruction(mode_checkbox_group):
    return instruct_dict[mode_checkbox_group]


def generate_audio(tts_text, mode_checkbox_group, sft_dropdown,
                   prompt_text, prompt_wav_upload, prompt_wav_record,
                   instruct_text, seed, stream, speed,
                   source_wav_upload=None, target_wav_upload=None,
                   source_batch_upload=None, export_dir="exports"):


    # Pick prompt audio from upload or recording
    if prompt_wav_upload is not None:
        prompt_wav = prompt_wav_upload
    elif prompt_wav_record is not None:
        prompt_wav = prompt_wav_record
    else:
        prompt_wav = None

    # Instruction mode checks
    if mode_checkbox_group in ['Instruction Control']:
        if cosyvoice.instruct is False:
            gr.Warning('You are using Instruction Control mode, but {} model does not support it. '
                       'Please use iic/CosyVoice-300M-Instruct.'.format(args.model_dir))
            yield (cosyvoice.sample_rate, default_data)
        if instruct_text == '':
            gr.Warning('You are in Instruction Control mode, please enter instruction text.')
            yield (cosyvoice.sample_rate, default_data)
        if prompt_wav is not None or prompt_text != '':
            gr.Info('In Instruction Control mode, prompt audio/text will be ignored.')

    # Cross-lingual mode checks
    if mode_checkbox_group in ['Cross-lingual Cloning']:
        if cosyvoice.instruct is True:
            gr.Warning('You are using Cross-lingual Cloning mode, but {} model does not support it. '
                       'Please use iic/CosyVoice-300M.'.format(args.model_dir))
            yield (cosyvoice.sample_rate, default_data)
        if instruct_text != '':
            gr.Info('In Cross-lingual Cloning mode, instruction text will be ignored.')
        if prompt_wav is None:
            gr.Warning('In Cross-lingual Cloning mode, please provide prompt audio.')
            yield (cosyvoice.sample_rate, default_data)
        gr.Info('In Cross-lingual Cloning mode, make sure synthesis text and prompt text are in different languages.')

    # Zero-shot and cross-lingual require prompt audio
    if mode_checkbox_group in ['3s Fast Cloning', 'Cross-lingual Cloning']:
        if prompt_wav is None:
            gr.Warning('Prompt audio is empty — did you forget to provide it?')
            yield (cosyvoice.sample_rate, default_data)
        if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
            gr.Warning('Prompt audio sample rate {} is below required {}.'
                       .format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
            yield (cosyvoice.sample_rate, default_data)

    # Pretrained voice checks
    if mode_checkbox_group in ['Pretrained Voice']:
        if instruct_text != '' or prompt_wav is not None or prompt_text != '':
            gr.Info('In Pretrained Voice mode, prompt text/audio and instruction text will be ignored.')
        if sft_dropdown == '':
            gr.Warning('No pretrained voices available!')
            yield (cosyvoice.sample_rate, default_data)

    # 3s Fast Cloning checks
    if mode_checkbox_group in ['3s Fast Cloning']:
        if prompt_text == '':
            gr.Warning('Prompt text is empty — did you forget to enter it?')
            yield (cosyvoice.sample_rate, default_data)
        if instruct_text != '':
            gr.Info('In 3s Fast Cloning mode, pretrained voice/instruction text will be ignored.')

    # Actual inference branches
    if mode_checkbox_group == 'Pretrained Voice':
        logging.info('Running pretrained voice inference')
        set_all_random_seed(seed)
        for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())

    elif mode_checkbox_group == '3s Fast Cloning':
        logging.info('Running zero-shot inference')
        prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
        set_all_random_seed(seed)
        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k,
                                               stream=stream, speed=speed):
            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())

    elif mode_checkbox_group == 'Cross-lingual Cloning':
        logging.info('Running cross-lingual inference')
        prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
        set_all_random_seed(seed)
        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k,
                                                   stream=stream, speed=speed):
            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())


    elif mode_checkbox_group == 'Voice Conversion':
        logging.info('Running voice conversion inference')
        if target_wav_upload is None:
            gr.Warning('Please provide a target speaker audio file.')
            yield (cosyvoice.sample_rate, default_data.astype('float32'))
        else:
            target_speech_16k = postprocess(load_wav(target_wav_upload, prompt_sr))
            set_all_random_seed(seed)

            os.makedirs(export_dir, exist_ok=True)

            # Case 1: single source file
            if source_wav_upload is not None:
                source_speech_16k = postprocess(load_wav(source_wav_upload, prompt_sr))
                for i in cosyvoice.inference_vc(source_speech_16k, target_speech_16k,
                                                stream=stream, speed=speed):
                    audio_out = i['tts_speech'].numpy().astype('float32').flatten()
                    # Save to export dir
                    out_path = os.path.join(export_dir, "converted_single.wav")
                    torchaudio.save(out_path, torch.tensor(audio_out).unsqueeze(0), cosyvoice.sample_rate)
                    yield (cosyvoice.sample_rate, audio_out)

            # Case 2: batch of source files
            if source_batch_upload is not None:
                for file in source_batch_upload:
                    logging.info(f"Converting batch file: {file.name}")
                    source_speech_16k = postprocess(load_wav(file.name, prompt_sr))
                    for i in cosyvoice.inference_vc(source_speech_16k, target_speech_16k,
                                                    stream=stream, speed=speed):
                        audio_out = i['tts_speech'].numpy().astype('float32').flatten()
                        # Save each file with same basename
                        base = os.path.splitext(os.path.basename(file.name))[0]
                        out_path = os.path.join(export_dir, f"{base}_converted.wav")
                        torchaudio.save(out_path, torch.tensor(audio_out).unsqueeze(0), cosyvoice.sample_rate)
                        # Only yield the last one to UI (others are saved)
                        yield (cosyvoice.sample_rate, audio_out)




def main():
    with gr.Blocks() as demo:
        gr.Markdown("### Repository [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
                    Pretrained models: [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
                    [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
                    [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
        gr.Markdown("#### Enter text, choose inference mode, and follow the steps")

        tts_text = gr.Textbox(label="Input text", lines=1,
                              value="This is a demo of the CosyVoice generative speech model.")

        with gr.Row():
            mode_checkbox_group = gr.Radio(
                choices=inference_mode_list,
                label='Select inference mode',
                value=inference_mode_list[0]
            )
            instruction_text = gr.Text(
                label="Instructions",
                value=instruct_dict[inference_mode_list[0]],
                scale=0.5
            )
            sft_dropdown = gr.Dropdown(
                choices=sft_spk,
                label='Select pretrained voice',
                value=sft_spk[0],
                scale=0.25
            )
            stream = gr.Radio(
                choices=stream_mode_list,
                label='Streaming?',
                value=stream_mode_list[0][1]
            )
            speed = gr.Number(
                value=1,
                label="Speed (non-streaming only)",
                minimum=0.5,
                maximum=2.0,
                step=0.1
            )
            with gr.Column(scale=0.25):
                seed_button = gr.Button(value="\U0001F3B2")
                seed = gr.Number(value=0, label="Random seed")

        # Existing prompt audio inputs
        with gr.Row():
            prompt_wav_upload = gr.Audio(
                sources='upload',
                type='filepath',
                label='Upload prompt audio (≥16kHz)'
            )
            prompt_wav_record = gr.Audio(
                sources='microphone',
                type='filepath',
                label='Record prompt audio'
            )

        # NEW: Voice Conversion inputs
        with gr.Row():
            source_wav_upload = gr.Audio(
                sources='upload',
                type='filepath',
                label='Upload source audio (speech to convert)'
            )
            target_wav_upload = gr.Audio(
                sources='upload',
                type='filepath',
                label='Upload target speaker audio'
            )
        # Batch conversion: allow multiple source files
        source_batch_upload = gr.File(
            file_types=[".wav", ".flac", ".m4a", ".mp3", ".opus"],  # only allow wav files
            file_count="multiple",  # ✅ allow multiple uploads
            type="filepath",  # return file paths
            label="Upload multiple source audio files (batch conversion)"
        )

        # Export directory
        export_dir = gr.Textbox(
            label="Export directory",
            value="exports",
            placeholder="Where to save converted files"
        )

        prompt_text = gr.Textbox(
            label="Prompt text",
            lines=1,
            placeholder="Enter text matching the prompt audio...",
            value=''
        )
        instruct_text = gr.Textbox(
            label="Instruction text",
            lines=1,
            placeholder="Enter instruction text...",
            value=''
        )

        generate_button = gr.Button("Generate Audio")

        audio_output = gr.Audio(label="Output audio", autoplay=True, streaming=True)

        seed_button.click(generate_seed, inputs=[], outputs=seed)

        # UPDATED: added source_wav_upload and target_wav_upload to inputs
        generate_button.click(
            generate_audio,
            inputs=[tts_text, mode_checkbox_group, sft_dropdown,
                    prompt_text, prompt_wav_upload, prompt_wav_record,
                    instruct_text, seed, stream, speed,
                    source_wav_upload, target_wav_upload,
                    source_batch_upload, export_dir],
            outputs=[audio_output]
        )

        mode_checkbox_group.change(
            fn=change_instruction,
            inputs=[mode_checkbox_group],
            outputs=[instruction_text]
        )

    demo.queue(max_size=4, default_concurrency_limit=2)
    demo.launch(server_name='0.0.0.0', server_port=args.port)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--port',
                        type=int,
                        default=8000)
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_models/CosyVoice2-0.5B',
                        help='local path or modelscope repo id')
    args = parser.parse_args()
    try:
        cosyvoice = CosyVoice(args.model_dir)
    except Exception:
        try:
            cosyvoice = CosyVoice2(args.model_dir)
        except Exception:
            raise TypeError('no valid model_type!')

    sft_spk = cosyvoice.list_available_spks()
    if len(sft_spk) == 0:
        sft_spk = ['']
    prompt_sr = 16000
    default_data = np.zeros(cosyvoice.sample_rate)
    main()
