In [1]:
import json
import logging
import os
import shutil
import time
from datetime import timedelta

import soundfile as sf
import torch
import torchaudio
from einops import rearrange
from tqdm import tqdm

from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict
from utils.vimsketch_dataset import VimSketchDataset

# Constants
MODEL_CONFIG_PATH = "runs/audiocaps_finetune_ctrl/3qfv6n0i/checkpoints/model_config_small_custom.json"
MODEL_CKPT_PATH = "runs/audiocaps_finetune_ctrl/3qfv6n0i/checkpoints/epoch=2-step=50000.ckpt"
DATASET_ROOT = "/home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/"

# Inference Parameters
TRANSFER_STRENGTH = 0.75  # for small model from 0 to 1
GUIDANCE_SCALE = 1.0  # 1.0 for rf_denoiser
STEPS = 8  # 8 for rf_denoiser
SEED = 42
TTA = True  # Set to True for Text-to-Audio

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def load_model(model_config_path, model_ckpt_path, device="cuda"):
    print(f"Loading model config from {model_config_path}")
    with open(model_config_path) as f:
        model_config = json.load(f)

    print("Creating model from config")
    model = create_model_from_config(model_config)

    print(f"Loading model checkpoint from {model_ckpt_path}")
    copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))

    model.to(device).eval().requires_grad_(False)
    print("Model loaded successfully")
    return model, model_config

def save_wave(waveform, savepath, name="outwav", sample_rate=44100):
    if type(name) is not list:
        name = [name] * waveform.shape[0]

    for i in range(waveform.shape[0]):
        path = os.path.join(
            savepath,
            "%s.wav"
            % (
                os.path.basename(name[i])
                if (".wav" not in name[i])
                else os.path.basename(name[i]).split(".")[0]
            ),
        )
        print("Save audio to %s" % path)

        # Post-processing
        audio = waveform[i]

        # Peak normalize, clip, convert to int16
        audio = (
            audio.to(torch.float32)
            .div(torch.max(torch.abs(audio)))
            .clamp(-1, 1)
            .mul(32767)
            .to(torch.int16)
            .cpu()
        )

        torchaudio.save(path, audio, sample_rate)

In [3]:
# Target files to generate
target_files = [
    'dataset/Vim_Sketch_Dataset/vocal_imitations/06724_112-needle_strings-commercial_synthesizers.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07233_162-subsynth_2007-single_synthesizer.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07266_166-subsynth_2039-single_synthesizer.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07573_196-subsynth_9879-single_synthesizer.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07602_199-synth_metallic_stars-commercial_synthesizers.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07794_218-vibraphone_sustained-acoustic_instruments.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/07948_233-windgong-acoustic_instruments.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/11764_132-piano_playing-everyday.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/11873_143-rooster_calling-everyday.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/11883_144-sandpaper_rubbing-everyday.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/12048_200-tambourine-acoustic_instruments.wav',
    'dataset/Vim_Sketch_Dataset/vocal_imitations/12292_224-water_bubbling-everyday.wav'
]
# Extract just the filenames for matching
target_filenames = set(os.path.basename(f) for f in target_files)

# Setup paths
dataset_root = os.environ.get("DATASET_ROOT", DATASET_ROOT)

dataset = VimSketchDataset(dataset_root)
print(f"Dataset loaded with {len(dataset)} items.")

# Setup save path
if TTA:
    save_path = os.path.join(dataset_root, "test", "audios")
else:
    # Fallback/Not expected for this task
    save_path = os.path.join(
        dataset_root, "style_transfer_sao", f"transfer_strength_{TRANSFER_STRENGTH}"
    )

os.makedirs(save_path, exist_ok=True)
print(f"Saving to: {save_path}")

# Load Model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {device}")
model, model_config = load_model(MODEL_CONFIG_PATH, MODEL_CKPT_PATH, device=device)
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

# Start timing
start_time = time.time()

print(f"Starting generation for {len(target_filenames)} specific files...")

processed_count = 0

