In [None]:
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
import json
import re
from typing import List, Dict, Tuple
import os
from janome.tokenizer import Tokenizer
from tqdm import tqdm
import subprocess
import srt

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
BATCH_SIZE = 30

def get_srt_files():
    return [f for f in os.listdir('.') if f.endswith('.srt')]

def select_srt_file():
    srt_files = get_srt_files()
    if not srt_files:
        print("No .srt files found in the current directory.")
        return None

    print("Available .srt files:")
    for i, file in enumerate(srt_files):
        print(f"{i+1}. {file}")

    while True:
        try:
            choice = int(input("Enter the number of the file you want to translate: "))
            if 1 <= choice <= len(srt_files):
                return srt_files[choice - 1]
            else:
                print("Invalid choice. Please enter a number between 1 and", len(srt_files))
        except ValueError:
            print("Invalid input. Please enter a number.")

def create_video_with_subtitles(video_file, subtitle_file, output_video_file):
    """字幕付きのビデオを作成します。
    
    Args:
        video_file (str): 入力ビデオファイルのパス。
        subtitle_file (str): 字幕ファイルのパス。
        output_video_file (str): 出力ビデオファイルのパス。
    
    Returns:
        None
    """
    process = subprocess.Popen([
        'ffmpeg',
        '-hwaccel', 'auto',
        '-i', f"{video_file}",
        '-vf', f"subtitles={subtitle_file}:force_style='FontName=Helvetica,FontSize=11'",
        '-c:v', 'h264_amf',
        '-c:a', 'copy',
        '-progress', '-',  
        f'{output_video_file}'
    ], stderr=subprocess.PIPE, universal_newlines=True)

    while True:
        line = process.stderr.readline()
        if line == '' and process.poll() is not None:
            break
        if 'frame=' in line:
            print(line.strip())

    if process.returncode != 0:
        print(f"Error creating video: {process.returncode}")            


