In [None]:
from google import genai
import os

client =  genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))

myfile = client.files.upload(file='/home/skamalj/dev/tataplay/separated_audio/htdemucs/media_1_audio/vocals.wav')

response = client.models.generate_content(
  model='gemini-2.5-flash',
  contents=[
    """You are given an audio clip. 
    Task:
    1. Transcribe the speech in its original language.
    2. Translate it into Tamil, but:
       - If the original speaker uses words from another language (e.g., English words in Hindi),
         preserve those words exactly in the translated text. Do not re-translate them.
       - The translation must feel natural and conversational.
    3. Assume there are 3 speakers (Speaker 1, Speaker 2, Speaker 3).
    4. Perform speaker diarization: identify continuous segments by the same speaker.
    5. Return output strictly in valid JSON with this structure:

    {
      "speakers_count": <int>,
      "speak_segments": [
        {
          "speaker": "<speaker_id>",
          "start_time": "<float_seconds>",
          "end_time": "<float_seconds>",
          "original_text": "<transcribed_text_in_original_language>",
          "translated_text": "<translated_text_in_targeted_language_with_code_switch_words_preserved>",
          "ssml": "<speak>...</speak>  // built from translated_text only"
        },
        ...
      ],
      "metadata": {
        "translation_language": "<targeted language",
        "total_duration": "<float_seconds>",
        "notes": "pitch, pauses, tone captured where possible; code-switch words preserved"
      }
    }

    Rules:
    - Each segment must include both original_text and translated_text.
    - The translated_text must preserve any words from another language exactly as spoken.
      Example: If the speaker says "वैसे कल आपका स्माइल करना मेरे लिए काफी हिस्टोरिकल था",
      then translated_text could be: "<translated text." (with "smile" and "historical" kept in English).
    - The ssml field must be built from translated_text only.
    - Each segment must have its own <speak> block with a <voice> tag unique to the speaker.
    - Use <break> and <prosody> to reflect pauses, pitch, and tone.
    - Do not include any explanation outside the JSON.
    """,
    myfile,
  ]
)

### Helper Function

In [None]:
# @title Helper functions (just run that cell)

import contextlib
import wave
from IPython.display import Audio

file_index = 0

@contextlib.contextmanager
def wave_file(filename, channels=1, rate=24000, sample_width=2):
    with wave.open(filename, "wb") as wf:
        wf.setnchannels(channels)
        wf.setsampwidth(sample_width)
        wf.setframerate(rate)
        yield wf

def play_audio_blob(blob):
  global file_index
  file_index += 1

  fname = f'audio_{file_index}.wav'
  with wave_file(fname) as wav:
    wav.writeframes(blob.data)

  return Audio(fname, autoplay=True)

def play_audio(response):
    return play_audio_blob(response.candidates[0].content.parts[0].inline_data)

In [None]:
import json

# Example: Gemini response text
text = response.text

# Remove leading ```json and trailing ```
if text.startswith("```json"):
    text = text[len("```json"):].strip()
if text.endswith("```"):
    text = text[:-3].strip()

# Now parse as JSON
gemini_data = json.loads(text)

# save gemini_data to a file
with open('translated_diarized_data.json', 'w', encoding='utf-8') as f:
    json.dump(gemini_data, f, indent=2,ensure_ascii=False)


In [None]:
def create_wave_file(filename, pcm, channels=1, rate=24000, sample_width=2):
   with wave.open(filename, "wb") as wf:
      wf.setnchannels(channels)
      wf.setsampwidth(sample_width)
      wf.setframerate(rate)
      wf.writeframes(pcm)

In [None]:
from google.genai import types
from pathlib import Path