for i in tqdm(range(len(dataset))):
    imitation_path = dataset[i]["imitation_path"]
    reference_path = dataset[i]["reference_path"]
    text = dataset[i]["text"]

    # Get output filename
    imitation_filename = os.path.basename(imitation_path)
    
    # FILTER: Check if this file is in our target list
    if imitation_filename not in target_filenames:
        continue
        
    output_file = os.path.join(save_path, imitation_filename)

    # Skip if already processed
    if os.path.exists(output_file):
        print(f"Skipping {imitation_filename}, already processed")
        processed_count += 1
        continue

    try:
        print(f"Processing: {imitation_filename}")
        
        # Get audio duration from imitation path
        info = torchaudio.info(imitation_path)
        duration = info.num_frames / info.sample_rate

        # Calculate sample_size based on duration
        target_sample_size = int(duration * sample_rate)

        # Ensure target_sample_size is valid for the model
        if model.pretransform is not None:
            downsampling_ratio = model.pretransform.downsampling_ratio
            target_sample_size = (
                target_sample_size // downsampling_ratio
            ) * downsampling_ratio
            
        # Load audio for control signals
        audio_tensor, sr = torchaudio.load(imitation_path)
        if sr != sample_rate:
            resampler = torchaudio.transforms.Resample(sr, sample_rate)
            audio_tensor = resampler(audio_tensor)
        
        # FIX: Ensure audio_tensor matches target_sample_size
        # The model's encoder expects the input audio to align with the target sample size
        if audio_tensor.shape[-1] > target_sample_size:
            audio_tensor = audio_tensor[..., :target_sample_size]
        elif audio_tensor.shape[-1] < target_sample_size:
            # Pad with zeros
            pad_size = target_sample_size - audio_tensor.shape[-1]
            audio_tensor = torch.nn.functional.pad(audio_tensor, (0, pad_size))

        # Prepare conditioning
        conditioning = [
            {
                "prompt": text,
                "seconds_start": 0,
                "seconds_total": duration,
                "audio": audio_tensor,
            }
        ]

        # Determine sampler parameters
        diffusion_objective = model.diffusion_objective
        if diffusion_objective == "rf_denoiser":
            sampler_type = "pingpong"
            sigma_min = 0.01
            sigma_max = 1.0
        elif diffusion_objective == "rectified_flow":
            sampler_type = "euler"
            sigma_min = 0.01
            sigma_max = 1.0
        else:
            sampler_type = "dpmpp-3m-sde"
            sigma_min = 0.03
            sigma_max = 500

        # TTA Logic
        # Calculate max duration from model config
        max_duration = sample_size / sample_rate

        if duration > max_duration:
            print(
                f"Requested duration {duration:.2f}s exceeds model max {max_duration:.2f}s, clipping."
            )
            duration = max_duration
            # Start over calculation if clipped
            target_sample_size = int(duration * sample_rate)
            if model.pretransform is not None:
                downsampling_ratio = model.pretransform.downsampling_ratio
                target_sample_size = (
                    target_sample_size // downsampling_ratio
                ) * downsampling_ratio

        # Generate TTA
        output = generate_diffusion_cond(
            model,
            steps=STEPS,
            cfg_scale=GUIDANCE_SCALE,
            conditioning=conditioning,
            sample_size=target_sample_size,
            seed=SEED,
            device=device,
            init_audio=None,
            init_noise_level=1.0,
            sampler_type=sampler_type,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
        )

        # Trimming
        output = output[..., :target_sample_size]

        save_wave(
            output,
            save_path,
            name=imitation_filename,
            sample_rate=sample_rate,
        )
        
        processed_count += 1
        print(f"Successfully generated {imitation_filename}")

    except Exception as e:
        print(f"Error processing {imitation_filename}: {str(e)}")
        # raise e # Uncomment for debugging
        continue

print(f"\nDone! Processed {processed_count}/{len(target_filenames)} target files.")

Dataset loaded with 12453 items.
Saving to: /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios
Loading model on cuda
Loading model config from runs/audiocaps_finetune_ctrl/3qfv6n0i/checkpoints/model_config_small_custom.json
Creating model from config


  WeightNorm.apply(module, name, dim)


Loading model checkpoint from runs/audiocaps_finetune_ctrl/3qfv6n0i/checkpoints/epoch=2-step=50000.ckpt
Model loaded successfully
Starting generation for 12 specific files...


  0%|          | 0/12453 [00:00<?, ?it/s]

