In [None]:
# Google Colab Whisper API Server - GPU Optimized Version
# This version properly utilizes T4 GPU and handles tensor errors

import warnings
warnings.filterwarnings('ignore')

import os
import sys
import logging

# Disable logging completely
logging.getLogger().setLevel(logging.CRITICAL)
logging.getLogger('werkzeug').setLevel(logging.CRITICAL)
os.environ['WERKZEUG_RUN_MAIN'] = 'true'

# Check GPU availability first
print("🔍 Checking GPU availability...")
try:
    import torch
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        print(f"✅ GPU detected: {gpu_name}")
        print(f"🎯 CUDA version: {torch.version.cuda}")
        print(f"💾 GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        DEVICE = "cuda"
    else:
        print("⚠️ No GPU detected, using CPU")
        DEVICE = "cpu"
except ImportError:
    print("📦 PyTorch not found, will install...")
    DEVICE = "cpu"

# Install packages with GPU support
print("📦 Installing required packages with GPU support...")
try:
    import subprocess

    # Install PyTorch with CUDA support first
    if DEVICE == "cuda" or torch.cuda.is_available():
        print("🚀 Installing PyTorch with CUDA support...")
        result = subprocess.run([
            sys.executable, '-m', 'pip', 'install',
            'torch', 'torchaudio', '--index-url', 'https://download.pytorch.org/whl/cu118'
        ], capture_output=True, text=True, timeout=300)
        if result.returncode == 0:
            print("✅ PyTorch with CUDA installed")
        else:
            print("⚠️ CUDA PyTorch installation had issues, using CPU version")

    # Install other packages
    result = subprocess.run([
        sys.executable, '-m', 'pip', 'install',
        'flask', 'flask-cors', 'openai-whisper', 'pyngrok', 'python-multipart',
        'requests', 'ffmpeg-python', 'numpy>=1.21.0'
    ], capture_output=True, text=True, timeout=300)

    if result.returncode != 0:
        print(f"⚠️ Package installation had warnings: {result.stderr}")
    else:
        print("✅ Packages installed successfully")

except Exception as e:
    print(f"❌ Package installation failed: {e}")
    sys.exit(1)

import time
import tempfile
import socket
from datetime import datetime
from pathlib import Path
import threading
import json
from werkzeug.serving import make_server
import numpy as np

try:
    from flask import Flask, request, jsonify
    from flask_cors import CORS
    import whisper
    import torch
    from pyngrok import ngrok
    import requests
    import ffmpeg
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("🔄 Retrying package installation...")
    subprocess.run([sys.executable, '-m', 'pip', 'install', '--force-reinstall', 'flask', 'flask-cors'])
    from flask import Flask, request, jsonify
    from flask_cors import CORS

# Re-check GPU after PyTorch installation
try:
    if torch.cuda.is_available():
        DEVICE = "cuda"
        torch.cuda.empty_cache()  # Clear GPU cache
        print(f"🎯 Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        DEVICE = "cpu"
        print("📱 Using CPU")
except:
    DEVICE = "cpu"
    print("📱 Using CPU (fallback)")

# Constants
MAX_FILE_SIZE = 300 * 1024 * 1024  # 100MB
DEFAULT_MAX_LENGTH = 42
DEFAULT_MAX_LINES = 2
VALID_AUDIO_EXTENSIONS = ['.mp3', '.wav', '.flac', '.m4a', '.ogg', '.webm', '.mp4', '.avi', '.mov']

# Configuration - Set your static domain here
STATIC_DOMAIN = "terrier-hip-sunbeam.ngrok-free.app"  # Replace with your actual static domain
NGROK_AUTH_TOKEN = "6mQy4iw5BXtaV2BJQe7vR_3JBuZQKpB1pxweCMv92tQ"  # Your ngrok auth token

# Global variables
whisper_models = {}
server_instance = None

def find_free_port():
    """Find a free port to use"""
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(('', 0))
            s.listen(1)
            port = s.getsockname()[1]
        return port
    except Exception as e:
        print(f"⚠️ Error finding free port: {e}")
        return 5000  # fallback port

def create_app():
    """Create Flask app with minimal logging"""
    app = Flask(__name__)
    CORS(app)

    # Disable Flask logging
    app.logger.disabled = True
    app.logger.setLevel(logging.CRITICAL)

    return app

def load_whisper_model(model_name="turbo"):
    """Load and cache whisper model with GPU support"""
    try:
        if model_name not in whisper_models:
            print(f"📥 Loading Whisper model: {model_name} on {DEVICE}")

            # Clear GPU cache before loading
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

            # Load model with device specification
            model = whisper.load_model(model_name, device=DEVICE)
            whisper_models[model_name] = model

            if DEVICE == "cuda":
                print(f"🎯 Model loaded on GPU: {torch.cuda.get_device_name(0)}")
                print(f"💾 GPU memory used: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
            else:
                print(f"📱 Model loaded on CPU")

        return whisper_models[model_name]
    except Exception as e:
        print(f"❌ Error loading model {model_name}: {e}")
        # Fallback to CPU if GPU fails
        if DEVICE == "cuda":
            print("🔄 Retrying on CPU...")
            try:
                model = whisper.load_model(model_name, device="cpu")
                whisper_models[model_name] = model
                return model
            except Exception as e2:
                print(f"❌ CPU fallback also failed: {e2}")
        raise

def is_valid_audio_file(filename):
    """Check if file has valid audio extension"""
    ext = Path(filename).suffix.lower()
    return ext in VALID_AUDIO_EXTENSIONS

def preprocess_audio(file_path):
    """Preprocess audio to handle potential issues"""
    try:
        # Check if file exists and has content
        if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
            raise ValueError("Audio file is empty or doesn't exist")

        # Use ffmpeg to convert to a standard format
        output_path = file_path + "_processed.wav"

        try:
            (
                ffmpeg
                .input(file_path)
                .output(output_path, acodec='pcm_s16le', ac=1, ar='16000')
                .overwrite_output()
                .run(capture_stdout=True, capture_stderr=True)
            )

            # Check if processed file was created and has content
            if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
                return output_path
            else:
                print("⚠️ Processed file is empty, using original")
                return file_path

        except ffmpeg.Error as e:
            print(f"⚠️ FFmpeg processing failed: {e}, using original file")
            return file_path

    except Exception as e:
        print(f"⚠️ Audio preprocessing failed: {e}")
        return file_path

def convert_to_srt_time(seconds):
    """Convert seconds to SRT time format (HH:MM:SS,mmm)"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    return f"{hours:02d}:{minutes:02d}:{secs:06.3f}".replace('.', ',')

def split_text_for_srt(text, max_length, max_lines):
    """Split text into lines for SRT format"""
    words = text.split()
    lines = []
    current_line = []

    for word in words:
        test_line = ' '.join(current_line + [word])
        if len(test_line) > max_length and current_line:
            lines.append(' '.join(current_line))
            current_line = [word]

            if len(lines) >= max_lines:
                remaining_words = words[words.index(word):]
                if lines:
                    lines[-1] += ' ' + ' '.join(remaining_words)
                break
        else:
            current_line.append(word)

    if current_line and len(lines) < max_lines:
        lines.append(' '.join(current_line))

    return lines[:max_lines]

def generate_srt(segments, max_length, max_lines):
    """Generate SRT content from segments"""
    srt_content = []

    for i, segment in enumerate(segments, 1):
        lines = split_text_for_srt(segment['text'], max_length, max_lines)

        srt_content.append(str(i))
        srt_content.append(f"{convert_to_srt_time(segment['start'])} --> {convert_to_srt_time(segment['end'])}")
        srt_content.extend(lines)
        srt_content.append("")

    return '\n'.join(srt_content).strip()

# Create Flask app
app = create_app()

@app.route('/health', methods=['GET'])
def health_check():
    gpu_info = {}
    if DEVICE == "cuda":
        try:
            gpu_info = {
                'gpu_name': torch.cuda.get_device_name(0),
                'gpu_memory_allocated': f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB",
                'gpu_memory_total': f"{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB"
            }
        except:
            gpu_info = {'gpu_error': 'Could not get GPU info'}

    return jsonify({
        'status': 'healthy',
        'timestamp': datetime.now().isoformat(),
        'version': '2.0.0-gpu',
        'device': DEVICE,
        'models_loaded': list(whisper_models.keys()),
        **gpu_info
    })

@app.route('/models', methods=['GET'])
def list_models():
    available_models = ['tiny', 'base', 'small', 'medium', 'large','turbo']
    return jsonify({
        'models': [f'ggml-{model}.bin' for model in available_models],
        'default': 'ggml-medium.bin',
        'loaded': list(whisper_models.keys()),
        'device': DEVICE,
        'recommended': 'ggml-turbo.bin' if DEVICE == "cuda" else 'ggml-medium.bin'
    })

@app.route('/transcribe', methods=['POST'])
def transcribe_audio():
    processed_file = None
    try:
        if 'audio' not in request.files:
            return jsonify({
                'error': 'missing_audio_file',
                'message': "No audio file provided in 'audio' field"
            }), 400

        file = request.files['audio']
        if file.filename == '':
            return jsonify({
                'error': 'missing_audio_file',
                'message': "No audio file provided in 'audio' field"
            }), 400

        if not is_valid_audio_file(file.filename):
            return jsonify({
                'error': 'invalid_file_type',
                'message': 'Supported formats: mp3, wav, flac, m4a, ogg, webm, mp4, avi, mov'
            }), 400

        # Parse parameters
        language = request.form.get('language', '')
        model = request.form.get('model', 'ggml-turbo.bin')
        output_srt = request.form.get('output_srt', '').lower() in ['true', '1']
        max_length = int(request.form.get('max_length', DEFAULT_MAX_LENGTH))
        max_lines = int(request.form.get('max_lines', DEFAULT_MAX_LINES))

        # Extract model name
        model_name = model.replace('ggml-', '').replace('.bin', '')
        if model_name not in ['tiny', 'base', 'small', 'medium', 'large','turbo']:
            model_name = 'turbo'

        # Use turbo model on GPU for better accuracy
        if DEVICE == "cuda" and model_name == 'turbo':
            print("🎯 Using turbo model on GPU for better accuracy")
            model_name = 'turbo'

        # Save temp file
        with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
            file.save(temp_file.name)
            temp_path = temp_file.name

        try:
            # Preprocess audio to handle potential issues
            print(f"🔧 Preprocessing audio file: {file.filename}")
            processed_file = preprocess_audio(temp_path)

            # Load model and transcribe
            print(f"🎯 Processing {file.filename} with model {model_name} on {DEVICE}")
            model_obj = load_whisper_model(model_name)
            start_time = time.time()

            # Transcription options with better error handling
            transcribe_options = {
                'fp16': DEVICE == "cuda",  # Use FP16 on GPU for speed
                'verbose': False,  # Reduce output
            }

            if language:
                transcribe_options['language'] = language

            # Clear GPU cache before transcription
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

            result = model_obj.transcribe(processed_file, **transcribe_options)
            duration = time.time() - start_time

            print(f"✅ Transcription completed in {duration:.2f}s on {DEVICE}")

            # Clear GPU cache after transcription
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

            # Validate result
            if not result or 'text' not in result:
                raise ValueError("Transcription returned empty result")

            # Prepare response
            response = {
                'text': result['text'].strip(),
                'language': result.get('language', language),
                'duration': duration,
                'timestamp': datetime.now().isoformat(),
                'filename': file.filename,
                'format': 'srt' if output_srt else 'text',
                'device': DEVICE,
                'model': model_name
            }

            if output_srt and 'segments' in result:
                srt_content = generate_srt(result['segments'], max_length, max_lines)
                response['srt'] = srt_content

            return jsonify(response)

        finally:
            # Cleanup temp files
            for path in [temp_path, processed_file]:
                if path and os.path.exists(path):
                    try:
                        os.unlink(path)
                    except:
                        pass

    except Exception as e:
        print(f"❌ Transcription error: {e}")

        # Clear GPU cache on error
        if DEVICE == "cuda":
            try:
                torch.cuda.empty_cache()
            except:
                pass

        return jsonify({
            'error': 'transcription_failed',
            'message': str(e),
            'device': DEVICE
        }), 500

def test_flask_locally(port):
    """Test if Flask is running locally"""
    try:
        response = requests.get(f"http://localhost:{port}/health", timeout=2)
        return response.status_code == 200
    except:
        return False

def start_server():
    """Start the Flask server with better error handling"""
    global server_instance
    port = find_free_port()
    print(f"🚀 Starting server on port {port}")

    try:
        # Create server instance
        server_instance = make_server('0.0.0.0', port, app, threaded=True)

        # Start server in thread
        def run_server():
            try:
                print(f"🔧 Flask server starting on 0.0.0.0:{port}")
                server_instance.serve_forever()
            except Exception as e:
                print(f"❌ Server error: {e}")

        server_thread = threading.Thread(target=run_server, daemon=True)
        server_thread.start()

        # Wait for server to be ready
        print("⏳ Waiting for Flask to start...")
        max_attempts = 20
        for attempt in range(max_attempts):
            if test_flask_locally(port):
                print("✅ Flask server is ready!")
                break
            time.sleep(0.5)
            if attempt == max_attempts - 1:
                print("❌ Flask server failed to start")
                return None, None, port

        # Setup ngrok tunnel with static domain
        try:
            print("🌐 Setting up ngrok tunnel with static domain...")

            # Set ngrok auth token
            ngrok.set_auth_token(NGROK_AUTH_TOKEN)

            # Create tunnel with static domain
            if STATIC_DOMAIN:
                print(f"🔗 Using static domain: {STATIC_DOMAIN}")
                tunnel = ngrok.connect(port, hostname=STATIC_DOMAIN)
                url = f"https://{STATIC_DOMAIN}"
            else:
                print("🔗 Using dynamic domain")
                tunnel = ngrok.connect(port)
                url = tunnel.public_url

            print(f"🌐 Public URL: {url}")

            # Test public endpoint
            print("🧪 Testing public endpoint...")
            time.sleep(3)  # Give ngrok time to establish tunnel

            try:
                response = requests.get(f"{url}/health", timeout=15)
                if response.status_code == 200:
                    print("✅ Public endpoint is working!")
                    return tunnel, url, port
                else:
                    print(f"⚠️ Public endpoint returned: {response.status_code}")
            except Exception as e:
                print(f"⚠️ Public endpoint test failed: {e}")
                print("🔧 This might be normal for static domains - they may take longer to propagate")

            return tunnel, url, port

        except Exception as e:
            print(f"❌ Error setting up ngrok: {e}")
            print("🔧 Server is still running locally for testing")
            return None, f"http://localhost:{port}", port

    except Exception as e:
        print(f"❌ Error starting server: {e}")
        return None, None, None

def main():
    """Main function to start the server"""
    print("🎙️ Setting up GPU-Optimized Whisper API Server...")
    print(f"🎯 Device: {DEVICE}")

    if STATIC_DOMAIN:
        print(f"🔗 Static domain configured: {STATIC_DOMAIN}")
    else:
        print("⚠️ No static domain configured - will use dynamic domain")

    # Pre-load model to test installation
    try:
        print("🔍 Testing Whisper installation...")
        # Use smaller model for initial test
        test_model = "turbo" if DEVICE == "cuda" else "tiny"
        load_whisper_model(test_model)
        print("✅ Whisper is working correctly")
    except Exception as e:
        print(f"❌ Whisper test failed: {e}")
        return

    tunnel, public_url, port = start_server()

    if public_url:
        print(f"\n🎉 Server is running!")
        print(f"🔗 URL: {public_url}")
        print(f"🔗 Health check: {public_url}/health")
        print(f"🔗 Transcribe: {public_url}/transcribe")
        print(f"🔗 Models: {public_url}/models")
        print(f"🎯 Device: {DEVICE}")

        if DEVICE == "cuda":
            print(f"🚀 GPU acceleration enabled!")
            print(f"💡 Recommended model: turbo (better accuracy)")
        else:
            print(f"📱 CPU mode - consider smaller models for speed")

        print(f"\n📋 Example usage:")
        print(f'curl -X POST "{public_url}/transcribe" \\')
        print('  -F "audio=@your_file.mp3" \\')
        print('  -F "output_srt=true" \\')
        print('  -F "language=en" \\')
        print('  -F "model=ggml-turbo.bin"')

        print("\n📝 Server is ready to accept audio files!")
        print("⏹️ The server will run continuously. Interrupt the kernel to stop.")

        try:
            # Keep running with periodic GPU memory cleanup
            while True:
                time.sleep(30)

                # Clear GPU cache periodically
                if DEVICE == "cuda":
                    try:
                        torch.cuda.empty_cache()
                    except:
                        pass

                # Health check
                try:
                    if tunnel and not test_flask_locally(port):
                        print("⚠️ Local server seems to have stopped")
                        break
                except:
                    pass
        except KeyboardInterrupt:
            print("\n🛑 Stopping server...")
            try:
                if tunnel:
                    ngrok.disconnect(tunnel)
                    ngrok.kill()
                if server_instance:
                    server_instance.shutdown()

                # Final GPU cleanup
                if DEVICE == "cuda":
                    torch.cuda.empty_cache()
            except:
                pass
            print("✅ Server stopped")
    else:
        print("❌ Failed to start server")
        print("🔧 Check the error messages above for troubleshooting")

# Run the server
if __name__ == "__main__":
    main()
else:
    # If running in notebook, start automatically
    main()