<a href="https://colab.research.google.com/github/rezabonyadi/YuE/blob/main/YuE_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Git LFS (if not already installed)
!apt-get install git-lfs -y
!git lfs install

# Clone the main YuE repository
!git clone https://github.com/multimodal-art-projection/YuE.git

# Navigate into the inference directory and clone the xcodec_mini_infer repo
%cd YuE/inference
!git clone https://huggingface.co/m-a-p/xcodec_mini_infer
%cd ..
!pip install -r requirements.txt
%cd YuE/inference

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 18 not upgraded.
Git LFS initialized.
Cloning into 'YuE'...
remote: Enumerating objects: 369, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 369 (delta 101), reused 88 (delta 88), pack-reused 262 (from 1)[K
Receiving objects: 100% (369/369), 12.99 MiB | 9.66 MiB/s, done.
Resolving deltas: 100% (160/160), done.
/content/YuE/inference
Cloning into 'xcodec_mini_infer'...
remote: Enumerating objects: 203, done.[K
remote: Counting objects: 100% (199/199), done.[K
remote: Compressing objects: 100% (186/186), done.[K
remote: Total 203 (delta 8), reused 199 (delta 8), pack-reused 4 (from 1)[K
Receiving objects: 100% (203/203), 8.38 MiB | 1.14 MiB/s, done.
Resolving deltas: 100% (8/8), done.
Filtering content: 100

[Errno 2] No such file or directory: 'YuE/inference'
/content/YuE


In [1]:
# DO NOT RUN THIS before you run the above and restrat your session.
# This is to get FlashAttention to work on Colab. You need an older torch version.
# Make sure you restart your session again after running this.
!pip install torch=='2.4.1+cu121' torchvision=='0.19.1+cu121' torchaudio=='2.4.1+cu121' --index-url https://download.pytorch.org/whl/cu121
!pip install flash-attn
!pip install protobuf==3.20.1 # This is to fix the version problem of the one from YuE


Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m799.0/799.0 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m120.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.1+cu121)
  Downloading https://download.pytorch.org/whl/c

In [1]:
!ls

sample_data  YuE


In [2]:
%cd YuE/inference

/content/YuE/inference


In [3]:
import os
import sys

# Use the current working directory instead of __file__
base_path = os.getcwd()
sys.path.append(os.path.join(base_path, 'xcodec_mini_infer'))
sys.path.append(os.path.join(base_path, 'xcodec_mini_infer', 'descriptaudiocodec'))

import argparse

# Create a dummy args object with at least the required attributes.
dummy_args = argparse.Namespace(bw=4, cuda_idx=0)

import re
import random
import uuid
import copy
from tqdm import tqdm
from collections import Counter
import numpy as np
import torch
import torchaudio
from torchaudio.transforms import Resample
import soundfile as sf
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
from omegaconf import OmegaConf

from codecmanipulator import CodecManipulator
from mmtokenizer import _MMSentencePieceTokenizer
from models.soundstream_hubert_new import SoundStream
from vocoder import build_codec_model, process_audio
from post_process_audio import replace_low_freq_with_energy_matched


class BlockTokenRangeProcessor(LogitsProcessor):
    def __init__(self, start_id, end_id):
        self.blocked_token_ids = list(range(start_id, end_id))

    def __call__(self, input_ids, scores):
        scores[:, self.blocked_token_ids] = -float("inf")
        return scores

# Set random seed
def seed_everything(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_val)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Audio utilities
def load_audio_mono(filepath, sampling_rate=16000):
    audio, sr = torchaudio.load(filepath)
    audio = torch.mean(audio, dim=0, keepdim=True)
    if sr != sampling_rate:
        resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
        audio = resampler(audio)
    return audio

def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
    if len(audio_prompt.shape) < 3:
        audio_prompt = audio_prompt.unsqueeze(0)
    with torch.no_grad():
        raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=target_bw)
    raw_codes = raw_codes.transpose(0, 1)
    raw_codes = raw_codes.cpu().numpy().astype(np.int16)
    return raw_codes

def split_lyrics(lyrics_content):
    pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
    segments = re.findall(pattern, lyrics_content, re.DOTALL)
    structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
    return structured_lyrics

def stage2_generate(model, prompt, batch_size=16):
    codec_ids = codectool.unflatten(prompt, n_quantizer=1)
    codec_ids = codectool.offset_tok_ids(
        codec_ids,
        global_offset=codectool.global_offset,
        codebook_size=codectool.codebook_size,
        num_codebooks=codectool.num_codebooks,
    ).astype(np.int32)

    if batch_size > 1:
        codec_list = []
        for i in range(batch_size):
            idx_begin = i * 300
            idx_end = (i + 1) * 300
            codec_list.append(codec_ids[:, idx_begin:idx_end])
        codec_ids = np.concatenate(codec_list, axis=0)
        prompt_ids = np.concatenate(
            [
                np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
                codec_ids,
                np.tile([mmtokenizer.stage_2], (batch_size, 1)),
            ],
            axis=1
        )
    else:
        prompt_ids = np.concatenate([
            np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
            codec_ids.flatten(),
            np.array([mmtokenizer.stage_2])
        ]).astype(np.int32)
        prompt_ids = prompt_ids[np.newaxis, ...]
    codec_ids = torch.as_tensor(codec_ids).to(device)
    prompt_ids = torch.as_tensor(prompt_ids).to(device)
    len_prompt = prompt_ids.shape[-1]
    block_list = LogitsProcessorList([
        BlockTokenRangeProcessor(0, 46358),
        BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)
    ])

    for frames_idx in range(codec_ids.shape[1]):
        cb0 = codec_ids[:, frames_idx:frames_idx+1]
        prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
        input_ids = prompt_ids
        with torch.no_grad():
            stage2_output = model.generate(
                input_ids=input_ids,
                min_new_tokens=7,
                max_new_tokens=7,
                eos_token_id=mmtokenizer.eoa,
                pad_token_id=mmtokenizer.eoa,
                logits_processor=block_list,
            )
        assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, \
            f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
        prompt_ids = stage2_output

    if batch_size > 1:
        output = prompt_ids.cpu().numpy()[:, len_prompt:]
        output_list = [output[i] for i in range(batch_size)]
        output = np.concatenate(output_list, axis=0)
    else:
        output = prompt_ids[0].cpu().numpy()[len_prompt:]
    return output

def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
    stage2_result = []
    codectool_stage2 = CodecManipulator("xcodec", 0, 8)
    for file_path in tqdm(stage1_output_set):
        output_filename = os.path.join(stage2_output_dir, os.path.basename(file_path))
        if os.path.exists(output_filename):
            print(f'{output_filename} stage2 has done.')
            stage2_result.append(output_filename)
            continue
        prompt = np.load(file_path).astype(np.int32)
        output_duration = prompt.shape[-1] // 50 // 6 * 6
        num_batch = output_duration // 6
        if num_batch <= batch_size:
            output = stage2_generate(model, prompt[:, :output_duration * 50], batch_size=num_batch)
        else:
            segments = []
            num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
            for seg in range(num_segments):
                start_idx = seg * batch_size * 300
                end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
                current_batch_size = batch_size if (seg != num_segments-1 or num_batch % batch_size == 0) else num_batch % batch_size
                segment = stage2_generate(model, prompt[:, start_idx:end_idx], batch_size=current_batch_size)
                segments.append(segment)
            output = np.concatenate(segments, axis=0)
        if output_duration * 50 != prompt.shape[-1]:
            ending = stage2_generate(model, prompt[:, output_duration * 50:], batch_size=1)
            output = np.concatenate([output, ending], axis=0)
        output = codectool_stage2.ids2npy(output)
        fixed_output = copy.deepcopy(output)
        for i, line in enumerate(output):
            for j, element in enumerate(line):
                if element < 0 or element > 1023:
                    counter = Counter(line)
                    most_frequent = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
                    fixed_output[i, j] = most_frequent
        np.save(output_filename, fixed_output)
        stage2_result.append(output_filename)
    return stage2_result

# Helper to save audio files
def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale_flag: bool = False):
    folder_path = os.path.dirname(path)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    limit = 0.99
    max_val = wav.abs().max()
    wav = wav * min(limit / max_val, 1) if rescale_flag else wav.clamp(-limit, limit)
    torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)