def synthesize_segment_tts(client, segment, output_dir="translated_segments", voice_name="Puck"):
    """
    Generate TTS audio for a single segment, keeping duration for lip-sync.

    segment: dict with keys
        - ssml: SSML string
        - start_time: float (seconds)
        - end_time: float (seconds)
        - speaker: string
    client: gemini client
    output_dir: where to save the segment audio
    voice_name: prebuilt Gemini TTS voice
    """
    import os
    Path(output_dir).mkdir(exist_ok=True)
    
    # Calculate duration from segment metadata
    duration = segment["end_time"] - segment["start_time"]
    
    # Prepare output file path
    seg_index = segment.get("index", 0)

    speaker = segment.get("speaker", "Speaker")
    out_file = Path(output_dir) / f"{seg_index}_{speaker.replace(' ', '_')}.wav"

    print(f"Processing segment {seg_index} with duration {duration:.3f} seconds to {out_file}")
    
    # Wrap the SSML into the TTS prompt
    tts_prompt = f"""
    You are given ssml with timing metadata. 
    Identify the target language from text. 
    Must keep original pauses and durations, which is {duration:.3f} seconds, intact for lip-sync. 
    Generate **audio output**.

    Here is the SSML:

    {segment['ssml']}
    """
    
    response = client.models.generate_content(
        model='gemini-2.5-pro-preview-tts',
        contents=tts_prompt,
        config=types.GenerateContentConfig(
            response_modalities=["AUDIO"],
            speech_config=types.SpeechConfig(
                voice_config=types.VoiceConfig(
                    prebuilt_voice_config=types.PrebuiltVoiceConfig(
                        voice_name=voice_name,
                    )
                )
            ),
        )
    )
    print(response.candidates[0])
    # Save the returned audio bytes
    if response.candidates[0].content:
        data = response.candidates[0].content.parts[0].inline_data.data # Gemini returns audio bytes in response.audio
        create_wave_file(str(out_file), data)


