# 🎵 A2SB: Audio-to-Audio Schrödinger Bridge - Kaggle Edition

**High-Quality Audio Restoration with NVIDIA A2SB**

This notebook includes **everything** you need: model download, setup, and Gradio web interface!

## ⚠️ IMPORTANT: GPU Required

**This notebook requires Kaggle GPU acceleration!**

### GPU Requirements:
- 🎯 **GPU Memory**: Requires 15GB+ GPU VRAM (P100 or T4 x2)
- ⏱️ **Processing Time**: Long audio files need extended runtime
- 💾 **RAM**: Minimum 25GB system RAM recommended
- 🚀 **Performance**: Better GPUs for faster processing

### How to Enable GPU on Kaggle:
1. Click **Settings** (right sidebar)
2. Under **Accelerator**, select **GPU P100** or **GPU T4 x2**
3. Click **Save**

### Kaggle Advantages:
- ✅ Free GPU access (30 hours/week)
- ✅ P100 GPU with 16GB VRAM
- ✅ 30GB RAM
- ✅ Persistent storage
- ✅ No subscription required

---

## 🌟 Features
- ✅ 44.1kHz high-resolution music restoration
- ✅ Bandwidth extension (high-frequency prediction)
- ✅ Audio inpainting (reconstruct missing segments)
- ✅ Support for long audio files (hours)
- ✅ End-to-end, no vocoder required
- ✅ **Gradio Web Interface** - User-friendly UI