In [4]:
import os
import torch
import numpy as np
import uuid
import soundfile as sf
from tqdm import tqdm
from omegaconf import OmegaConf

# Assume these functions/classes are imported or defined elsewhere:
#   - _MMSentencePieceTokenizer, AutoModelForCausalLM, CodecManipulator,
#   - seed_everything, split_lyrics, load_audio_mono, encode_audio,
#   - LogitsProcessorList, BlockTokenRangeProcessor, stage2_inference,
#   - save_audio, build_codec_model, process_audio, replace_low_freq_with_energy_matched

def validate_prompt_options(use_audio_prompt, audio_prompt_path,
                            use_dual_tracks_prompt, vocal_track_prompt_path, instrumental_track_prompt_path):
    """Validate that if prompts are enabled, file paths are provided."""
    if use_audio_prompt and not audio_prompt_path:
        raise FileNotFoundError("Audio prompt is enabled but no audio_prompt_path was provided!")
    if use_dual_tracks_prompt and (not vocal_track_prompt_path or not instrumental_track_prompt_path):
        raise FileNotFoundError("Dual tracks prompt is enabled but vocal_track_prompt_path and instrumental_track_prompt_path are not provided!")

def setup_output_directories(output_dir):
    """Set up and return directories for each pipeline stage."""
    stage1_output_dir = os.path.join(output_dir, "stage1")
    stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
    recons_output_dir = os.path.join(output_dir, "recons")
    recons_mix_dir = os.path.join(recons_output_dir, 'mix')
    os.makedirs(stage1_output_dir, exist_ok=True)
    os.makedirs(stage2_output_dir, exist_ok=True)
    os.makedirs(recons_mix_dir, exist_ok=True)
    return stage1_output_dir, stage2_output_dir, recons_output_dir, recons_mix_dir