In [None]:
!rm translated_segments/*.wav

In [None]:
# Assume `gemini_data` is your parsed Gemini JSON with speak_segments
for idx, seg in enumerate(gemini_data["speak_segments"]):
    seg["index"] = idx
    synthesize_segment_tts(client, seg, voice_name="Puck")


In [None]:
import json
from pathlib import Path
import librosa
import soundfile as sf

# Paths
json_path = "translated_diarized_data.json"
tts_dir = Path("translated_segments")          # TTS-generated segments
adjusted_dir = Path("adjusted_segments")      # Save adjusted segments
adjusted_dir.mkdir(exist_ok=True)

# Load JSON
with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

segments = data["speak_segments"]


for idx, seg in enumerate(segments):
    speaker = seg.get("speaker", "Speaker")
    
    tts_file = tts_dir / f"{idx}_{speaker.replace(' ', '_')}.wav"
    if not tts_file.exists():
        print(f"Missing file: {tts_file}")
        continue

    # Target duration from JSON
    target_duration = seg["end_time"] - seg["start_time"]  # in seconds

    # Load audio
    y, sr = librosa.load(tts_file, sr=None)
    current_duration = len(y) / sr

    # Skip if already close enough
    if abs(current_duration - target_duration) < 0.01:
        adjusted_audio = y
    else:
        rate = current_duration / target_duration
        adjusted_audio = librosa.effects.time_stretch(y, rate=rate)

    # Save adjusted segment
    out_file = adjusted_dir / f"adjusted_seg_{idx}_{seg['speaker']}.wav"
    sf.write(out_file, adjusted_audio, sr)
    print(f"Saved adjusted segment: {out_file} (target: {target_duration:.2f}s, actual: {current_duration:.2f}s)")


### Switch to OpenVoice Environment

In [None]:
import json
from pathlib import Path

# ----------- Load Gemini segments -----------
with open("translated_diarized_data.json", "r", encoding="utf-8") as f:
    gemini_data = json.load(f)

segments = gemini_data.get("speak_segments", [])

print(f"Loaded {len(segments)} segments from file")

In [None]:
import torch
from openvoice.api import ToneColorConverter

ckpt_converter = 'checkpoints_v2/converter'
device="cuda:0" if torch.cuda.is_available() else "cpu"

tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')

In [None]:
import re
from pathlib import Path
from collections import defaultdict

from pydub import AudioSegment
import torch
import warnings

# OpenVoice se_extractor import (make sure OpenVoice is on PYTHONPATH)
from openvoice import se_extractor

def _sanitize_speaker(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_\-]", "_", s).strip("_")

def build_reference_embeddings_from_diarization(
    original_audio_path: str,
    gemini_data: dict,
    out_dir: str = "reference_segments",
    join_silence_ms: int = 100,
    sample_rate: int = 24000,
    min_total_duration_sec: float = 2.5,
    tone_color_converter=None,
):
    """
    Concatenate all segments for each speaker (in chronological order),
    export a single WAV reference per speaker, then create + save OpenVoice embeddings.

    Returns:
      dict mapping speaker -> {"ref_wav": str, "embedding": str, "duration_s": float}
    """

    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load full original audio once (pydub uses ffmpeg behind the scenes)
    audio = AudioSegment.from_file(str(original_audio_path))
    full_duration_ms = len(audio)
    print(f"[info] Loaded original audio '{original_audio_path}' ({full_duration_ms/1000:.2f}s)")

    # Group segments by speaker (keeps order)
    speaker_segments = defaultdict(list)
    for seg in gemini_data.get("speak_segments", []):
        try:
            start_ms = int(round(float(seg["start_time"]) * 1000))
            end_ms   = int(round(float(seg["end_time"]) * 1000))
        except Exception as e:
            print(f"[warning] skipping segment with bad times: {seg} -> {e}")
            continue
        if end_ms <= start_ms:
            print(f"[warning] skipping zero/negative-length segment: start={start_ms} end={end_ms}")
            continue
        # clamp within audio
        start_ms = max(0, min(start_ms, full_duration_ms))
        end_ms   = max(0, min(end_ms, full_duration_ms))
        dur_ms = end_ms - start_ms
        speaker = seg.get("speaker", "unknown")
        speaker_segments[speaker].append({"start_ms": start_ms, "end_ms": end_ms, "duration_ms": dur_ms})

    results = {}

    for speaker, segs in speaker_segments.items():
        # sort by start time to keep natural order
        segs_sorted = sorted(segs, key=lambda x: x["start_ms"])
        concat = AudioSegment.silent(duration=0, frame_rate=sample_rate)

        total_ms = 0
        for s in segs_sorted:
            start_ms = s["start_ms"]
            end_ms = s["end_ms"]
            # slice original audio
            clip = audio[start_ms:end_ms]
            # append clip
            concat += clip
            total_ms += len(clip)
            # add tiny silence between clips to avoid gluing words (optional)
            concat += AudioSegment.silent(duration=join_silence_ms, frame_rate=sample_rate)

        total_s = total_ms / 1000.0

        # Warn if too short
        if total_s < min_total_duration_sec:
            warnings.warn(
                f"[warn] total concatenated duration for speaker '{speaker}' is short ({total_s:.2f}s). "
                "Embeddings may be poor. Consider merging more segments or increasing min_total_duration_sec."
            )

        # Trim trailing silence
        # If concat is longer than join silence at end, remove last join_silence_ms
        if len(concat) >= join_silence_ms:
            concat = concat[:-join_silence_ms]

        # Normalize export settings
        concat = concat.set_frame_rate(sample_rate).set_channels(1)

        # Save reference wav
        speaker_safe = _sanitize_speaker(speaker)
        ref_path = out_dir / f"ref_{speaker_safe}.wav"
        concat.export(str(ref_path), format="wav")
        print(f"[ok] Saved ref audio for '{speaker}' -> {ref_path} ({total_s:.2f}s)")

        # Create embedding via se_extractor
        try:
            # se_extractor.get_se may accept different signatures; attempt with tone_color_converter first
            if tone_color_converter is not None:
                source_se, audio_name = se_extractor.get_se(str(ref_path), tone_color_converter, vad=True)
            else:
                # fallback signature
                source_se, audio_name = se_extractor.get_se(str(ref_path), vad=True)
        except TypeError:
            # try alternate call (some builds expect only path)
            source_se, audio_name = se_extractor.get_se(str(ref_path))

        emb_path = out_dir / f"{speaker_safe}_embedding.pt"
        torch.save(source_se, str(emb_path))
        print(f"[ok] Saved embedding for '{speaker}' -> {emb_path}")

        results[speaker] = {
            "ref_wav": str(ref_path),
            "embedding": str(emb_path),
            "duration_s": total_s,
            "num_segments": len(segs_sorted),
        }

    return results


In [None]:
# gemini_data = json.load(open("translated_diarized_data.json", encoding="utf-8"))
# original_audio = "original_audio.mp3"

ref_map = build_reference_embeddings_from_diarization(
    original_audio_path="media_1_audio.mp3",
    gemini_data=gemini_data,
    out_dir="reference_segments",
    join_silence_ms=100,
    sample_rate=24000,
    min_total_duration_sec=5.0,
    tone_color_converter=tone_color_converter  # or None
)

print(ref_map)
# -> {'Speaker 1': {'ref_wav': 'reference_segments/ref_Speaker_1.wav', 'embedding': 'reference_segments/Speaker_1_embedding.pt', ...}, ...}


In [None]:
from pydub import AudioSegment
import torch
from pathlib import Path

def build_src_embedding_from_tts(tts_dir: str, gemini_data: dict, tone_color_converter=None, max_len_sec=10):
    """
    Combine TTS segments (up to max_len_sec) to create one source embedding.
    """
    combined = AudioSegment.silent(duration=0)
    total_ms = 0

    for idx, seg in enumerate(gemini_data["speak_segments"]):
        tts_path = Path(tts_dir) / f"{idx}_{seg['speaker']}.wav"
        if not tts_path.exists():
            continue

        seg_audio = AudioSegment.from_wav(tts_path)
        combined += seg_audio
        total_ms += len(seg_audio)

        if total_ms >= max_len_sec * 1000:
            break

    # Save temporary combined TTS clip
    tmp_path = Path(tts_dir) / "tts_src_reference.wav"
    combined.export(tmp_path, format="wav")

    # Extract embedding once
    if tone_color_converter is not None:
        src_se, _ = se_extractor.get_se(str(tmp_path), tone_color_converter, vad=True)
    else:
        src_se, _ = se_extractor.get_se(str(tmp_path), vad=True)

    print(f"[ok] Source embedding created from {total_ms/1000:.1f}s of TTS audio -> {tmp_path}")
    return src_se, str(tmp_path)


In [None]:
# Step 2: Build **single src embedding** from TTS segments
src_se, tts_ref_path = build_src_embedding_from_tts(
    tts_dir="translated_segments",
    gemini_data=gemini_data,
    tone_color_converter=tone_color_converter
)

In [None]:
def convert_tts_segments_with_refs(tts_dir, gemini_data, ref_map, out_dir, src_se, tone_color_converter):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    results = {}
    for idx, seg in enumerate(gemini_data.get("speak_segments", [])):
        speaker = seg["speaker"]

        tts_wav = Path(tts_dir) / f"{idx}_{speaker}.wav"
        if not tts_wav.exists():
            print(f"[skip] No TTS audio for seg {idx}")
            continue

        if speaker not in ref_map:
            print(f"[skip] No reference for {speaker}")
            continue

        tgt_se = torch.load(ref_map[speaker]["embedding"])
        out_wav = out_dir / f"converted_seg_{idx}_{speaker}.wav"

        tone_color_converter.convert(
            audio_src_path=str(tts_wav),
            src_se=src_se,   # <-- reuse the same source embedding
            tgt_se=tgt_se,
            output_path=str(out_wav),
            message="@MyShell",
        )
        print(f"[ok] Converted seg {idx} ({speaker}) -> {out_wav}")
        results[idx] = str(out_wav)

    return results


In [None]:
converted = convert_tts_segments_with_refs(
    tts_dir="translated_segments",
    gemini_data=gemini_data,
    ref_map=ref_map,
    out_dir="converted_segments",
    src_se=src_se,
    tone_color_converter=tone_color_converter
)

In [None]:
import json
from pathlib import Path
from pydub import AudioSegment

def mix_segments_with_background(gemini_json_path, converted_dir, background_audio_path, output_path):
    # Load segment metadata
    with open(gemini_json_path, "r", encoding="utf-8") as f:
        gemini_data = json.load(f)

    speak_segments = gemini_data["speak_segments"]

    # Load background audio
    bg_audio = AudioSegment.from_file(background_audio_path, format="mp3")
    
    # Overlay each converted segment
    for idx, seg in enumerate(speak_segments):
        seg_file = Path(converted_dir) / f"converted_seg_{idx}_{seg['speaker']}.wav"
        if not seg_file.exists():
            print(f"Skipping missing segment: {seg_file}")
            continue
        
        segment_audio = AudioSegment.from_file(seg_file, format="wav")
        start_ms = int(seg["start_time"] * 1000)
        bg_audio = bg_audio.overlay(segment_audio, position=start_ms)
        print(f"Overlayed segment {idx} ({seg['speaker']}) at {start_ms} ms")

    # Export final mix
    bg_audio.export(output_path, format="wav")
    print(f"Final audio mix saved to {output_path}")

# Example usage
mix_segments_with_background(
    gemini_json_path="translated_diarized_data.json",
    converted_dir="converted_segments",
    background_audio_path="separated_audio/htdemucs/media_1_audio/no_vocals.wav",
    output_path="final_mix.wav"
)


In [None]:
from moviepy import VideoFileClip, AudioFileClip

# Load video and audio clips
video = VideoFileClip("media_1_video.mp4")
audio = AudioFileClip("final_mix.wav")

# Set the new audio to the video
final_video = video.with_audio(audio)

# Export the combined file
final_video.write_videofile("output_video_with_audio.mp4")
