In [1]:
from stable_audio_tools.models.pretrained import get_pretrained_model

model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")


No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention


  WeightNorm.apply(module, name, dim)


In [2]:
import sys
sys.path.append("/workspace/stable-audio-tools")


In [3]:
!pwd

/workspace


In [4]:
import torchaudio
import numpy as np
from pathlib import Path
import torch

torch.set_float32_matmul_precision('high')
# Make sure model has a pretransform encoder
assert hasattr(model, "pretransform") and model.pretransform is not None, "Your model must have a .pretransform encoder"
model.pretransform.to("cuda").eval()

# Config
AUDIO_DIR = Path("/workspace/data3")
OUTPUT_DIR = Path("/workspace/data3_preencoded_overlap")
SAMPLE_RATE = 44100
SEGMENT_DURATION = 1.49
SEGMENT_SAMPLES = int(SAMPLE_RATE * SEGMENT_DURATION)
OVERLAP_RATIO = 0.5
STEP_SIZE = int(SEGMENT_SAMPLES * (1 - OVERLAP_RATIO))
MAX_SEGMENTS_PER_FILE = 20

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("🔁 Encoding and saving pre-encoded latent segments...")

total_saved = 0

for file in AUDIO_DIR.glob("*.wav"):
    try:
        audio, sr = torchaudio.load(str(file))

        # Resample if needed
        if sr != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)
            audio = resampler(audio)

        # 🔧 Force stereo if mono
        if audio.shape[0] == 1:
            audio = torch.cat([audio, audio], dim=0)

        start = 0
        seg_idx = 0
        while start + SEGMENT_SAMPLES <= audio.shape[1] and seg_idx < MAX_SEGMENTS_PER_FILE:
            segment = audio[:, start:start + SEGMENT_SAMPLES].unsqueeze(0).to("cuda")  # [1, 2, T]

            with torch.no_grad():
                latent = model.pretransform.encode(segment)  # [1, D, T']
                latent = latent.squeeze(0).cpu().numpy()    # [D, T']

            out_path = OUTPUT_DIR / f"{file.stem}_ov{seg_idx}.npy"
            np.save(out_path, latent)

            start += STEP_SIZE
            seg_idx += 1
            total_saved += 1

    except Exception as e:
        print(f"❌ Error processing {file.name}: {e}")

print(f"✅ Done. Saved {total_saved} pre-encoded .npy segments to: {OUTPUT_DIR}")


🔁 Encoding and saving pre-encoded latent segments...
✅ Done. Saved 80 pre-encoded .npy segments to: /workspace/data3_preencoded_overlap


In [5]:
from stable_audio_tools.data.dataset import create_dataloader_from_config
import json

with open("/workspace/dataset_config.json") as f:
    dataset_config = json.load(f)

train_loader = create_dataloader_from_config(
    dataset_config,
    batch_size=1,
    sample_size=65536,
    sample_rate=44100,
    audio_channels=1,
    num_workers=8
)


Found 80 files


In [6]:
from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper
import pytorch_lightning as pl

training_wrapper = DiffusionCondTrainingWrapper(
    model=model,
    lr=1e-4,
    pre_encoded=True
)

trainer = pl.Trainer(
    max_steps=1200,                     # <-- total steps to train
    accumulate_grad_batches=4,          # <-- simulate larger batch
    precision=16,                        # <-- mixed precision (faster)
    log_every_n_steps=600,               # <-- print loss every 10 steps
    enable_progress_bar=True,           
    enable_checkpointing=False,
    val_check_interval=None,
    strategy='auto',
    devices=1
)

trainer.fit(training_wrapper, train_dataloaders=train_loader)
training_wrapper.export_model("/workspace/stable-audio-tools/saved/final_model.pt")


/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                             | Params
-------------------------------------------------------------------
0 | diffusion     | ConditionedDiffusionModelWrappe

Training: |                                           | 0/? [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
`Trainer.fit` stopped: `max_steps=1200` reached.


In [7]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [20]:
from stable_audio_tools.inference.generation import generate_diffusion_cond

conditioning = [
    {
        "prompt": "Atlanta-style trap beat with heavy bass and synth lead",
        "seconds_total": 20.0  # leave this as float
    }
]

DURATION_SEC = 10
SAMPLE_RATE = 44100
SAMPLE_SIZE = DURATION_SEC * SAMPLE_RATE

output = generate_diffusion_cond(
    model=model.to("cuda"),       # make 100% sure model is on GPU
    steps=1000,
    cfg_scale=2.0,
    conditioning=conditioning,
    sample_size=SAMPLE_SIZE,       # 9 seconds
    device="cuda"                 # force everything onto GPU
)


1725299224


1000it [02:08,  7.77it/s]


In [21]:
import torchaudio
from einops import rearrange

# Rearrange: [B, C, T] -> [C, T] for saving
waveform = rearrange(output, "b c n -> c (b n)")

# Peak normalize and convert to 16-bit PCM
waveform = waveform.to(torch.float32).div(torch.max(torch.abs(waveform))).mul(32767).clamp(-32768, 32767).to(torch.int16).cpu()


In [22]:
torchaudio.save("remix_output_prompt_1000_steps.wav", waveform, sample_rate=44100)
print("✅ Saved to remix_output_prompt_1000_steps.wav")


✅ Saved to remix_output_prompt_1000_steps.wav