def load_stage1_model(stage1_model, device):
    """Load and compile the Stage 1 model."""
    model = AutoModelForCausalLM.from_pretrained(
        stage1_model,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2"  # Note: may not work well on Colab
    )
    model.to(device)
    model.eval()
    if torch.__version__ >= "2.0.0":
        model = torch.compile(model)
    return model

def load_codec_model(basic_model_config, resume_path, device):
    """Load the codec model used for encoding/decoding audio."""
    model_config = OmegaConf.load(basic_model_config)
    codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
    parameter_dict = torch.load(resume_path, map_location='cpu', weights_only=False)
    codec_model.load_state_dict(parameter_dict['codec_model'])
    codec_model.to(device)
    codec_model.eval()
    return codec_model

def read_prompt_files(genre_txt, lyrics_txt):
    """Read the genre and lyrics prompt files and build the prompt texts."""
    with open(genre_txt, "r") as f:
        genres = f.read().strip()
    with open(lyrics_txt, "r") as f:
        lyrics_content = f.read()
    lyrics = split_lyrics(lyrics_content)
    full_lyrics = "\n".join(lyrics)
    # The first prompt is a full description; subsequent ones are segment-specific.
    prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
    prompt_texts += lyrics
    return genres, prompt_texts

def generate_stage1_outputs(prompt_texts, genres, model, mmtokenizer, codectool, codec_model, device,
                             use_audio_prompt, audio_prompt_path, use_dual_tracks_prompt,
                             vocal_track_prompt_path, instrumental_track_prompt_path,
                             prompt_start_time, prompt_end_time, max_new_tokens,
                             top_p, temperature, repetition_penalty, stage1_output_dir):
    """
    Process each prompt segment and generate Stage 1 outputs.
    This function builds the prompt tokens (handling audio/dual-track if enabled),
    runs generation, and then saves the raw output tokens into files.
    """
    raw_output = None
    stage1_output_set = []
    random_id = uuid.uuid4()

    # Here we assume the first prompt (index 0) is the head (full prompt)
    # and subsequent segments are generated one by one.
    run_n_segments = min(len(prompt_texts), len(prompt_texts))  # Adjust if needed
    for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
        section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
        guidance_scale = 1.5 if i <= 1 else 1.2

        # Skip the first prompt for generation (or handle it as needed)
        if i == 0:
            continue

        # Build prompt tokens with optional audio/dual-track inputs
        if i == 1:
            if use_dual_tracks_prompt or use_audio_prompt:
                if use_dual_tracks_prompt:
                    vocals_ids = load_audio_mono(vocal_track_prompt_path)
                    instrumental_ids = load_audio_mono(instrumental_track_prompt_path)
                    vocals_ids = encode_audio(codec_model, vocals_ids, device, target_bw=0.5)
                    instrumental_ids = encode_audio(codec_model, instrumental_ids, device, target_bw=0.5)
                    vocals_ids = codectool.npy2ids(vocals_ids[0])
                    instrumental_ids = codectool.npy2ids(instrumental_ids[0])
                    # Interleave the two tracks
                    ids_segment_interleaved = np.concatenate([np.array(vocals_ids), np.array(instrumental_ids)], axis=0)
                    audio_prompt_codec = ids_segment_interleaved[int(prompt_start_time * 50 * 2): int(prompt_end_time * 50 * 2)]
                    audio_prompt_codec = audio_prompt_codec.tolist()
                elif use_audio_prompt:
                    audio_prompt = load_audio_mono(audio_prompt_path)
                    raw_codes = encode_audio(codec_model, audio_prompt, device, target_bw=0.5)
                    code_ids = codectool.npy2ids(raw_codes[0])
                    audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
                audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
                sentence_ids = (mmtokenizer.tokenize("[start_of_reference]") +
                                audio_prompt_codec_ids +
                                mmtokenizer.tokenize("[end_of_reference]"))
                head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
            else:
                head_id = mmtokenizer.tokenize(prompt_texts[0])
            prompt_ids = (head_id +
                          mmtokenizer.tokenize('[start_of_segment]') +
                          mmtokenizer.tokenize(section_text) +
                          [mmtokenizer.soa] + codectool.sep_ids)
        else:
            prompt_ids = (mmtokenizer.tokenize('[end_of_segment]') +
                          mmtokenizer.tokenize('[start_of_segment]') +
                          mmtokenizer.tokenize(section_text) +
                          [mmtokenizer.soa] + codectool.sep_ids)

        prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
        # Concatenate with previous outputs if not the first generation segment
        input_ids = torch.cat([raw_output, prompt_ids], dim=1) if raw_output is not None else prompt_ids

        # Trim input if it exceeds the model context window
        max_context = 16384 - max_new_tokens - 1
        if input_ids.shape[-1] > max_context:
            print(f'Section {i}: input length {input_ids.shape[-1]} exceeds max context {max_context}, using last tokens.')
            input_ids = input_ids[:, -max_context:]

        with torch.no_grad():
            output_seq = model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                min_new_tokens=100,
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                eos_token_id=mmtokenizer.eoa,
                pad_token_id=mmtokenizer.eoa,
                logits_processor=LogitsProcessorList([
                    BlockTokenRangeProcessor(0, 32002),
                    BlockTokenRangeProcessor(32016, 32016)
                ]),
                guidance_scale=guidance_scale,
            )
            # Append eos token if missing
            if output_seq[0][-1].item() != mmtokenizer.eoa:
                tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
                output_seq = torch.cat((output_seq, tensor_eoa), dim=1)

        # Accumulate raw output for subsequent segments
        if raw_output is not None:
            raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
        else:
            raw_output = output_seq

    # Save the final raw output tokens as separate npy files for vocals and instrumentals
    ids = raw_output[0].cpu().numpy()
    soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
    eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
    if len(soa_idx) != len(eoa_idx):
        raise ValueError(f'Invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')

    vocals = []
    instrumentals = []
    # Adjust range based on whether an audio prompt was used
    range_begin = 1 if (use_audio_prompt or use_dual_tracks_prompt) else 0
    for i in range(range_begin, len(soa_idx)):
        codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
        if codec_ids[0] == 32016:
            codec_ids = codec_ids[1:]
        # Ensure even number of tokens to split into two channels
        codec_ids = codec_ids[:2 * (len(codec_ids) // 2)]
        # Reshape assuming vocals and instrumentals are interleaved
        reshaped = np.reshape(codec_ids, (-1, 2))
        vocals_ids = codectool.ids2npy(reshaped[:, 0])
        instrumentals_ids = codectool.ids2npy(reshaped[:, 1])
        vocals.append(vocals_ids)
        instrumentals.append(instrumentals_ids)
    vocals = np.concatenate(vocals, axis=1)
    instrumentals = np.concatenate(instrumentals, axis=1)

    vocal_save_path = os.path.join(
        stage1_output_dir,
        f"{genres.replace(' ', '-')}_vtrack_{random_id}.npy"
    )
    inst_save_path = os.path.join(
        stage1_output_dir,
        f"{genres.replace(' ', '-')}_itrack_{random_id}.npy"
    )
    # Save npy files accordingly.
    np.save(vocal_save_path, vocals)
    np.save(inst_save_path, instrumentals)

    # vocal_save_path = os.path.join(
    #     os.path.dirname(stage1_output_set[0]) if stage1_output_set else ".",
    #     f"{genres.replace(' ', '-')}_vtrack_{random_id}.npy"
    # )
    # inst_save_path = os.path.join(
    #     os.path.dirname(stage1_output_set[0]) if stage1_output_set else ".",
    #     f"{genres.replace(' ', '-')}_itrack_{random_id}.npy"
    # )
    # np.save(vocal_save_path, vocals)
    # np.save(inst_save_path, instrumentals)
    stage1_output_set.extend([vocal_save_path, inst_save_path])
    return stage1_output_set

def run_stage2_inference(stage2_model, stage1_output_set, stage2_output_dir, stage2_batch_size, device):
    """Load and run Stage 2 model on the outputs from Stage 1."""
    model_stage2 = AutoModelForCausalLM.from_pretrained(
        stage2_model,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    model_stage2.to(device)
    model_stage2.eval()
    if torch.__version__ >= "2.0.0":
        model_stage2 = torch.compile(model_stage2)
    stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=stage2_batch_size)
    print("Stage 2 outputs:", stage2_result)
    return stage2_result

def reconstruct_tracks(stage2_result, stage1_output_dir, codec_model, device):
    """Decode the stage 2 npy files into audio tracks and save them."""
    tracks = []
    for npy_file in stage2_result:
        codec_result = np.load(npy_file)
        with torch.no_grad():
            decoded_waveform = codec_model.decode(
                torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long)
                .unsqueeze(0).permute(1, 0, 2).to(device)
            )
        decoded_waveform = decoded_waveform.cpu().squeeze(0)
        save_path = os.path.join(stage1_output_dir,
                                 os.path.splitext(os.path.basename(npy_file))[0] + ".mp3")
        tracks.append(save_path)
        save_audio(decoded_waveform, save_path, 16000)
    return tracks

def mix_tracks(tracks, recons_mix_dir):
    """Mix vocal and instrumental tracks into a single audio file."""
    recons_mix = None
    for inst_path in tracks:
        try:
            if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) and '_itrack' in inst_path:
                vocal_path = inst_path.replace('_itrack', '_vtrack')
                if not os.path.exists(vocal_path):
                    continue
                recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('_itrack', '_mixed'))
                vocal_stem, sr = sf.read(inst_path)
                instrumental_stem, _ = sf.read(vocal_path)
                mix_stem = (vocal_stem + instrumental_stem) / 1
                sf.write(recons_mix, mix_stem, sr)
        except Exception as e:
            print(e)
    return recons_mix

