In [None]:
# ========================================
# MAXIMUM ACCURACY JAPANESE TRANSCRIPTION
# With POST-TRANSCRIPTION Segmentation Adjustment
# ========================================

!pip install -q faster-whisper
!pip install -q tqdm
!apt install ffmpeg -y

from faster_whisper import WhisperModel
from google.colab import files
import os
import time
from datetime import timedelta
import re
import torch
import json
import glob

# ==================== CONFIGURATION ====================

class MaxAccuracyConfig:
    MODEL_SIZE = "large-v3"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    COMPUTE_TYPE = "float16"
    BEAM_SIZE = 10
    BEST_OF = 5
    PATIENCE = 2.0
    COMPRESSION_RATIO_THRESHOLD = 2.4
    LOG_PROB_THRESHOLD = -1.0
    NO_SPEECH_THRESHOLD = 0.6
    TEMPERATURE = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
    LANGUAGE = "ja"
    INITIAL_PROMPT = "„Åì„Çì„Å´„Å°„ÅØ„ÄÇ‰ª•‰∏ã„ÅØÊó•Êú¨Ë™û„ÅÆÈü≥Â£∞„Åß„Åô„ÄÇÂè•Ë™≠ÁÇπ„ÇíÊ≠£Á¢∫„Å´‰ªò„Åë„Å¶„Åè„Å†„Åï„ÅÑ„ÄÇ"
    VAD_FILTER = True
    VAD_MIN_SILENCE = 2000
    VAD_THRESHOLD = 0.3
    NUM_WORKERS = 8
    CPU_THREADS = 8

# ==================== UTILITY FUNCTIONS ====================

