# üéôÔ∏è Qwen3-TTS for Kaggle (Dual T4 - 2x Speed!)

**Uses BOTH T4 GPUs** - processes 2 chunks in parallel for ~2x faster generation!

**Before running:**
1. Settings ‚Üí Accelerator ‚Üí **GPU T4 x2** (REQUIRED!)
2. Settings ‚Üí Internet ‚Üí **ON**
3. Run cells with **Shift+Enter** (one by one!)

In [None]:
!pip install -q qwen-tts flask flask-cors pyngrok soundfile numpy

In [None]:
import torch
from qwen_tts import Qwen3TTSModel

print("üîÑ Loading Qwen3-TTS on BOTH GPUs...")
print(f"   GPU 0: {torch.cuda.get_device_name(0)}")
print(f"   GPU 1: {torch.cuda.get_device_name(1)}")

# Load model on GPU 0
print("\nüì¶ Loading model on GPU 0...")
model_0 = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)
print("‚úÖ GPU 0 ready!")

# Load model on GPU 1
print("\nüì¶ Loading model on GPU 1...")
model_1 = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
    device_map="cuda:1",
    torch_dtype=torch.bfloat16,
)
print("‚úÖ GPU 1 ready!")

print("\nüöÄ Both GPUs loaded and ready for parallel processing!")

In [None]:
# PASTE YOUR NGROK TOKEN HERE!
NGROK_TOKEN = ""  # <-- Get from https://dashboard.ngrok.com/get-started/your-authtoken

from pyngrok import ngrok
if NGROK_TOKEN:
    ngrok.set_auth_token(NGROK_TOKEN)
    print("‚úÖ ngrok token set!")
else:
    print("‚ö†Ô∏è Enter your ngrok token above!")

In [None]:
from flask import Flask, request, send_file, jsonify, Response
from flask_cors import CORS
import soundfile as sf
import numpy as np
import io, base64, tempfile, os, json, re
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

app = Flask(__name__)
CORS(app)

CHUNK_SIZE = 1000  # 1K characters per chunk

# Thread-safe locks for each GPU
gpu_locks = [threading.Lock(), threading.Lock()]

def split_text(text, max_chars=CHUNK_SIZE):
    """Split text into chunks at sentence boundaries"""
    sentences = re.split(r'(?<=[.!?‡•§])\s+', text)
    chunks = []
    current = ""
    for s in sentences:
        if len(current) + len(s) <= max_chars:
            current += (" " if current else "") + s
        else:
            if current:
                chunks.append(current.strip())
            if len(s) > max_chars:
                words = s.split()
                current = ""
                for word in words:
                    if len(current) + len(word) + 1 <= max_chars:
                        current += (" " if current else "") + word
                    else:
                        if current:
                            chunks.append(current.strip())
                        current = word
            else:
                current = s
    if current:
        chunks.append(current.strip())
    return chunks if chunks else [text]

def generate_chunk(chunk_idx, chunk_text, language, ref_path, ref_text, gpu_id):
    """Generate audio for a single chunk on specified GPU"""
    model = model_0 if gpu_id == 0 else model_1
    
    with gpu_locks[gpu_id]:
        try:
            if ref_text and ref_text.strip():
                wavs, sr = model.generate_voice_clone(text=chunk_text, language=language, ref_audio=ref_path, ref_text=ref_text)
            else:
                wavs, sr = model.generate_voice_clone(text=chunk_text, language=language, ref_audio=ref_path, x_vector_only_mode=True)
            return (chunk_idx, wavs[0], sr, None)
        except Exception as e:
            return (chunk_idx, None, None, str(e))

@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        "status": "ok", 
        "model": "Qwen3-TTS-1.7B", 
        "gpu_0": torch.cuda.get_device_name(0),
        "gpu_1": torch.cuda.get_device_name(1),
        "mode": "dual-gpu-parallel",
        "chunk_size": CHUNK_SIZE
    })