def upsample_vocoder(stage2_result, codec_model, config_path, vocal_decoder_path, inst_decoder_path, output_dir, rescale):
    """Perform vocoder upsampling using decoder models to produce high-quality audio."""
    vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
    vocoder_output_dir = os.path.join(output_dir, 'vocoder')
    vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
    vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
    os.makedirs(vocoder_mix_dir, exist_ok=True)
    os.makedirs(vocoder_stems_dir, exist_ok=True)
    vocal_output = None
    instrumental_output = None

    for npy_file in stage2_result:
        if '_itrack' in npy_file:
            instrumental_output = process_audio(
                npy_file,
                os.path.join(vocoder_stems_dir, 'itrack.mp3'),
                rescale,
                dummy_args,
                inst_decoder,
                codec_model
            )
        else:
            vocal_output = process_audio(
                npy_file,
                os.path.join(vocoder_stems_dir, 'vtrack.mp3'),
                rescale,
                dummy_args,
                vocal_decoder,
                codec_model
            )
    try:
        mix_output = instrumental_output + vocal_output
        vocoder_mix = os.path.join(vocoder_mix_dir, 'mixed_output.mp3')
        save_audio(mix_output, vocoder_mix, 44100, rescale)
        print(f"Created mix: {vocoder_mix}")
    except RuntimeError as e:
        print(e)
        print(f"Mixing failed! instrumental shape: {instrumental_output.shape}, vocal shape: {vocal_output.shape}")
        vocoder_mix = None
    return vocoder_mix