def format_timestamp(seconds):
    td = timedelta(seconds=seconds)
    hours = int(td.total_seconds() // 3600)
    minutes = int((td.total_seconds() % 3600) // 60)
    secs = int(td.total_seconds() % 60)
    millis = int((td.total_seconds() % 1) * 1000)
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"

def format_duration(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{secs:02d}"

def count_japanese_chunks(text):
    chunks = re.split(r'[„ÄÅ„ÄÇÔºÅÔºü\s]+', text)
    return len([c for c in chunks if c.strip()])

def smart_segment_split(segments, max_chunks, max_seconds, max_chars):
    """Improved segmentation that respects ALL constraints"""
    new_segments = []

    for segment in segments:
        text = segment['text'].strip()
        if not text:
            continue

        start_time = segment['start']
        end_time = segment['end']
        duration = end_time - start_time

        chunks = count_japanese_chunks(text)

        # Check if segment already meets ALL criteria
        if (chunks <= max_chunks and
            duration <= max_seconds and
            len(text) <= max_chars):
            new_segments.append(segment)
            continue

        # Calculate how many parts we need based on ALL constraints
        parts_by_duration = max(1, int(duration / max_seconds) + (1 if duration % max_seconds > 0.1 else 0))
        parts_by_chars = max(1, (len(text) + max_chars - 1) // max_chars)
        parts_by_chunks = max(1, (chunks + max_chunks - 1) // max_chunks)

        # Use the maximum to ensure ALL constraints are met
        num_parts = max(parts_by_duration, parts_by_chars, parts_by_chunks)

        if num_parts == 1:
            new_segments.append(segment)
            continue

        # Split the text into the required number of parts
        chars_per_part = max(1, len(text) // num_parts)
        text_parts = []

        remaining_text = text
        for i in range(num_parts - 1):
            split_point = chars_per_part

            # Look for punctuation near the split point
            for punct in ['„ÄÇ', 'ÔºÅ', 'Ôºü', '„ÄÅ', '„Äç', ' ']:
                punct_pos = remaining_text.find(punct, max(0, split_point - 10))
                if punct_pos != -1 and punct_pos < split_point + 10:
                    split_point = punct_pos + 1
                    break

            split_point = min(split_point, len(remaining_text))
            if split_point > 0:
                text_parts.append(remaining_text[:split_point].strip())
                remaining_text = remaining_text[split_point:].strip()

        if remaining_text:
            text_parts.append(remaining_text)

        text_parts = [p for p in text_parts if p.strip()]

        if not text_parts:
            continue

        time_per_part = duration / len(text_parts)

        for i, part in enumerate(text_parts):
            part_start = start_time + (i * time_per_part)
            part_end = start_time + ((i + 1) * time_per_part)

            if part_end - part_start > max_seconds:
                part_end = part_start + max_seconds

            new_segments.append({
                'start': part_start,
                'end': part_end,
                'text': part.strip()
            })

    return new_segments

def create_srt(segments, output_file):
    with open(output_file, 'w', encoding='utf-8') as f:
        for i, segment in enumerate(segments, start=1):
            f.write(f"{i}\n")
            f.write(f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n")
            f.write(f"{segment['text']}\n\n")
    return output_file

def create_vtt(segments, output_file):
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("WEBVTT\n\n")
        for i, segment in enumerate(segments, start=1):
            start = format_timestamp(segment['start']).replace(',', '.')
            end = format_timestamp(segment['end']).replace(',', '.')
            f.write(f"{start} --> {end}\n")
            f.write(f"{segment['text']}\n\n")
    return output_file

# ==================== FILE SELECTION ====================

def select_audio_file():
    print("\n" + "=" * 70)
    print(" AUDIO FILE SELECTION")
    print("=" * 70)

    audio_extensions = ['*.aac', '*.mp3', '*.wav', '*.m4a', '*.flac', '*.ogg', '*.wma', '*.mp4', '*.mkv', '*.webm']
    existing_files = []

    for ext in audio_extensions:
        existing_files.extend(glob.glob(ext))
        existing_files.extend(glob.glob(ext.upper()))

    for directory in ['/content', '/content/drive/MyDrive']:
        if os.path.exists(directory):
            for ext in audio_extensions:
                existing_files.extend(glob.glob(os.path.join(directory, ext)))

    existing_files = sorted(list(set(existing_files)))

    if existing_files:
        print("\n Found existing audio files:")
        for i, f in enumerate(existing_files, 1):
            size_mb = os.path.getsize(f) / (1024 * 1024)
            print(f"  [{i}] {os.path.basename(f)} ({size_mb:.2f} MB)")
        print(f"\n  [0] Upload a NEW file")

        while True:
            try:
                choice = input("\n Enter your choice: ").strip()
                if choice == '0':
                    uploaded = files.upload()
                    if uploaded:
                        return list(uploaded.keys())[0]
                    continue
                choice_num = int(choice)
                if 1 <= choice_num <= len(existing_files):
                    return existing_files[choice_num - 1]
            except ValueError:
                if os.path.exists(choice):
                    return choice
                print(" Invalid input")
    else:
        print("\n Please upload your audio file...")
        uploaded = files.upload()
        if uploaded:
            return list(uploaded.keys())[0]

# ==================== TRANSCRIPTION ENGINE ====================

class LongAudioTranscriber:
    def __init__(self):
        self.model = None
        self.config = MaxAccuracyConfig()
        self.raw_segments = None
        self.full_text = None
        self.info = None
        self.base_name = None

    def load_model(self):
        print(f" Loading {self.config.MODEL_SIZE} model...")
        self.model = WhisperModel(
            self.config.MODEL_SIZE,
            device=self.config.DEVICE,
            compute_type=self.config.COMPUTE_TYPE,
            num_workers=self.config.NUM_WORKERS,
            cpu_threads=self.config.CPU_THREADS
        )
        print(" Model loaded!\n")

    def transcribe(self, audio_file):
        """Transcribe and store raw results"""
        self.base_name = os.path.splitext(os.path.basename(audio_file))[0]

        print(f"  Transcribing: {audio_file}")
        print(" This may take a while for long audio...\n")

        start_time = time.time()

        segments_generator, self.info = self.model.transcribe(
            audio_file,
            language=self.config.LANGUAGE,
            beam_size=self.config.BEAM_SIZE,
            best_of=self.config.BEST_OF,
            patience=self.config.PATIENCE,
            temperature=self.config.TEMPERATURE,
            compression_ratio_threshold=self.config.COMPRESSION_RATIO_THRESHOLD,
            log_prob_threshold=self.config.LOG_PROB_THRESHOLD,
            no_speech_threshold=self.config.NO_SPEECH_THRESHOLD,
            condition_on_previous_text=True,
            initial_prompt=self.config.INITIAL_PROMPT,
            vad_filter=self.config.VAD_FILTER,
            vad_parameters=dict(
                min_silence_duration_ms=self.config.VAD_MIN_SILENCE,
                threshold=self.config.VAD_THRESHOLD
            ),
            word_timestamps=True,
            hallucination_silence_threshold=2.0
        )

        self.raw_segments = []
        self.full_text = ""
        segment_count = 0

        for segment in segments_generator:
            seg_dict = {
                'start': segment.start,
                'end': segment.end,
                'text': segment.text,
                'avg_logprob': segment.avg_logprob,
                'no_speech_prob': segment.no_speech_prob
            }
            self.raw_segments.append(seg_dict)
            self.full_text += segment.text
            segment_count += 1

            if segment_count % 50 == 0:
                print(f"  Processed: {format_duration(segment.end)} | {segment_count} segments")

        processing_time = time.time() - start_time

        # Save raw transcription
        raw_json = f"{self.base_name}_raw_transcription.json"
        with open(raw_json, 'w', encoding='utf-8') as f:
            json.dump({
                'segments': self.raw_segments,
                'full_text': self.full_text,
                'language': self.info.language,
                'language_probability': self.info.language_probability
            }, f, ensure_ascii=False, indent=2)

        print(f"\n Transcription complete!")
        print(f" Time: {format_duration(processing_time)}")
        print(f" Segments: {len(self.raw_segments)}")
        print(f" Raw data saved: {raw_json}")

        return self.raw_segments

    def load_existing_transcription(self, json_file):
        """Load previously saved transcription"""
        print(f" Loading existing transcription: {json_file}")
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)

        self.raw_segments = data['segments']
        self.full_text = data['full_text']
        self.base_name = os.path.splitext(os.path.basename(json_file))[0].replace('_raw_transcription', '')

        print(f" Loaded {len(self.raw_segments)} segments")
        return self.raw_segments

    def generate_subtitles(self, max_chunks, max_seconds, max_chars, suffix=""):
        """Generate subtitle files with custom segmentation"""
        if not self.raw_segments:
            print(" No transcription loaded! Run transcribe() first or load existing.")
            return

        print(f"\n  Applying segmentation: {max_chunks} chunks | {max_seconds}s | {max_chars} chars")

        custom_segments = smart_segment_split(
            self.raw_segments,
            max_chunks,
            max_seconds,
            max_chars
        )

        # Generate filenames
        suffix_str = f"_{suffix}" if suffix else ""
        srt_file = f"{self.base_name}_custom{suffix_str}.srt"
        vtt_file = f"{self.base_name}_custom{suffix_str}.vtt"

        create_srt(custom_segments, srt_file)
        create_vtt(custom_segments, vtt_file)

        print(f" Created: {srt_file} ({len(custom_segments)} segments)")
        print(f" Created: {vtt_file}")

        return srt_file, vtt_file, custom_segments

    def generate_original_srt(self):
        """Generate SRT with original Whisper segmentation"""
        if not self.raw_segments:
            print(" No transcription loaded!")
            return

        srt_file = f"{self.base_name}_original.srt"
        create_srt(self.raw_segments, srt_file)
        print(f" Created: {srt_file} ({len(self.raw_segments)} segments)")
        return srt_file

# ==================== INTERACTIVE MENU ====================

def interactive_segmentation(engine):
    """Allow user to try different settings without re-transcribing"""

    while True:
        print("\n" + "=" * 70)
        print("  SEGMENTATION ADJUSTMENT MENU")
        print("=" * 70)
        print("\nCurrent transcription has", len(engine.raw_segments), "original segments")
        print("\nOptions:")
        print("  [1] Generate subtitles with CUSTOM settings")
        print("  [2] Generate subtitles with PRESET settings")
        print("  [3] Generate ORIGINAL Whisper segmentation")
        print("  [4] Download all generated files")
        print("  [5] Exit")
        print("-" * 70)

        choice = input("\nüëâ Enter your choice (1-5): ").strip()

        if choice == '1':
            print("\n Enter your custom settings:")
            print("   (Press Enter for defaults)")

            try:
                max_chunks = int(input("   Max chunks per segment [10]: ") or 10)
                max_seconds = float(input("   Max seconds per segment [8]: ") or 8)
                max_chars = int(input("   Max characters per segment [45]: ") or 45)
                suffix = input("   File suffix (optional, e.g., 'v1'): ").strip()

                srt_file, vtt_file, segments = engine.generate_subtitles(
                    max_chunks, max_seconds, max_chars, suffix
                )

                # Preview
                print(f"\n Preview (first 5 segments):")
                print("-" * 50)
                for i, seg in enumerate(segments[:5], 1):
                    duration = seg['end'] - seg['start']
                    print(f"{i}. [{duration:.1f}s] {seg['text'][:50]}...")

            except ValueError as e:
                print(f" Invalid input: {e}")

        elif choice == '2':
            print("\n SELECT A PRESET:")
            print("  [a] Compact:   5 chunks | 4s  | 35 chars (short subtitles)")
            print("  [b] Standard:  10 chunks | 8s | 45 chars (recommended)")
            print("  [c] Spacious:  15 chunks | 12s | 60 chars (longer subtitles)")
            print("  [d] Ultra-short: 3 chunks | 2s | 25 chars (very fast)")
            print("  [e] Minimal:   1 chunk | 1s | 42 chars (your original request)")

            preset = input("\n Enter preset (a-e): ").strip().lower()

            presets = {
                'a': (5, 4, 35, 'compact'),
                'b': (10, 8, 45, 'standard'),
                'c': (15, 12, 60, 'spacious'),
                'd': (3, 2, 25, 'ultrashort'),
                'e': (1, 1, 42, 'minimal')
            }

            if preset in presets:
                chunks, secs, chars, name = presets[preset]
                engine.generate_subtitles(chunks, secs, chars, name)
            else:
                print("Invalid preset")

        elif choice == '3':
            engine.generate_original_srt()

        elif choice == '4':
            print("\n Downloading all generated files...")
            for f in glob.glob(f"{engine.base_name}*.srt"):
                print(f"   ‚¨áÔ∏è  {f}")
                files.download(f)
            for f in glob.glob(f"{engine.base_name}*.vtt"):
                print(f"   ‚¨áÔ∏è  {f}")
                files.download(f)
            for f in glob.glob(f"{engine.base_name}*.json"):
                print(f"   ‚¨áÔ∏è  {f}")
                files.download(f)
            print(" Download complete!")

        elif choice == '5':
            print("\n Goodbye!")
            break

        else:
            print("Invalid choice. Please enter 1-5.")

# ==================== MAIN ====================

def main():
    print("=" * 70)
    print(" JAPANESE TRANSCRIPTION WITH POST-PROCESSING")
    print("   Transcribe once, adjust segmentation unlimited times!")
    print("=" * 70)

    engine = LongAudioTranscriber()

    # Check for existing transcription
    existing_json = glob.glob("*_raw_transcription.json")

    if existing_json:
        print("\n Found existing transcription(s):")
        for i, f in enumerate(existing_json, 1):
            print(f"  [{i}] {f}")
        print(f"  [0] Start NEW transcription")

        choice = input("\n Load existing or start new? ").strip()

        if choice != '0' and choice.isdigit():
            idx = int(choice) - 1
            if 0 <= idx < len(existing_json):
                engine.load_existing_transcription(existing_json[idx])
                interactive_segmentation(engine)
                return

    # New transcription
    input_file = select_audio_file()

    print(f"\n Selected: {input_file}")
    print(f" Size: {os.path.getsize(input_file) / (1024*1024):.2f} MB")

    engine.load_model()
    engine.transcribe(input_file)

    # Enter interactive mode
    interactive_segmentation(engine)

# ==================== RUN ====================

if __name__ == "__main__":
    main()