Processing: 06724_112-needle_strings-commercial_synthesizers.wav
42


  info = torchaudio.info(imitation_path)
  return AudioMetaData(
  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
  + 2 * np.log10(f_sq)
100%|██████████| 8/8 [00:00<00:00, 48.63it/s]
  info = torchaudio.info(imitation_path)
  return AudioMetaData(
  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
  + 2 * np.log10(f_sq)


Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/06724_112-needle_strings-commercial_synthesizers.wav
Successfully generated 06724_112-needle_strings-commercial_synthesizers.wav
Processing: 11764_132-piano_playing-everyday.wav
42


100%|██████████| 8/8 [00:00<00:00, 76.74it/s]
 77%|███████▋  | 9564/12453 [00:02<00:00, 4063.07it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/11764_132-piano_playing-everyday.wav
Successfully generated 11764_132-piano_playing-everyday.wav
Processing: 11873_143-rooster_calling-everyday.wav
42


100%|██████████| 8/8 [00:00<00:00, 78.06it/s]
 79%|███████▉  | 9895/12453 [00:02<00:00, 3734.80it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/11873_143-rooster_calling-everyday.wav
Successfully generated 11873_143-rooster_calling-everyday.wav
Processing: 11883_144-sandpaper_rubbing-everyday.wav
42


100%|██████████| 8/8 [00:00<00:00, 81.43it/s]
 81%|████████▏ | 10148/12453 [00:02<00:00, 3277.38it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/11883_144-sandpaper_rubbing-everyday.wav
Successfully generated 11883_144-sandpaper_rubbing-everyday.wav
Processing: 07233_162-subsynth_2007-single_synthesizer.wav
42


100%|██████████| 8/8 [00:00<00:00, 78.48it/s]
 84%|████████▎ | 10414/12453 [00:02<00:00, 2803.63it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07233_162-subsynth_2007-single_synthesizer.wav
Successfully generated 07233_162-subsynth_2007-single_synthesizer.wav
Processing: 07266_166-subsynth_2039-single_synthesizer.wav
42


100%|██████████| 8/8 [00:00<00:00, 78.89it/s]
 85%|████████▍ | 10583/12453 [00:03<00:00, 2338.84it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07266_166-subsynth_2039-single_synthesizer.wav
Successfully generated 07266_166-subsynth_2039-single_synthesizer.wav
Processing: 07573_196-subsynth_9879-single_synthesizer.wav
42


100%|██████████| 8/8 [00:00<00:00, 78.61it/s]
 89%|████████▉ | 11117/12453 [00:03<00:00, 2355.68it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07573_196-subsynth_9879-single_synthesizer.wav
Successfully generated 07573_196-subsynth_9879-single_synthesizer.wav
Processing: 07602_199-synth_metallic_stars-commercial_synthesizers.wav
42


100%|██████████| 8/8 [00:00<00:00, 51.43it/s]


Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07602_199-synth_metallic_stars-commercial_synthesizers.wav
Successfully generated 07602_199-synth_metallic_stars-commercial_synthesizers.wav
Processing: 12048_200-tambourine-acoustic_instruments.wav
42


100%|██████████| 8/8 [00:00<00:00, 81.05it/s]
 91%|█████████ | 11289/12453 [00:04<00:00, 1322.51it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/12048_200-tambourine-acoustic_instruments.wav
Successfully generated 12048_200-tambourine-acoustic_instruments.wav
Processing: 07794_218-vibraphone_sustained-acoustic_instruments.wav
42


100%|██████████| 8/8 [00:00<00:00, 78.28it/s]
 94%|█████████▍| 11768/12453 [00:04<00:00, 1384.51it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07794_218-vibraphone_sustained-acoustic_instruments.wav
Successfully generated 07794_218-vibraphone_sustained-acoustic_instruments.wav
Processing: 12292_224-water_bubbling-everyday.wav
42


100%|██████████| 8/8 [00:00<00:00, 77.37it/s]
 96%|█████████▌| 11972/12453 [00:04<00:00, 1145.71it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/12292_224-water_bubbling-everyday.wav
Successfully generated 12292_224-water_bubbling-everyday.wav
Processing: 07948_233-windgong-acoustic_instruments.wav
42


100%|██████████| 8/8 [00:00<00:00, 77.81it/s]
100%|██████████| 12453/12453 [00:04<00:00, 2499.48it/s]

Save audio to /home/paul/OneDrive/Master/practical_work/Practical-Work-AI-Master/dataset/Vim_Sketch_Dataset/test/audios/07948_233-windgong-acoustic_instruments.wav
Successfully generated 07948_233-windgong-acoustic_instruments.wav

Done! Processed 12/12 target files.