def post_process_audio(recons_mix, vocoder_mix, output_dir):
    """Blend low frequencies from the reconstructed and vocoder outputs."""
    final_mix_path = os.path.join(output_dir, os.path.basename(recons_mix))
    replace_low_freq_with_energy_matched(
        a_file=recons_mix,     # 16kHz file
        b_file=vocoder_mix,    # 48kHz file
        c_file=final_mix_path,
        cutoff_freq=5500.0
    )
    print("Post processing complete.")
    return final_mix_path


In [5]:
lyric = """
[verse]
A new dawn is rising, sparks ignite the sky
Dreams once locked in pages, now they come alive
We're painting with data, a vision so bright
Turning the unknown into clear insight

[chorus]
Every code we write, pushing past the line
Every thought we chase, breaking space and time
You can't stop the future now
We won't slow down
The world is changing, hear the sound
We won't slow down

[verse]
They say it's just machines, but they don't understand
Behind the circuits, there's a guiding hand
We see the unseen, weave logic and light
Expanding the world with a keystroke at night

[chorus]
Every code we write, pushing past the line
Every thought we chase, breaking space and time
You can't stop the future now
We won't slow down
The world is changing, hear the sound
We won't slow down

[bridge]
Imagination unbound, the limits erased
A symphony of learning, the patterns embraced
From vision to motion, from whispers to speech
The edge of tomorrow is right within reach

[outro]
Every dream we chase, rewriting the rules
A world built on knowledge, breaking through
You can't stop the future now
We won't slow down
With AI rising all around
We won't slow down
"""

