<center>
  <h1>salsa-milk - Isolate Vocals from Media</h1>
</center>

This notebook allows you to extract vocals from audio and video files using the Demucs AI model.

## How it Works

salsa-milk uses Facebook's [Demucs](https://github.com/facebookresearch/demucs) to separate vocals from music and background noise. The process runs on Google Colab, utilizing its computational resources for faster processing.

## How to Use

1. **For YouTube Videos**: Due to YouTube's restrictions, download files locally first:
   - Install yt-dlp: `pip install yt-dlp`
   - Download: `yt-dlp -x --audio-format wav "YOUR_YOUTUBE_URL"`
   - Upload the downloaded file(s) to this notebook

2. **For Local Files**: Upload your audio or video files directly

3. **Batch Processing**: You can process multiple files at once by placing them in the input directory

4. Run the processing cell and download your extracted vocals

Let's get started!

In [None]:
#@title ## Extract Vocals from Media Files

# First, ensure setup is complete
def ensure_setup_complete():
    """Check if dependencies are installed and install them if not."""
    import subprocess
    import sys
    import os
    import time
    
    print("Checking setup requirements...")
    
    # Create necessary directories
    os.makedirs('input', exist_ok=True)
    os.makedirs('output', exist_ok=True)
    os.makedirs('tmp', exist_ok=True)
    
    # Force update yt-dlp to latest version (crucial for YouTube downloads)
    print("Updating yt-dlp to latest version...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-U", "yt-dlp"], 
                  stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    
    # Check if key dependencies are installed
    try:
        import torch
        print("Checking for demucs...")
        subprocess.run([sys.executable, "-m", "pip", "show", "demucs"], 
                       stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
        print("✅ Core dependencies already installed")
    except (ImportError, subprocess.CalledProcessError):
        print("🔄 Installing dependencies...")
        subprocess.run([sys.executable, "-m", "pip", "install", "demucs", "torch", "tqdm"])
        time.sleep(2)  # Give time for imports to become available
    
    # Check if ffmpeg is installed
    try:
        subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("✅ FFmpeg already installed")
    except (subprocess.CalledProcessError, FileNotFoundError):
        print("🔄 Installing FFmpeg...")
        subprocess.run(["apt-get", "update", "-qq"])
        subprocess.run(["apt-get", "install", "-y", "-qq", "ffmpeg"])
    
    print("✅ Setup complete!")
    return True

# Helper functions for displaying and downloading results
def create_download_button(file_path, button_text=None):
    """Create a download button for a file."""
    from IPython.display import HTML
    import base64
    import os
    
    if button_text is None:
        button_text = f"Download {os.path.basename(file_path)}"
        
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    
    # Create a download link
    button_html = f'''
    <a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">
        <button style="font-size: 14px; padding: 5px 15px; background-color: #4CAF50; color: white; 
                 border: none; border-radius: 4px; cursor: pointer;">
            {button_text}
        </button>
    </a>
    '''
    
    return HTML(button_html)

def display_video(file_path):
    """Display a video player that works in Colab."""
    from IPython.display import HTML
    import base64
    from pathlib import Path
    import os
    
    # For very large files, use a simple preview image instead
    file_size = Path(file_path).stat().st_size
    if file_size > 50 * 1024 * 1024:  # if larger than 50MB
        return HTML(f'''
        <div style="background-color: #f8f9fa; padding: 12px; border-radius: 4px; text-align: center;">
            <p>Video file is large ({file_size/1024/1024:.1f} MB). Please use the download button below.</p>
        </div>
        ''')
    
    # For smaller files, encode and display directly
    video_type = "video/mp4" if file_path.endswith("mp4") else "video/webm"
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    
    return HTML(f'''
    <video width="640" height="360" controls>
        <source src="data:{video_type};base64,{b64}" type="{video_type}">
        Your browser does not support the video tag.
    </video>
    ''')

# Run setup
ensure_setup_complete()

# Import necessary libraries
import os
import subprocess
import torch
import glob
from IPython.display import display, HTML, Audio
from google.colab import files
from tqdm.notebook import tqdm

#@markdown ### Select Input Source
input_source = "Scan Input Directory" #@param ["Scan Input Directory", "Upload File", "YouTube URL"]

# Initialize variables
input_files = []
video_id = None

if input_source == "Scan Input Directory":
    #@markdown Files in the input directory will be automatically processed
    
    # Scan for media files in the input directory
    media_extensions = ['.mp3', '.wav', '.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4a', '.flac', '.ogg']
    input_files = []
    
    for ext in media_extensions:
        input_files.extend(glob.glob(f"input/*{ext}"))
    
    if not input_files:
        print("No media files found in the input directory.")
        
        # Allow uploading if no files found
        print("Would you like to upload files now?")
        upload_now = True #@param {type:"boolean"}
        
        if upload_now:
            print("Please upload audio or video files:")
            uploaded = files.upload()
            
            for filename, content in uploaded.items():
                file_path = os.path.join('input', filename)
                if not os.path.exists(file_path):
                    with open(file_path, 'wb') as f:
                        f.write(content)
                input_files.append(file_path)
    
    # Show files that will be processed
    print(f"Found {len(input_files)} files to process:")
    for file in input_files:
        print(f"  - {os.path.basename(file)}")

elif input_source == "Upload File":
    #@markdown Upload audio or video files to process:
    print("Please upload audio or video files:")
    uploaded = files.upload()
    
    if not uploaded:
        raise ValueError("No files were uploaded")
    
    # Process the uploaded files
    for filename, content in uploaded.items():
        # Save to input directory
        file_path = os.path.join('input', filename)
        
        # If already saved by upload widget, no need to write again
        if not os.path.exists(file_path):
            with open(file_path, 'wb') as f:
                f.write(content)
        
        input_files.append(file_path)
    
    print(f"Uploaded {len(input_files)} files")

elif input_source == "YouTube URL":
    #@markdown Enter YouTube URL:
    youtube_url = "" #@param {type:"string"}
    
    if not youtube_url.strip():
        raise ValueError("Please enter a YouTube URL")
        
    print("⚠️ Note: YouTube downloads may fail due to restrictions.")
    print("For more reliable results, download the video locally and upload it.")
    print("See the instructions at the top of the notebook.")
    
    print(f"\nAttempting to download from YouTube: {youtube_url}")
    
    # Extract video ID for naming
    if "youtube.com/watch?v=" in youtube_url:
        video_id = youtube_url.split("youtube.com/watch?v=")[1].split("&")[0]
    elif "youtu.be/" in youtube_url:
        video_id = youtube_url.split("youtu.be/")[1].split("?")[0]
    else:
        video_id = "youtube_video"  # Fallback name
    
    # Try downloading with yt-dlp
    input_file_path = f"input/{video_id}.wav"
    try:
        cmd = [
            "yt-dlp",
            "--extract-audio",
            "--audio-format", "wav",
            "--output", f"input/{video_id}.%(ext)s",
            "--no-check-certificate",
            "--geo-bypass",
            youtube_url
        ]
        
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        if result.returncode == 0 and os.path.exists(input_file_path):
            print("✅ Successfully downloaded from YouTube")
            input_files = [input_file_path]
        else:
            print("❌ YouTube download failed with error:")
            print(result.stderr)
            print("\nPlease download the video locally and upload it instead.")
            input_files = []
    except Exception as e:
        print(f"❌ Failed to download from YouTube: {str(e)}")
        input_files = []

#@markdown ### Demucs Options
model = "htdemucs" #@param ["htdemucs", "htdemucs_ft", "htdemucs_6s"]
device = "cuda" if torch.cuda.is_available() else "cpu" #@param ["cpu", "cuda"]

#@markdown ### Output Options
create_video = True #@param {type:"boolean"}
output_format = "mp4" #@param ["mp4", "wav"]

if not input_files:
    print("No files to process. Please upload files or provide a valid YouTube URL.")
else:
    # Process all files in batch
    results = []
    
    for file_path in tqdm(input_files, desc="Processing files"):
        print(f"\nProcessing {os.path.basename(file_path)}...")
        
        # Extract ID from filename
        file_id = os.path.splitext(os.path.basename(file_path))[0]
        
        # Check if it's a video file
        has_video = file_path.lower().endswith((".mp4", ".mov", ".avi", ".mkv", ".webm"))
        
        # Run Demucs to extract vocals
        demucs_cmd = [
            "demucs",
            "--two-stems", "vocals",
            "-n", model,
            "--device", device,
            "-o", "tmp",
            file_path
        ]
        
        try:
            subprocess.run(demucs_cmd, check=True)
            
            # Get path to extracted vocals
            vocals_path = f"tmp/htdemucs/{file_id}/vocals.wav"
            
            # Check alternate paths if default not found
            if not os.path.exists(vocals_path):
                potential_paths = glob.glob(f"tmp/*/{file_id}/vocals.wav")
                if potential_paths:
                    vocals_path = potential_paths[0]
                    print(f"Found vocals at alternate path: {vocals_path}")
                else:
                    print(f"⚠️ Could not find extracted vocals for {file_id}, skipping...")
                    continue
            
            # Create output file
            if create_video and output_format == "mp4" and has_video:
                output_path = f"output/{file_id}_vocals.mp4"
                
                # Replace audio in original video
                print(f"Creating video for {file_id} (replacing original audio)...")
                ffmpeg_cmd = [
                    "ffmpeg", "-y",
                    "-i", file_path,
                    "-i", vocals_path,
                    "-c:v", "copy",
                    "-c:a", "aac",
                    "-b:a", "192k",
                    "-map", "0:v:0",
                    "-map", "1:a:0",
                    "-shortest",
                    output_path
                ]
                
                subprocess.run(ffmpeg_cmd, check=True)
            elif create_video and output_format == "mp4":
                # For audio files or when video option is selected with audio input
                output_path = f"output/{file_id}_vocals.mp4"
                
                # Create video with black background
                print(f"Creating video for {file_id} (with black background)...")
                ffmpeg_cmd = [
                    "ffmpeg", "-y",
                    "-f", "lavfi",
                    "-i", "color=c=black:s=1280x720:r=30",
                    "-i", vocals_path,
                    "-c:v", "libx264",
                    "-c:a", "aac",
                    "-b:a", "192k",
                    "-shortest",
                    output_path
                ]
                
                subprocess.run(ffmpeg_cmd, check=True)
            else:
                # Just copy the audio
                output_path = f"output/{file_id}_vocals.wav"
                print(f"Creating audio file for {file_id}...")
                subprocess.run(["cp", vocals_path, output_path], check=True)
            
            # Add to results
            results.append({
                "input": file_path,
                "output": output_path,
                "id": file_id
            })
            
            print(f"✅ Successfully processed {file_id}")
            
        except Exception as e:
            print(f"❌ Error processing {file_id}: {str(e)}")
    
    # Display results
    if results:
        print(f"\nProcessed {len(results)} files successfully:")
        for result in results:
            print(f"  - {os.path.basename(result['output'])}")
        
        # Create a summary of all the files
        print("\nResults:")
        for i, result in enumerate(results):
            output_file = result["output"]
            print(f"\n{i+1}. {os.path.basename(output_file)}:")
            
            # Create download button for each file
            display(create_download_button(output_file))
            
            # Show preview of first few files only to avoid cluttering the notebook
            if i < 3:  # Show first 3 files
                if output_file.endswith('.mp4'):
                    display(display_video(output_file))
                else:
                    display(Audio(output_file))
        
        # Provide option to download all files
        print("\nYou can download all files at once using the cell below:")
    else:
        print("No files were successfully processed. Please check the errors above and try again.")

In [None]:
#@title ## Download All Processed Files

import glob
import os
from google.colab import files

output_files = glob.glob("output/*")

if not output_files:
    print("No processed files found. Please run the processing cell first.")
else:
    print(f"Found {len(output_files)} files:")
    for file_path in output_files:
        print(f"  - {os.path.basename(file_path)}")
    
    print("\nDownloading files... (this may take a while for large files)")
    for file_path in output_files:
        files.download(file_path)

In [None]:
#@title ## Cleanup Temporary Files (Optional)

import shutil
import os

#@markdown Select which files to delete:
delete_input = False #@param {type:"boolean"}
delete_output = False #@param {type:"boolean"}
delete_temp = True #@param {type:"boolean"}

if delete_input:
    shutil.rmtree("input", ignore_errors=True)
    os.makedirs("input", exist_ok=True)
    print("Deleted input files")
    
if delete_output:
    shutil.rmtree("output", ignore_errors=True)
    os.makedirs("output", exist_ok=True)
    print("Deleted output files")
    
if delete_temp:
    shutil.rmtree("tmp", ignore_errors=True)
    os.makedirs("tmp", exist_ok=True)
    print("Deleted temporary files")

print("Cleanup complete!")