## 📚 Resources
- 📄 [Paper](https://arxiv.org/abs/2501.11311)
- 💻 [GitHub Repository](https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-.git)
- 🎬 [Original NVIDIA Demo](https://research.nvidia.com/labs/adlr/A2SB/)
- 🤗 [Models](https://huggingface.co/nvidia/audio_to_audio_schrodinger_bridge)

---

**Usage:** Run cells in order. The last cell will launch the Gradio interface!

## 📦 1. Setup and Dependencies

**This will take 5-10 minutes. Please be patient!**

In [None]:
# Clone the optimized repository
print("📥 Cloning repository...\n")

# Change to /kaggle/working directory (writable)
import os
os.chdir('/kaggle/working')

# Remove if exists
!rm -rf diffusion-audio-restoration-colab-Kaggle-

# Clone repository
!git clone https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-.git
os.chdir('diffusion-audio-restoration-colab-Kaggle-')

print(f"\n✅ Repository cloned successfully!")
print(f"✓ Current directory: {os.getcwd()}")

In [None]:
# Shortened Combined Installation & Fix Script - Fixed NumPy Downgrade & Restart Note
# Key Fix: Force uninstall/reinstall NumPy 1.x to avoid extension conflicts.
# Run installs, then RESTART RUNTIME before verification/app launch.

print("📦 Shortened Installation & Fixes - Starting...")
print("⏱️  5-10 minutes...\n")

import torch
torch_version = torch.__version__.split('+')[0]
cuda_version = torch.version.cuda
index_url = "https://download.pytorch.org/whl/cu124" if '12.4' in cuda_version else "https://download.pytorch.org/whl/cu118"

print(f"🔧 PyTorch {torch_version}, CUDA {cuda_version} - Using {index_url.split('/')[-1]} wheels")

# Core audio/vision fixes
!pip uninstall -y torchaudio torchvision transformers huggingface-hub torchmetrics 2>/dev/null || true
if torch_version.startswith('2.6'):
    !pip install -q torchaudio==2.6.0+cu124 torchvision==0.21.0+cu124 --index-url {index_url}
elif torch_version.startswith('2.5'):
    !pip install -q torchaudio==2.5.0+cu124 torchvision==0.20.0+cu124 --index-url {index_url}
elif torch_version.startswith('2.4'):
    !pip install -q torchaudio==2.4.0+cu124 torchvision==0.19.0+cu124 --index-url {index_url}
else:
    !pip install -q torchaudio torchvision --index-url {index_url}

# Transformers & metrics
!pip install -q "huggingface-hub<1.0,>=0.24.0" "transformers>=4.44.0,<4.45.0" "torchmetrics>=1.4.0,<1.5.0"

# Main packages
!pip install -q "lightning>=2.5.0" librosa soundfile einops gradio "jsonargparse[signatures]>=4.0.0" rotary-embedding-torch pyyaml tqdm nest-asyncio

# RAPIDS alignment (full upgrade to 25.6 to reduce conflicts)
!pip install -q --upgrade "rmm-cu12==25.6.*" "libraft-cu12==25.6.*" "pylibraft-cu12==25.6.*" "libcugraph-cu12==25.6.*" "pylibcugraph-cu12==25.6.*" "cugraph-cu12==25.6.*" "cuml-cu12==25.6.*" "cuvs-cu12==25.6.*" "cudf-cu12==25.6.*" "pylibcudf-cu12==25.6.*"

# Dataset/Gradio fixes
!pip install -q --upgrade "pyarrow>=21.0.0" "pydantic>=2.0,<2.12"

# NumPy & scikit-learn (FORCE downgrade NumPy 1.x for extension compatibility)
print("🔧 Force-downgrading NumPy to 1.x...")
!pip uninstall -y numpy 2>/dev/null || true
!pip install -q --force-reinstall --no-deps "numpy==1.26.4"
!pip install -q "scikit-learn>=1.5.0,<1.6.0"

# Fix packages requiring NumPy 2.x (downgrade opencv if conflicting; ignore if unused)
!pip install -q "opencv-python==4.8.1.78"  # Supports NumPy 1.x

# Sentence-transformers (with deps now, post-NumPy fix)
!pip install -q sentence-transformers

# Optional SSR
!pip install -q ssr-eval 2>/dev/null || echo "⚠️ SSR skipped"

print("\n" + "="*50)
print("✅ Installs Complete! CRITICAL: RESTART RUNTIME NOW (Runtime > Restart session)")
print("Then run this verification cell in a NEW cell:")
print("="*50)
print("""
import nest_asyncio; nest_asyncio.apply()
import lightning as pl, gradio as gr, numpy as np
import torch

print(f'✓ PyTorch: {torch.__version__} | Lightning: {pl.__version__} | Gradio: {gr.__version__}')
print(f'✓ NumPy: {np.__version__} (1.x confirmed) | CUDA: {torch.cuda.is_available()}')

if torch.cuda.is_available():
    print(f'✓ GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB)')

# Key imports
modules = [
    ('torchvision', 'TorchVision'), ('torchaudio', 'TorchAudio'), ('transformers', 'Transformers'),
    ('torchmetrics', 'TorchMetrics'), ('datasets', 'Datasets')
]
for mod_name, name in modules:
    try:
        __import__(mod_name); print(f'✓ {name}: OK')
    except Exception as e: print(f'⚠️ {name}: {str(e)[:50]}...')

try: from sentence_transformers import SentenceTransformer; print('✓ Sentence-Transformers: OK')
except Exception as e: print(f'⚠️ Sentence-Transformers: {str(e)[:50]}...')

try: from lightning.pytorch import LightningModule; print('✓ Lightning/PyTorch Import: OK')
except Exception as e: print(f'⚠️ Lightning Import: {str(e)[:50]}...')

import subprocess
result = subprocess.run(['pip', 'check'], capture_output=True, text=True)
print('✅ No conflicts!' if 'no broken' in result.stdout.lower() else f'⚠️ Conflicts: {result.stdout[:150]}...')

print('\\n🎉 If all ✓, launch your app! For spaCy: !pip install \\"thinc<8.3\\"')
""")
print("\n🚀 Restart now to load clean NumPy 1.x & resolve dynamo/circular imports.")

## 📥 2. Download Model Files

We'll download two model checkpoints:
- **One-split (0.0-1.0)**: Full time range (~1.5GB)
- **Two-split (0.5-1.0)**: Second time range (~1.5GB)

**Total download: ~3GB. This will take 5-10 minutes.**

In [None]:
# Create checkpoint directory
!mkdir -p ckpt
print("✓ Checkpoint directory created")

In [None]:
import os
from tqdm import tqdm

print("📥 Downloading model checkpoints...\n")
print("⏱️  This will take 5-10 minutes depending on your connection.\n")

# Model files
models = {
    'onesplit': {
        'path': 'ckpt/A2SB_onesplit_0.0_1.0_release.ckpt',
        'url': 'https://huggingface.co/nvidia/audio_to_audio_schrodinger_bridge/resolve/main/ckpt/A2SB_onesplit_0.0_1.0_release.ckpt'
    },
    'twosplit': {
        'path': 'ckpt/A2SB_twosplit_0.5_1.0_release.ckpt',
        'url': 'https://huggingface.co/nvidia/audio_to_audio_schrodinger_bridge/resolve/main/ckpt/A2SB_twosplit_0.5_1.0_release.ckpt'
    }
}

# Check and download each model
for name, info in models.items():
    if os.path.exists(info['path']):
        size_mb = os.path.getsize(info['path']) / (1024 * 1024)
        print(f"✓ {name} model already exists ({size_mb:.2f} MB)")
    else:
        print(f"⬇️  Downloading {name} model (~1.5GB)...")
        !wget -q --show-progress -O {info['path']} {info['url']}
        if os.path.exists(info['path']):
            size_mb = os.path.getsize(info['path']) / (1024 * 1024)
            print(f"✅ {name} model downloaded ({size_mb:.2f} MB)")
        else:
            print(f"❌ Failed to download {name} model!")
            print(f"Please check your internet connection and try again.")

print("\n" + "="*50)
print("✅ Model download complete!")
print("="*50)

## ⚙️ 3. Configuration

Update the configuration file with the correct model paths.

In [None]:
import yaml

print("⚙️  Updating configuration...\n")

# Update config file
config_path = 'configs/ensemble_2split_sampling.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

config['model']['pretrained_checkpoints'] = [
    'ckpt/A2SB_onesplit_0.0_1.0_release.ckpt',
    'ckpt/A2SB_twosplit_0.5_1.0_release.ckpt'
]

with open(config_path, 'w') as f:
    yaml.dump(config, f)

print("✅ Configuration updated successfully!")
print(f"\nModel paths:")
for i, path in enumerate(config['model']['pretrained_checkpoints'], 1):
    print(f"  {i}. {path}")

## 🎨 4. Launch Gradio Web Interface

### 🚀 Ready to restore audio!

**Features:**
- 📤 Drag-and-drop file upload
- 🎤 Microphone recording
- ⚙️ Advanced settings (sampling steps, cutoff frequency)
- 📊 Real-time progress tracking
- 🔊 Instant playback and comparison
- 📈 Spectral analysis visualization

**How to use:**
1. **Run the cell below** - Wait for the Gradio link to appear
2. **Click the public link** (usually ends with `.gradio.live`)
3. **Upload audio** or record from microphone
4. **Choose mode:**
   - **Bandwidth Extension**: Restore high frequencies (for low-quality MP3s)
   - **Inpainting**: Fill in missing audio segments
5. **Adjust settings** (optional):
   - Sampling Steps: 25-100 (higher = better quality, slower)
   - Auto Cutoff: Automatically detect cutoff frequency
   - Inpainting Length: 0.1-1.0 seconds
6. **Click "🚀 Restore"** and wait for processing
7. **Listen & Download** the restored audio

**Tips:**
- Start with default settings (50 steps, auto cutoff)
- For faster results: 25-30 steps
- For best quality: 75-100 steps
- Processing time: ~2-3 minutes per 10 seconds of audio (on P100)

**⚠️ Important:**
- Keep this notebook tab open during processing
- Don't close the Kaggle session
- If you get "Out of Memory" error, reduce sampling steps or audio length
- Gradio will create a public URL that expires after 72 hours

In [None]:
from pyngrok import ngrok
import os
os.environ['NGROK_TOKEN'] = '2th59EQdTHjO8gP1GIetB9u7tdg_6rfJmEVKdwSDgnbh3yWtc'

In [None]:
# Launch Gradio interface with public URL
print("🚀 Launching Gradio interface...\n")
print("⏱️  Please wait for the link to appear below.\n")
print("="*60)

%cd /kaggle/working/diffusion-audio-restoration-colab-Kaggle-

# Kaggle için özel gradio_app kullan
import os
if os.path.exists('gradio_app_kaggle.py'):
    print("✓ Using Kaggle-optimized Gradio app")
    !python gradio_app_kaggle.py --share
else:
    print("⚠️ Using default Gradio app")
    !python gradio_app.py --share

print("\n" + "="*60)
print("✅ Gradio interface launched!")
print("Click the public link above to access the web interface.")
print("The link will be valid for 72 hours.")
print("="*60)

In [None]:
%%bash
set -euo pipefail

REPO_URL="https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-.git"
APP_DIR="/kaggle/working/diffusion-audio-restoration-colab-Kaggle-"

if [ -d "$APP_DIR/.git" ]; then
  echo "✅ Repo bulundu, güncelleniyor..."
  cd "$APP_DIR"
  git fetch --all --prune
  git reset --hard origin/main
else
  echo "📥 Repo yok, klonlanıyor..."
  rm -rf "$APP_DIR"
  git clone --depth=1 "$REPO_URL" "$APP_DIR"
  cd "$APP_DIR"
fi

git config --global --add safe.directory "$APP_DIR"
find "$APP_DIR" -type d -name "__pycache__" -exec rm -rf {} + || true

echo "✅ Güncel commit:"
git --no-pager log --oneline -1

## 📚 5. Tips and Troubleshooting

### ⚡ Performance Optimization

**GPU Requirements:**
- ✅ **Kaggle P100**: 16GB VRAM (Recommended)
- ✅ **Kaggle T4 x2**: 2x 16GB VRAM (Excellent)
- ⚠️ **Kaggle T4**: 16GB VRAM (May work)

**Processing Times (on P100 GPU):**
- 10 seconds audio, 50 steps: ~2-3 minutes
- 30 seconds audio, 50 steps: ~5-7 minutes
- 60 seconds audio, 50 steps: ~10-15 minutes

### 🎯 Quality Settings

**Sampling Steps:**
- **25-30:** Fast (good quality)
- **50-75:** Balanced (excellent quality) ⭐ Recommended
- **75-100:** Best (outstanding quality)

**Cutoff Frequency (Bandwidth Extension):**
- **Auto-detect**: Usually best ⭐ Recommended
- **Manual adjustment:**
  - Low-quality MP3: 2000-4000 Hz
  - Medium quality: 4000-8000 Hz
  - High quality: 8000+ Hz

**Inpainting Length:**
- 0.1-0.3s: Small gaps or clicks
- 0.3-0.5s: Medium gaps
- 0.5-1.0s: Large missing segments

### 🔧 Troubleshooting

#### ❌ CUDA Out of Memory Error

**Solutions:**
1. **Reduce sampling steps** to 25-30
2. **Split audio** into shorter segments (10-20 seconds)
3. **Restart kernel**: Kernel > Restart Kernel
4. **Clear GPU memory**: Run the cell below
5. **Enable P100 GPU** in Kaggle settings

```python
# Clear GPU memory
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
print("✅ GPU memory cleared")
```

#### ❌ Model Not Found Error

**Solutions:**
1. Re-run the model download cells (Section 2)
2. Check your internet connection
3. Verify files exist:
```python
!ls -lh ckpt/
```

#### ❌ Gradio Interface Not Loading

**Solutions:**
1. Wait 30-60 seconds for the link to appear
2. Check if the cell is still running
3. Restart kernel and run all cells again
4. Make sure you're using `--share` flag for public URL

#### ❌ Audio Format Error

**Solution:** Convert to WAV format
```python
import librosa
import soundfile as sf

# Convert any audio to WAV
y, sr = librosa.load('input.mp3', sr=44100)
sf.write('input.wav', y, sr)
```

#### ⚠️ Session Timeout

**Solutions:**
1. Kaggle provides 12 hours of continuous runtime
2. Keep the tab active
3. Process shorter audio files
4. Save intermediate results to `/kaggle/working/`

### 💡 Best Practices

1. **Start small**: Test with 10-20 second clips first
2. **Use defaults**: 50 steps, auto cutoff works well
3. **Monitor GPU**: Check `nvidia-smi` if issues occur
4. **Save outputs**: Download restored audio immediately
5. **Batch processing**: Process multiple files one at a time
6. **Use Kaggle Datasets**: Upload your audio files as a Kaggle dataset for easier access

### 📁 Kaggle File System

**Important directories:**
- `/kaggle/input/`: Read-only input data (datasets)
- `/kaggle/working/`: Writable directory (your work)
- `/kaggle/temp/`: Temporary files

**Upload audio files:**
1. Create a Kaggle dataset with your audio files
2. Add the dataset to your notebook
3. Access files from `/kaggle/input/your-dataset-name/`

**Download results:**
- Restored audio files are saved in `/kaggle/working/`
- Click the "Output" tab to download files
- Or use the Gradio interface to download directly

### 📖 License and Usage

- **Model:** NVIDIA OneWay NonCommercial License
- **Code:** NVIDIA Source Code License - Non Commercial
- **Commercial Use:** Contact NVIDIA for licensing
- **Research Use:** Free for academic and research purposes

### 🔗 Additional Resources

- **Paper:** [arXiv:2501.11311](https://arxiv.org/abs/2501.11311)
- **GitHub:** [test4373/diffusion-audio-restoration](https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-.git)
- **Original NVIDIA Repo:** [NVIDIA/diffusion-audio-restoration](https://github.com/NVIDIA/diffusion-audio-restoration)
- **Demo:** [NVIDIA Research](https://research.nvidia.com/labs/adlr/A2SB/)
- **Models:** [HuggingFace](https://huggingface.co/nvidia/audio_to_audio_schrodinger_bridge)

### 📧 Support

- **Issues:** [GitHub Issues](https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-/issues)
- **Original NVIDIA Issues:** [NVIDIA GitHub](https://github.com/NVIDIA/diffusion-audio-restoration/issues)
- **Kaggle Community:** [Kaggle Forums](https://www.kaggle.com/discussions)

---

### 🎉 Thank You!

Thank you for using this notebook on Kaggle!

**Citation:**
```bibtex
@article{kong2025a2sb,
  title={A2SB: Audio-to-Audio Schrodinger Bridges},
  author={Kong, Zhifeng and Shih, Kevin J and Nie, Weili and Vahdat, Arash and Lee, Sang-gil and Santos, Joao Felipe and Jukic, Ante and Valle, Rafael and Catanzaro, Bryan},
  journal={arXiv preprint arXiv:2501.11311},
  year={2025}
}
```

### ⭐ Support This Project

If you find this project useful:
- ⭐ Star the [GitHub repository](https://github.com/test4373/diffusion-audio-restoration-colab-Kaggle-.git)
- 🐛 Report bugs or suggest features
- 📢 Share with others who might benefit
- 👍 Upvote this Kaggle notebook

---

**Made with ❤️ for the audio restoration community**

**Optimized for Kaggle with GPU memory management and user-friendly interface**

### 🆓 Why Kaggle?

Kaggle offers several advantages:
- **Free GPU access**: 30 hours per week
- **Powerful hardware**: P100 GPU with 16GB VRAM
- **Generous resources**: 30GB RAM, 73GB disk
- **Persistent storage**: Save your work and models
- **Community**: Share and collaborate with others
- **No subscription**: Completely free!

Perfect for audio restoration projects! 🎵