genre = """female blues airy vocal bright vocal piano sad romantic guitar jazz"""

with open('../prompt_egs/lyrics_AI.txt', 'w') as f:
    f.write(lyric)
with open('../prompt_egs/genre_AI.txt', 'w') as f:
    f.write(genre)


In [6]:
# This should take about 30 min

genre_txt="../prompt_egs/genre_AI.txt"
lyrics_txt="../prompt_egs/lyrics_AI.txt"
use_audio_prompt=False         # or True if using an audio prompt
use_dual_tracks_prompt=False     # or True if using dual tracks
output_dir="./my_output"
# Adjust other parameters as needed…


# Model and generation configuration
stage1_model = "m-a-p/YuE-s1-7B-anneal-en-cot"
stage2_model = "m-a-p/YuE-s2-1B-general"
max_new_tokens = 3000
run_n_segments = 2
stage2_batch_size = 4
# Prompt file paths
# Audio prompt options
audio_prompt_path = ""
prompt_start_time = 0.0
prompt_end_time = 30.0
# Dual track prompt options
vocal_track_prompt_path = ""
instrumental_track_prompt_path = ""
# Output and miscellaneous options
keep_intermediate = False
disable_offload_model = False
cuda_idx = 0
seed = 42
# Paths for xcodec and upsampler
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
config_path = './xcodec_mini_infer/decoders/config.yaml'
vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
rescale = False
"""Main pipeline function that calls modular sub-functions for each step."""

# Validate prompt options
print("Validating prompt options...")
validate_prompt_options(use_audio_prompt, audio_prompt_path,
                        use_dual_tracks_prompt, vocal_track_prompt_path, instrumental_track_prompt_path)

# Set up directories for outputs
print("Setting up directories...")
stage1_output_dir, stage2_output_dir, recons_output_dir, recons_mix_dir = setup_output_directories(output_dir)

# Set device and seed for reproducibility
device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
seed_everything(seed)

# Initialize tokenizer and models
print("Loading models for stage 1...")
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
model = load_stage1_model(stage1_model, device)
codec_model = load_codec_model(basic_model_config, resume_path, device)
codectool = CodecManipulator("xcodec", 0, 1)  # Instantiate codec helper

# Read prompt files and build prompt texts
genres, prompt_texts = read_prompt_files(genre_txt, lyrics_txt)

