<center>
  <h1>salsa-milk - remove music from media</h1>
</center>

This notebook extracts vocals from audio and video files using the Demucs AI model.

## Instructions

1. Upload your audio or video files using the file browser on the left sidebar
2. Optionally, paste YouTube URLs in the provided field
3. Run the cell below to process all files
4. Download the processed files with isolated vocals

**Note**: Large files will take longer to process. GPU acceleration will be used automatically if available.

In [None]:
#@title ## Extract Vocals from Media Files

import os
import sys
import subprocess
import glob
import shutil
import time
import re
from tqdm.notebook import tqdm
from IPython.display import display, HTML, Audio
from google.colab import files

os.makedirs('output', exist_ok=True)
os.makedirs('tmp', exist_ok=True)
print("✓ Created working directories")

# Function to install and check dependencies with progress updates
def setup_environment():
    print("Setting up environment...")
    
    # Check Python version
    python_version = sys.version_info
    print(f"Using Python {python_version[0]}.{python_version[1]}.{python_version[2]}")
    
    # Check if ffmpeg is installed
    try:
        subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("✅ FFmpeg already installed")
    except (subprocess.CalledProcessError, FileNotFoundError):
        print("🔄 Installing FFmpeg...")
        subprocess.run(["apt-get", "update", "-qq"], check=True)
        subprocess.run(["apt-get", "install", "-y", "-qq", "ffmpeg"], check=True)
        print("✅ FFmpeg installed successfully")
    
    # Install/update dependencies with progress indicators
    dependencies = {
        "yt-dlp": "latest",  # Always get the latest version
        "demucs": "4.0.1",
        "torch": "2.6.0",
        "tqdm": "4.67.1",
        "ffmpeg-python": "0.2.0"
    }
    
    for package, version in dependencies.items():
        if version == "latest":
            print(f"🔄 Updating {package} to latest version...")
            result = subprocess.run(
                [sys.executable, "-m", "pip", "install", "-U", package],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
        else:
            print(f"🔄 Installing {package} {version}...")
            result = subprocess.run(
                [sys.executable, "-m", "pip", "install", f"{package}>={version}"],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
        
        if result.returncode == 0:
            print(f"✅ {package} installed/updated successfully")
        else:
            print(f"⚠️ {package} installation had issues: {result.stderr[:100]}...")
    
    print("✅ Environment setup complete!")
    return True

# Function to create a download button
def create_download_button(file_path, button_text=None):
    import base64
    
    if button_text is None:
        button_text = f"Download {os.path.basename(file_path)}"
    
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    
    button_html = f'''
    <a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">
        <button style="font-size: 14px; padding: 5px 15px; background-color: #4CAF50; color: white; 
                 border: none; border-radius: 4px; cursor: pointer;">
            {button_text}
        </button>
    </a>
    '''
    
    return HTML(button_html)

# Function to display video preview
def display_video(file_path):
    import base64
    from pathlib import Path
    
    file_size = Path(file_path).stat().st_size
    if file_size > 50 * 1024 * 1024:  # if larger than 50MB
        return HTML(f'''
        <div style="background-color: #f8f9fa; padding: 12px; border-radius: 4px; text-align: center;">
            <p>Video file is large ({file_size/1024/1024:.1f} MB). Please use the download button below.</p>
        </div>
        ''')
    
    video_type = "video/mp4" if file_path.endswith("mp4") else "video/webm"
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    
    return HTML(f'''
    <video width="640" height="360" controls>
        <source src="data:{video_type};base64,{b64}" type="{video_type}">
        Your browser does not support the video tag.
    </video>
    ''')

# Function to download from YouTube
def download_from_youtube(urls):
    video_files = []
    
    if not urls.strip():
        return video_files
    
    # Split multiple URLs
    url_list = re.split(r'\s+', urls.strip())
    
    for url in url_list:
        if not url.strip():
            continue
            
        print(f"Downloading from YouTube: {url}")
        
        # Extract video ID for naming
        if "youtube.com/watch?v=" in url:
            video_id = url.split("youtube.com/watch?v=")[1].split("&")[0]
        elif "youtu.be/" in url:
            video_id = url.split("youtu.be/")[1].split("?")[0]
        else:
            # Generate a timestamp-based ID if we can't extract one
            video_id = f"yt_{int(time.time())}"
        
        # Download video
        try:
            video_output = f"{video_id}.mp4"
            video_cmd = [
                "yt-dlp",
                "-f", "best", # Get best format that includes video
                "--output", video_output,
                "--no-check-certificate",
                "--geo-bypass",
                url
            ]
            
            print(f"Downloading video from: {url}")
            video_result = subprocess.run(video_cmd, capture_output=True, text=True)
            
            if video_result.returncode == 0 and os.path.exists(video_output):
                print(f"✅ Successfully downloaded {video_id}")
                video_files.append(video_output)
            else:
                print(f"❌ Failed to download {url}, skipping...")
                print(f"Error: {video_result.stderr[:200]}...")
        except Exception as e:
            print(f"❌ Error downloading {url}: {str(e)}")
    
    return video_files

# Main process function
def process_files(input_files):
    results = []
    
    # Import torch (only after installation)
    import torch
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    for file_path in tqdm(input_files, desc="Processing files"):
        print(f"\nProcessing {os.path.basename(file_path)}...")
        
        # Extract ID from filename
        file_id = os.path.splitext(os.path.basename(file_path))[0]
        
        # Check if it's a video file
        has_video = file_path.lower().endswith((".mp4", ".mov", ".avi", ".mkv", ".webm"))
        
        # Run Demucs to extract vocals
        demucs_cmd = [
            "demucs",
            "--two-stems", "vocals",
            "-n", "htdemucs",  # Always use htdemucs model
            "--device", device,
            "-o", "tmp",
            file_path
        ]
        
        try:
            subprocess.run(demucs_cmd, check=True)
            
            # Get path to extracted vocals
            vocals_path = f"tmp/htdemucs/{file_id}/vocals.wav"
            
            # Check alternate paths if default not found
            if not os.path.exists(vocals_path):
                potential_paths = glob.glob(f"tmp/*/{file_id}/vocals.wav")
                if potential_paths:
                    vocals_path = potential_paths[0]
                    print(f"Found vocals at alternate path: {vocals_path}")
                else:
                    print(f"⚠️ Could not find extracted vocals for {file_id}, skipping...")
                    continue
            
            # Get file extension for output
            if has_video:
                # Video files always output as MP4
                output_ext = "mp4"
                output_path = f"output/{file_id}_vocals.{output_ext}"
                
                # Replace audio in original video
                print(f"Creating video with isolated vocals...")
                ffmpeg_cmd = [
                    "ffmpeg", "-y",
                    "-i", file_path,
                    "-i", vocals_path,
                    "-c:v", "copy",
                    "-c:a", "aac",
                    "-b:a", "192k",
                    "-map", "0:v:0",
                    "-map", "1:a:0",
                    "-shortest",
                    output_path
                ]
            else:
                # For audio files, keep the original extension
                original_ext = os.path.splitext(file_path)[1][1:].lower()
                output_ext = original_ext if original_ext else "wav"
                output_path = f"output/{file_id}_vocals.{output_ext}"
                
                print(f"Creating audio file with isolated vocals...")
                
                # Choose codec based on output format
                codec = "copy"
                if output_ext in ["mp3"]:
                    codec = "libmp3lame"
                elif output_ext in ["aac", "m4a"]:
                    codec = "aac"
                elif output_ext in ["ogg", "opus"]:
                    codec = "libopus"
                
                ffmpeg_cmd = [
                    "ffmpeg", "-y",
                    "-i", vocals_path,
                    "-c:a", codec,
                    "-b:a", "192k",
                    output_path
                ]
            
            subprocess.run(ffmpeg_cmd, check=True)
            
            # Add to results
            results.append({
                "input": file_path,
                "output": output_path,
                "id": file_id
            })
            
            print(f"✅ Successfully processed {file_id}")
            
        except Exception as e:
            print(f"❌ Error processing {file_id}: {str(e)}")
    
    return results

# Function to clean up temporary files
def cleanup_temp_files():
    if os.path.exists("tmp"):
        shutil.rmtree("tmp", ignore_errors=True)
        os.makedirs("tmp", exist_ok=True)
        print("✓ Cleaned up temporary files")

# Main execution

# Setup the environment
setup_environment()

#@markdown ### PRESS THE ▶️ ON THE LEFT TO START
#@markdown (Optional) Paste one or more YouTube URLs, separated by spaces or new lines
youtube_urls = "" #@param {type:"string"}

# Collect files to process
files_to_process = []

# First, download any YouTube videos
if youtube_urls.strip():
    youtube_files = download_from_youtube(youtube_urls)
    files_to_process.extend(youtube_files)

# Now scan for all media files
media_extensions = ['.mp3', '.wav', '.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4a', '.flac', '.ogg']
for ext in media_extensions:
    local_files = glob.glob(f"*{ext}")
    files_to_process.extend(local_files)

# If no files found, prompt for upload
if not files_to_process:
    print("No media files found. Please upload some files:")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        if any(filename.lower().endswith(ext) for ext in media_extensions):
            files_to_process.append(filename)

# Process all files
if files_to_process:
    print(f"Found {len(files_to_process)} files to process:")
    for file in files_to_process:
        print(f"  - {os.path.basename(file)}")
    
    # Process the files
    results = process_files(files_to_process)
    
    # Display results
    if results:
        print(f"\nProcessed {len(results)} files successfully:")
        
        for i, result in enumerate(results):
            output_file = result["output"]
            print(f"\n{i+1}. {os.path.basename(output_file)}:")
            
            # Create download button
            display(create_download_button(output_file))
            
            # Show preview for first 3 files
            if i < 3:
                # Show audio player for audio files, video player for videos
                if output_file.lower().endswith((".mp3", ".wav", ".ogg", ".flac", ".m4a")):
                    display(Audio(output_file))
                else:
                    display(display_video(output_file))
        
        # Download all files option
        print("\nTo download all files:")
        for output_file in [r["output"] for r in results]:
            files.download(output_file)
    else:
        print("No files were successfully processed. Please check the errors above.")
else:
    print("No media files to process. Please upload files or provide YouTube URLs.")

# Clean up
cleanup_temp_files()