@app.route('/api/tts', methods=['POST'])
def generate_tts():
    data = request.json
    text = data.get('text', 'Hello')
    language = data.get('language', 'English')
    ref_audio_b64 = data.get('ref_audio')
    ref_text = data.get('ref_text', '')
    stream = data.get('stream', False)

    def generate_with_progress():
        try:
            if not ref_audio_b64:
                yield f"data: {json.dumps({'type': 'error', 'message': 'Reference audio is required'})}\n\n"
                return
            
            chunks = split_text(text, CHUNK_SIZE)
            total_chunks = len(chunks)
            
            print(f"üéôÔ∏è Dual-GPU: {len(text)} chars ‚Üí {total_chunks} chunks")
            
            yield f"data: {json.dumps({'type': 'progress', 'current': 0, 'total': total_chunks, 'percent': 0, 'status': f'Preparing {total_chunks} chunks (2 GPUs)...'})}\n\n"
            
            # Decode reference audio
            audio_bytes = base64.b64decode(ref_audio_b64)
            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
                f.write(audio_bytes)
                ref_path = f.name
            
            yield f"data: {json.dumps({'type': 'progress', 'current': 0, 'total': total_chunks, 'percent': 5, 'status': 'Reference audio loaded'})}\n\n"
            
            # Process chunks in pairs (parallel on 2 GPUs)
            results = {}  # Store results by chunk index
            sample_rate = None
            completed = 0
            
            with ThreadPoolExecutor(max_workers=2) as executor:
                futures = {}
                
                for i, chunk in enumerate(chunks):
                    gpu_id = i % 2  # Alternate between GPU 0 and GPU 1
                    future = executor.submit(generate_chunk, i, chunk, language, ref_path, ref_text, gpu_id)
                    futures[future] = i
                
                for future in as_completed(futures):
                    chunk_idx, audio, sr, error = future.result()
                    
                    if error:
                        print(f"‚ùå Chunk {chunk_idx+1} failed: {error}")
                        yield f"data: {json.dumps({'type': 'error', 'message': f'Chunk {chunk_idx+1} failed: {error}'})}\n\n"
                        os.unlink(ref_path)
                        return
                    
                    results[chunk_idx] = audio
                    if sample_rate is None:
                        sample_rate = sr
                    
                    completed += 1
                    pct = int((completed / total_chunks) * 90) + 5
                    gpu_used = chunk_idx % 2
                    yield f"data: {json.dumps({'type': 'progress', 'current': completed, 'total': total_chunks, 'percent': pct, 'status': f'Chunk {chunk_idx+1} done (GPU{gpu_used}) - {completed}/{total_chunks}'})}\n\n"
            
            os.unlink(ref_path)
            
            yield f"data: {json.dumps({'type': 'progress', 'current': total_chunks, 'total': total_chunks, 'percent': 95, 'status': 'Merging audio in order...'})}\n\n"
            
            # Concatenate in ORDER (0, 1, 2, 3, ...)
            ordered_audio = [results[i] for i in range(total_chunks)]
            final_audio = np.concatenate(ordered_audio)
            
            buffer = io.BytesIO()
            sf.write(buffer, final_audio, sample_rate, format='WAV')
            buffer.seek(0)
            audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
            
            duration = len(final_audio) / sample_rate
            print(f"‚úÖ Generated {duration:.1f}s from {total_chunks} chunks (dual-GPU)")
            
            yield f"data: {json.dumps({'type': 'complete', 'audio': audio_b64, 'duration': round(duration, 1), 'chunks': total_chunks})}\n\n"
            
        except Exception as e:
            import traceback; traceback.print_exc()
            yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"

    if stream:
        return Response(generate_with_progress(), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'})
    else:
        return jsonify({"error": "Use stream=true"}), 400

@app.route('/api/clone', methods=['POST'])
def voice_clone():
    return generate_tts()

public_url = ngrok.connect(5000)
print("\n" + "="*50)
print(f"üöÄ QWEN3-TTS URL: {public_url}")
print(f"   Mode: DUAL-GPU (2x speed!)")
print(f"   Chunk size: {CHUNK_SIZE} chars")
print("="*50 + "\n")

app.run(port=5000)