# # Stage 1: Generate raw outputs and save npy files for vocals/instrumentals
print("Performing stage 1...")
stage1_output_set = generate_stage1_outputs(
    prompt_texts, genres, model, mmtokenizer, codectool, codec_model, device,
    use_audio_prompt, audio_prompt_path, use_dual_tracks_prompt,
    vocal_track_prompt_path, instrumental_track_prompt_path,
    prompt_start_time, prompt_end_time, max_new_tokens,
    top_p=0.93, temperature=1.0, repetition_penalty=1.2, stage1_output_dir=stage1_output_dir
)

# Optionally offload Stage 1 model
if not disable_offload_model:
    model.cpu()
    del model
    torch.cuda.empty_cache()



Validating prompt options...
Setting up directories...
Loading models for stage 1...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/2.60G [00:00<?, ?B/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

  WeightNorm.apply(module, name, dim)


Performing stage 1...


  0%|          | 0/7 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
 86%|████████▌ | 6/7 [26:52<05:02, 302.22s/it]

Section 6: input length 15357 exceeds max context 13383, using last tokens.


100%|██████████| 7/7 [32:41<00:00, 280.25s/it]


In [7]:
# This would take around 40 min

stage1_output_set = [stage1_output_dir+'/' + a for a in os.listdir(stage1_output_dir)]

# Stage 2: Run inference on Stage 1 outputs
print("Running stage 2 inferences...")
stage2_result = run_stage2_inference(stage2_model, stage1_output_set, stage2_output_dir, stage2_batch_size, device)

# Reconstruction: Decode the Stage 2 outputs to audio tracks
print("Reconstructing the track...")
tracks = reconstruct_tracks(stage2_result, stage1_output_dir, codec_model, device)

# Mixing: Combine vocal and instrumental tracks into a mix
print("Mixing the track...")
recons_mix = mix_tracks(tracks, recons_mix_dir)

# Vocoder upsampling: Enhance audio quality using the vocoder
print("Upsampling the vocoder...")
vocoder_mix = upsample_vocoder(stage2_result, codec_model, config_path, vocal_decoder_path, inst_decoder_path, output_dir, rescale)

# Post processing: Blend low frequencies between mixes
print("Post processing the audio...")
final_mix = post_process_audio(recons_mix, vocoder_mix, output_dir)

print("Music pipeline completed successfully.")
results =  {
    "stage1_outputs": stage1_output_set,
    "stage2_outputs": stage2_result,
    "reconstructed_tracks": tracks,
    "vocoder_mix": vocoder_mix,
    "final_mix": final_mix
}
print("Pipeline completed. Final mix is saved at:", results["final_mix"])

Running stage 2 inferences...


config.json:   0%|          | 0.00/683 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.93G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

100%|██████████| 2/2 [1:12:09<00:00, 2164.62s/it]


Stage 2 outputs: ['./my_output/stage2/female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_vtrack_0a3f313a-658a-4244-9562-09403590b8ff.npy', './my_output/stage2/female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_itrack_0a3f313a-658a-4244-9562-09403590b8ff.npy']
Reconstructing the track...
Mixing the track...
Upsampling the vocoder...


  vocal_decoder.load_state_dict(torch.load(vocal_decoder_path))
  inst_decoder.load_state_dict(torch.load(inst_decoder_path))


Processing ./my_output/stage2/female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_vtrack_0a3f313a-658a-4244-9562-09403590b8ff.npy
Compressed shape: (8, 8781)


  compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}")


Decoded in 0.32s (542.56x RTF)
Saved: ./my_output/vocoder/stems/vtrack.mp3
Processing ./my_output/stage2/female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_itrack_0a3f313a-658a-4244-9562-09403590b8ff.npy
Compressed shape: (8, 8781)
Decoded in 0.08s (2219.57x RTF)
Saved: ./my_output/vocoder/stems/itrack.mp3
Created mix: ./my_output/vocoder/mix/mixed_output.mp3
Post processing the audio...
Successfully created 'female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_mixed_0a3f313a-658a-4244-9562-09403590b8ff.mp3' with matched low-frequency energy.
Post processing complete.
Music pipeline completed successfully.
Pipeline completed. Final mix is saved at: ./my_output/female-blues-airy-vocal-bright-vocal-piano-sad-romantic-guitar-jazz_mixed_0a3f313a-658a-4244-9562-09403590b8ff.mp3


In [8]:
from IPython.display import Audio, display

sound_file = results["final_mix"]
display(Audio(sound_file))
