In [None]:
import random
import csv
import json
from typing import List, Literal, Optional
from tqdm import tqdm
import spacy
import os
import random

from gemini_aistudio import GenerativeModel

from dotenv import load_dotenv
load_dotenv()

from sacrebleu import corpus_chrf
from dataclasses import dataclass

from translator_utils import Translator, Glossary, Line, LineManager, Message

MODE: Literal["translate", "post-edit"] = "translate"
ZERO_SHOT: bool = True

if MODE == "post-edit":
    gemini = GenerativeModel(system_instruction="""You are an expert translator. I am going to give you relevant glossary entries, and relevant past translations, where the first is the English source, the second is a machine translation of the English to Bislama, and the third is the Bislama reference translation. The sentences will be written English: <sentence> MT: <machine translated sentence> Bislama: <translated sentence>. After the example pairs, I am going to provide another sentence in English and its machine translation, and I want you to translate it into Bislama. Give only the translation, and no extra commentary, formatting, or chattiness. Translate the text from English to Bislama.""")
elif MODE == "translate":
    if ZERO_SHOT:
        gemini = GenerativeModel(system_instruction="You are an expert translator. I am going to give you text in English, and would like you to translate it to Bislama. Give only the translation, and no extra commentary, formatting, or chattiness.")
    else:
        gemini = GenerativeModel(system_instruction="""You are an expert translator. I am going to give you some example pairs of text snippets where the first is in English and the second is a translation of the first snippet into Bislama. The sentences will be written English: <first sentence> Bislama: <translated first sentence> After the example pairs, I am going to provide another sentence in English and I want you to translate it into Bislama. Give only the translation, and no extra commentary, formatting, or chattiness. Translate the text from English to Bislama.""")

TRANSLATE_WITH: Literal["google", "madlad", "opusmt"] = "madlad"

pred_key = f'tgt_pred_{TRANSLATE_WITH}'


In [None]:
glossary = Glossary(file="datafiles/bislama_school_dictionary.json")
glossary.load_entries()

In [None]:
# select lines to translate

with open('datafiles/bislama_parallel.csv') as f:
    reader = csv.DictReader(f)
    lines: List[Line] = [Line(en=row['en'], tgt=row['tgt'], tgt_pred_madlad=row['tgt_pred_madlad'], tgt_pred_opusmt=row['tgt_pred_opusmt']) for row in reader]

print(f"Total of {len(lines)} lines")

lines = [l for l in lines if l.en and l.tgt]
print(f"Total of {len(lines)} lines with en and tgt")


mt_chrf = corpus_chrf(
    [getattr(l, pred_key) for l in lines],
    [[l.tgt for l in lines]],
    word_order=2,
)
print(f"CHRF for MT: {mt_chrf.score:.2f}")

random.seed(42)
train_lines = lines[:int(len(lines) * 0.8)]
print(f"{len(train_lines)} lines for training")
test_lines = lines[int(len(lines) * 0.8):]
print(f"{len(test_lines)} lines for testing")


In [None]:
translator = Translator(translate_with=TRANSLATE_WITH)
translator.init_bm25(train_lines)

random.seed(42)
train_sample = random.sample(train_lines, 10)

def format_messages_for_gemini(messages: List[Message]) -> List[dict]:
    return [{
        "role": "user" if message.role == "user" else "model",
        "parts": [message.content]
    } for message in messages if message.role != 'system']

def get_post_edited_translation_gemini(input_text: str) -> str:
    # get glossary entries and similar sentences
    glossary_entries = glossary.get_entries(input_text)
    similar_sentences = translator.get_top_similar_sentences_bm25(input_text, top_n=10)

    messages = translator.construct_prompt_post_edit(
        input_text, 
        similar_sentences,
        glossary_entries=glossary_entries,
    )

    # print(messages[0].content)
    messages = format_messages_for_gemini(messages)

    response = gemini.generate_content(
        messages,
    )
    return response.strip()

def get_final_translation_gemini(input_text: str) -> str:
    messages = translator.construct_prompt_translation(
        input_text,
        train_sample if not ZERO_SHOT else [],
    )

    # print(messages[0].content)
    messages = format_messages_for_gemini(messages)

    response = gemini.generate_content(
        messages,
    )
    return response.strip()

In [None]:
print(f'src: {test_lines[1].en}')
print(f'pred: {get_final_translation_gemini(test_lines[1].en)}')

In [None]:
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

MAX_LINES = None


if MODE == 'post-edit':
    def process_line(line):
        if not getattr(line, 'tgt_pred_post_edited', None):
            line.tgt_pred_post_edited = get_post_edited_translation_gemini(line.en)
        return line
elif MODE == 'translate':
    def process_line(line):
        if not getattr(line, 'tgt_pred_gemini', None):
            line.tgt_pred_gemini = get_final_translation_gemini(line.en)
        return line


with ThreadPoolExecutor(max_workers=10) as executor:
    if MAX_LINES:
        futures = [executor.submit(process_line, line) for line in test_lines[:MAX_LINES]]
    else:
        futures = [executor.submit(process_line, line) for line in test_lines]
    
    for future in tqdm(as_completed(futures), total=len(test_lines)):
        try:
            future.result()
        except Exception as e:
            print(f"Error processing line: {e}")

In [None]:
def chrf_for_key(key: str):
    chrf = corpus_chrf(
        [getattr(l, key) for l in test_lines],
        [[l.tgt for l in test_lines]],
        word_order=2,
    )
    return chrf.score

if MODE == 'post-edit':
    mt_chrf = chrf_for_key(pred_key)
    print(f"CHRF for MT: {mt_chrf:.2f}")

    ape_chrf = chrf_for_key('tgt_pred_post_edited')
    print(f"CHRF for APE: {ape_chrf:.2f}")
elif MODE == 'translate':
    gemini_chrf = chrf_for_key('tgt_pred_gemini')
    print(f"CHRF for Gemini: {gemini_chrf:.2f}")

In [None]:

for line in random.sample(test_lines, 5):
    print("English:", line.en)
    stripped_bislama = strip_response_tags(line.tgt_pred_post_edited)
    print("APE:", stripped_bislama)
    print("Reference:", line.tgt)
    print()


In [None]:

with open('datafiles/bislama_parallel_gemini.csv', 'w') as f:
    writer = csv.DictWriter(f, fieldnames=['en', 'tgt', 'tgt_pred_madlad', 'tgt_pred_post_edited'])
    writer.writeheader()
    for line in test_lines:
        writer.writerow({
            'en': line.en,
            'tgt': line.tgt,
            'tgt_pred_madlad': line.tgt_pred_madlad,
            'tgt_pred_post_edited': line.tgt_pred_post_edited,
        })