class SRTTranslator:
    def __init__(self, api_key: str, batch_size: int = BATCH_SIZE):
        """
        Initialize the SRT translator
        
        Args:
            api_key (str): OpenAI API key
            batch_size (int): Number of subtitle entries to translate at once
        """
        self.chat = ChatOpenAI(
            temperature=0,
            openai_api_key=api_key,
            model_name="gpt-4o-mini"
        )
        self.batch_size = batch_size

    def parse_srt(self, file_path: str) -> List[Dict[str, str]]:
        """
        Parse SRT file into a list of dictionaries
        """
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()

        # Split into subtitle blocks
        blocks = content.strip().split('\n\n')
        subtitles = []

        for block in blocks:
            lines = block.split('\n')
            if len(lines) >= 3:
                index = lines[0]
                timestamp = lines[1]
                text = ' '.join(lines[2:])
                
                subtitles.append({
                    'index': index,
                    'timestamp': timestamp,
                    'text': text
                })

        return subtitles

    def create_translation_prompt(self, subtitles: List[Dict[str, str]]) -> str:
        """
        Create a prompt for translation
        """
        texts = [f"{i+1}. {sub['text']}" for i, sub in enumerate(subtitles)]
        text_block = '\n'.join(texts)
        
        prompt = f"""
            Please translate the following English subtitles to Japanese. 
            Respond with ONLY a valid JSON object in the following format:
            {{
                "translations": [
                    "Japanese translation 1",
                    "Japanese translation 2",
                    ...
                ]
            }}

            Subtitles to translate:

            {text_block}
            """
        return prompt

    def extract_translations_from_response(self, response_content: str) -> List[str]:
        """
        Extract translations from the response, with better error handling
        """
        # First try to parse the entire response as JSON
        try:
            data = json.loads(response_content)
            if isinstance(data, dict) and "translations" in data:
                return data["translations"]
        except json.JSONDecodeError:
            pass

        # If that fails, try to extract JSON object using regex
        try:
            json_match = re.search(r'\{[\s\S]*\}', response_content)
            if json_match:
                data = json.loads(json_match.group(0))
                if isinstance(data, dict) and "translations" in data:
                    return data["translations"]
        except (json.JSONDecodeError, AttributeError):
            pass

        # If all parsing attempts fail, try to extract translations directly
        try:
            lines = response_content.split('\n')
            translations = []
            for line in lines:
                # Remove common prefixes that might appear in the response
                line = re.sub(r'^\d+\.\s*', '', line)
                line = line.strip()
                if line:
                    translations.append(line)
            if translations:
                return translations
        except Exception:
            pass

        raise ValueError("Could not extract translations from API response")

    def translate_batch(self, subtitles: List[Dict[str, str]]) -> List[str]:
        """
        Translate a batch of subtitles with improved response handling
        """
        prompt = self.create_translation_prompt(subtitles)
        response = self.chat([HumanMessage(content=prompt)])
        
        return self.extract_translations_from_response(response.content)

    def translate_file(self, input_path: str, output_path: str):
        """
        Translate entire SRT file and save to new file
        """
        # Parse input file
        subtitles = self.parse_srt(input_path)
        
        # Process in batches
        translated_subtitles = []
        for i in range(0, len(subtitles), self.batch_size):
            batch = subtitles[i:i+self.batch_size]
            try:
                translations = self.translate_batch(batch)
                
                # Verify we got the expected number of translations
                if len(translations) != len(batch):
                    print(f"Warning: Got {len(translations)} translations for {len(batch)} subtitles in batch {i//self.batch_size + 1}")
                    # Pad or truncate translations to match batch size
                    translations = translations[:len(batch)] if len(translations) > len(batch) else translations + [""] * (len(batch) - len(translations))
                
                # Combine original timing with translations
                for j, translation in enumerate(translations):
                    subtitle = batch[j]
                    translated_subtitles.append({
                        'index': subtitle['index'],
                        'timestamp': subtitle['timestamp'],
                        'text': translation
                    })
            except Exception as e:
                print(f"Error processing batch {i//self.batch_size + 1}: {str(e)}")
                continue

        # Write output file
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, sub in enumerate(translated_subtitles):
                f.write(f"{sub['index']}\n")
                f.write(f"{sub['timestamp']}\n")
                f.write(f"{sub['text']}\n\n")

    def add_line_breaks(self, text, max_line_length=40, max_lines=2):
        """テキストに適切な改行を追加します。

        Args:
            text (str): 改行を追加するテキスト。
            max_line_length (int): 1行あたりの最大文字数。
            max_lines (int): 最大行数。

        Returns:
            str: 改行が追加されたテキスト。
        """
        tokenizer = Tokenizer()
        tokens = tokenizer.tokenize(text, wakati=True)
        lines = []
        current_line = ""
        
        for token in tokens:
            if re.match(r'[。！？]', token):
                current_line += token
                lines.append(current_line)
                current_line = ""
            elif len(current_line) + len(token) <= max_line_length:
                current_line += token
            else:
                lines.append(current_line)
                current_line = token
        
        if current_line:
            lines.append(current_line)

        lines = [re.sub(r'[$$$$［］「」$$（）\{\}｛｝''""]', '', line) for line in lines]
        return '\n'.join(lines[:max_lines])

    def add_line_breaks_to_srt(self, srt_text):
        """SRTファイルの字幕データに改行を追加します。

        Args:
            srt_text (str): SRTファイルの字幕データ。

        Returns:
            str: 改行が追加されたSRTファイルの字幕データ。
        """
        subtitles = list(srt.parse(srt_text))
        for subtitle in subtitles:
            subtitle.content = self.add_line_breaks(subtitle.content)
        return srt.compose(subtitles)
    

def main():
    translator = SRTTranslator(
        api_key=OPENAI_API_KEY,
        batch_size=30
    )
    
    input_path = select_srt_file()
    if input_path is None:
        return
    
    output_path = "output_translated.srt"
    translator.translate_file(input_path, output_path)
    
    # Add line breaks to SRT file
    with open(output_path, 'r', encoding='utf-8') as f:
        srt_text = f.read()
    srt_text_with_line_breaks = translator.add_line_breaks_to_srt(srt_text)
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(srt_text_with_line_breaks)
    
    # Create video with subtitles
    video_file = input_path.replace(".srt", ".mp4")
    output_video_file = video_file.replace(".mp4", "_jp.mp4")
    create_video_with_subtitles(video_file, output_path, output_video_file)

if __name__ == "__main__":
    main()