# A2SB Audio Upsampling

This notebook allows you to upscale an audio file using the A2SB model.

In [None]:
# @title Clone Repository and Install Dependencies
!git clone https://github.com/NVIDIA/diffusion-audio-restoration.git
%cd diffusion-audio-restoration
!pip install -q numpy scipy matplotlib jsonargparse librosa soundfile torch torchaudio einops pytorch_lightning rotary_embedding_torch ssr_eval ipywidgets

In [None]:
# @title Import Libraries
import os
import numpy as np
import yaml
from google.colab import files
import ipywidgets as widgets
from IPython.display import display, Audio
import librosa
import soundfile as sf
from subprocess import Popen, PIPE
from datetime import datetime

In [None]:
# @title Helper Functions

def load_yaml(file_path):
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    return data

def save_yaml(data, prefix="./temp_config"):
    os.makedirs(os.path.dirname(prefix), exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    rnd_num = np.random.rand()
    rnd_num = rnd_num - rnd_num % 0.000001
    file_name = f"{prefix}_{timestamp}_{rnd_num}.yaml"
    with open(file_name, 'w') as f:
        yaml.dump(data, f)
    return file_name

def shell_run_cmd(cmd):
    print('running:', cmd)
    p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True)
    stdout, stderr = p.communicate()
    print(stdout.decode())
    print(stderr.decode())

### 1. Upload Audio File
Run the following cell to upload your audio file. Make sure it's in a format that `librosa` can read (e.g., .wav, .mp3).

In [None]:
uploaded = files.upload()
input_filename = list(uploaded.keys())[0]
print(f'Uploaded file: {input_filename}')

Saving test.wav to test.wav


<IPython.core.display.HTML object>

### 2. Download Pre-trained Model

In [None]:
CHECKPOINT_URL = "https://huggingface.co/nvidia/audio_to_audio_schrodinger_bridge/resolve/main/ckpt/A2SB_onesplit_0.0_1.0_release.ckpt?download=true"
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_FILENAME = "A2SB_onesplit_0.0_1.0_release.ckpt"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, CHECKPOINT_FILENAME)

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

if not os.path.exists(CHECKPOINT_PATH):
    print("Downloading checkpoint...")
    shell_run_cmd(f"wget -O {CHECKPOINT_PATH} {CHECKPOINT_URL}")
    print("Download complete.")
else:
    print("Checkpoint already exists.")

In [None]:
def compute_rolloff_freq(audio_file, roll_percent=0.99):
    y, sr = librosa.load(audio_file, sr=None)
    rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr, roll_percent=roll_percent)[0]
    rolloff = int(np.mean(rolloff))
    print('99 percent rolloff:', rolloff)
    return rolloff

def upsample_one_sample(audio_filename, output_audio_filename, predict_n_steps=50):
    assert output_audio_filename != audio_filename, "output filename cannot be input filename"

    inference_config = load_yaml('configs/inference_files_upsampling.yaml')
    inference_config['data']['predict_filelist'] = [
        {
            'filepath': f'../{audio_filename}',
            'output_subdir': '.'
        }
    ]

    cutoff_freq = compute_rolloff_freq(f'../{audio_filename}', roll_percent=0.99)
    inference_config['data']['transforms_aug'][0]['init_args']['upsample_mask_kwargs'] = {
        'min_cutoff_freq': cutoff_freq,
        'max_cutoff_freq': cutoff_freq
    }
    temporary_yaml_file = save_yaml(inference_config)

    cmd = (
        f"python ensembled_inference_api.py predict "
        f"-c configs/onesplit.yaml "
        f"-c {temporary_yaml_file} "
        f"--model.predict_n_steps={predict_n_steps} "
        f"--model.output_audio_filename=../{output_audio_filename}"
    )
    shell_run_cmd(cmd)
    
    os.remove(temporary_yaml_file)


### 4. Run Upsampling
Run the following cell to start the upsampling process.

In [None]:
output_filename = f"restored_{input_filename}"
upsample_one_sample(input_filename, output_filename)

### 5. Results

In [None]:
print("Original Audio:")
display(Audio(input_filename))

print("Upscaled Audio:")
display(Audio(output_filename))

print("Download the upscaled audio:")
files.download(output_filename)