In [None]:
!pip uninstall -y typing-extensions typing-inspection starlette pydantic grpcio fastapi ipython torch torchvision torchaudio

In [None]:
!python -m pip install --upgrade pip setuptools wheel
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

In [None]:
import os
from datetime import datetime

BASE_DIR = "/workspace"

A2A_REPAINT_DIR = os.path.join(BASE_DIR, "text2audio_gen", "repaints")
T2A_INPUT_DIR = os.path.join(BASE_DIR, "text2audio_input")
A2A_DIR = os.path.join(BASE_DIR, "audio2audio_clips")
GEN_DIR = os.path.join(BASE_DIR, "text2audio_gen")

for d in (A2A_DIR, T2A_INPUT_DIR, GEN_DIR, A2A_REPAINT_DIR):
    os.makedirs(d, exist_ok=True)

SRC_AUDIO_PATH = os.getenv("ACE_SRC_AUDIO", "/workspace/audio2audio_clips/src_audio.wav")

OUTPUT_WAV_PATH = os.path.join(GEN_DIR, "generated_music.wav")
PROMPT_FILE = os.path.join(T2A_INPUT_DIR, "prompt_input.txt")
LYRICS_FILE = os.path.join(T2A_INPUT_DIR, "lyrics_input.txt")

ts = datetime.now().strftime("%Y%m%d_%H%M%S")

REPAINT_WAV_PATH = os.path.join(GEN_DIR, f"repaint_{ts}.wav")
EXTEND_WAV_PATH  = os.path.join(GEN_DIR, f"extend_{ts}.wav")

print("SRC_AUDIO_PATH:", SRC_AUDIO_PATH)
print("T2A_INPUT_DIR:", T2A_INPUT_DIR)
print("PROMPT_FILE  :", PROMPT_FILE)
print("LYRICS_FILE  :", LYRICS_FILE)
print("GEN_DIR:", GEN_DIR)

In [None]:
!git clone https://github.com/ace-step/ACE-Step.git
%cd ACE-Step

In [None]:
!pip install -r requirements.txt
!pip install -e .

In [None]:
import sys
import os

repo_path = os.getenv('ACE_STEP_REPO_PATH', os.getcwd())
os.chdir(repo_path)

if repo_path not in sys.path:
    sys.path.append(repo_path)

print("Working directory:", os.getcwd())
print("In sys.path:", repo_path in sys.path)

In [None]:
import os
from huggingface_hub import snapshot_download

repo_id = "ACE-Step/ACE-Step-v1-3.5B"
cache_root = os.getenv('ACE_STEP_WEIGHTS_PATH', os.path.join(repo_path, '..', 'acestep_checkpoints'))
os.makedirs(cache_root, exist_ok=True)

snapshot_dir = snapshot_download(repo_id=repo_id, cache_dir=cache_root)
print("Snapshot directory:", snapshot_dir)

In [None]:
import torch
from acestep.pipeline_ace_step import ACEStepPipeline

use_cuda = torch.cuda.is_available()
device_id = 0 if use_cuda else -1
use_bf16 = True

dtype = torch.bfloat16 if (use_bf16 and use_cuda) else torch.float32

pipe = ACEStepPipeline(
    checkpoint_dir=snapshot_dir,
    dtype=dtype,      
    device_id=device_id,           
    torch_compile=False
)

pipe.load_checkpoint(checkpoint_dir=snapshot_dir)
print("Pipeline loaded.")

In [10]:
def generate_music(text_prompt: str, duration: int = 30):
    waveform = model.generate(text=text_prompt, duration=duration)
    return waveform

In [None]:
def _read(path: str) -> str:
    try:
        with open(path, "r", encoding="utf-8") as f:
            return f.read().strip()
    except FileNotFoundError:
        print(f"[warn] Missing {path}; creating starter file")
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            f.write("")
        return ""
    except UnicodeDecodeError:
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            return f.read().strip()

prompt_text = _read(PROMPT_FILE)
lyrics_text = _read(LYRICS_FILE)

print("Prompt chars:", len(prompt_text))
print("Lyrics chars:", len(lyrics_text))

In [None]:
import os, math, torch, torchaudio

SR = 44100
HOP = 4096  

def sec_to_frames(s): 
    return int(round((s * SR) / HOP))

def frames_to_sec(f):
    return (f * HOP) / SR

wav, sr = torchaudio.load(SRC_AUDIO_PATH)  
if wav.dim() == 2 and wav.size(0) > 1:
    wav = wav.mean(dim=0, keepdim=True)
if sr != SR:
    wav = torchaudio.functional.resample(wav, sr, SR)
    sr = SR
L = wav.shape[-1]
total_frames = math.ceil(L / HOP)           
duration_qs  = frames_to_sec(total_frames)  

desired_start_s = 5.0
desired_span_s  = 10.0

start_f = max(0, min(total_frames - 2, sec_to_frames(desired_start_s)))
end_f   = min(total_frames, start_f + max(2, sec_to_frames(desired_span_s)))
start_s_q = frames_to_sec(start_f)
end_s_q   = frames_to_sec(end_f)

print("Source duration (q):", duration_qs, "s | repaint window:", start_s_q, "->", end_s_q, "s")

In [None]:
result = pipe(
    audio_duration=10,
    prompt=prompt_text,
    lyrics=lyrics_text,
    infer_step=60,
    guidance_scale=15.0,
    scheduler_type="euler",
    cfg_type="apg",
    omega_scale=10.0,
    manual_seeds="",
    guidance_interval=0.5,
    guidance_interval_decay=0.0,
    min_guidance_scale=3.0,
    use_erg_tag=True,
    use_erg_lyric=True,
    use_erg_diffusion=True,
    oss_steps="",
    guidance_scale_text=0.0,
    guidance_scale_lyric=0.0,
    src_audio_path=SRC_AUDIO_PATH,
    save_path=REPAINT_WAV_PATH,
    # Repaint
    task="repaint",
    audio2audio_enable=True,                   
    ref_audio_strength=0.4,  
    retake_variance=0.6,
    repaint_start=start_s_q,
    repaint_end=end_s_q,
    retake_seeds=""
)