# Prompt tuning for translating English > Mambai

Translate an English sentence to Mambai by:

1. Find closest sentences using LASER
2. Find dictionary entries for words in sentence
3. Construct prompt, with a mix of example sentences and dict entries

TODO:

- Clean up Mambai corpus
  - Some dict entries missing as it relies on font weight, which is not always OCRed correctly
    - others need to be separated (e.g. "sit; live")
  - Some sentences poorly aligned
- Get similar sentences based on syntactic similarity, instead of `get_sentences_starting_with_same_words`


In [None]:
import dotenv

dotenv.load_dotenv()

### Get Mambai corpus, split between sentences and dict entries


In [None]:
import csv
import json
import random

with open("mambai_parallel_eng_mgm.csv") as f:
    reader = csv.DictReader(f)
    data = list(reader)

print(f"Total of {len(data)} rows in the dataset.")

train_data = [r for r in data if r["split"] == "train"]

print(f"Total of {len(train_data)} rows in the training set.")

In [None]:
import json

with open("eng_mgm.json") as f:
    dict_entries = json.load(f)

In [None]:
# experiment tracking in https://docs.google.com/spreadsheets/d/1wP0tDiPqmS8UWNiyY4oSAZn3Mzw2Z4FTZZOQ5lf-NHQ/edit#gid=0

# config = {
#     "model": "gpt-4",
#     "train_rows": len(train_data),
#     "retrieval_sentences": {"tfidf": 5},
#     "retrieval_dict": True,
#     "bleu": 12.6,
#     "chrf": 32.4,
# }

config = {
    "model": "gpt-4",
    "train_rows": len(train_data),
    "retrieval_sentences": {"tfidf": 10},
    "retrieval_dict": False,
    "bleu": 12.6,
    "chrf": 32.4,
}

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel


def find_top_k_tfidf(sentence, eng_sentences, top_k=5):
    # Ensure the given sentence is included in the list of sentences to compare
    documents = [sentence] + eng_sentences

    # Initialize the TF-IDF Vectorizer and transform the documents into TF-IDF vectors
    tfidf_vectorizer = TfidfVectorizer()
    tfidf_matrix = tfidf_vectorizer.fit_transform(documents)

    # Compute the cosine similarity between the first document (the given sentence)
    # and all other documents
    cosine_similarities = linear_kernel(tfidf_matrix[0:1], tfidf_matrix).flatten()

    # Find the indices of the top k similarity scores (excluding the first document itself)
    # We add 1 to skip the first document which is the input sentence itself
    top_k_indices = cosine_similarities[1:].argsort()[-top_k:][::-1] + 1

    return [train_data[index - 1] for index in top_k_indices]


# Example usage:
eng_sentences = [row["English (eng)"] for row in train_data]
input_sentence = "We will be sitting there having coffee"

num_to_retrieve = config["retrieval_sentences"]["tfidf"]
top_tfidf = find_top_k_tfidf(input_sentence, eng_sentences, num_to_retrieve)

print(f"Top {num_to_retrieve} similar sentences for '{input_sentence}'")
for sentence in top_tfidf:
    print(sentence["English (eng)"])

### Get LASER encoder, encode English sentences from Mambai corpus


In [None]:
from laser_encoders import LaserEncoderPipeline

encoder = LaserEncoderPipeline(lang="eng_Latn")

embeddings = encoder.encode_sentences([row["English (eng)"] for row in train_data])

### Construct prompt


In [None]:
from sklearn.metrics.pairwise import cosine_similarity

import spacy

nlp = spacy.load("en_core_web_sm")


def find_top_k_semantic_laser(input, top_k=5):
    embedded_input = encoder.encode_sentences([input])
    closest_indices = cosine_similarity(embedded_input, embeddings)[0].argsort()[
        -top_k:
    ][::-1]
    return [train_data[i] for i in closest_indices]


def get_sentences_starting_with_same_words(input):
    input_words = input.split()
    first_two_words = " ".join(input_words[:2])
    for row in train_data:
        if row["English (eng)"].startswith(first_two_words):
            yield row


def get_relevant_dict_entries(sent):
    doc = nlp(sent)
    lemmas = [token.lemma_ for token in doc]
    for lemma in lemmas:
        for row in dict_entries:
            if row["entry"] == lemma:
                yield row

In [None]:
prompt_template = """You are a translator for the Mambai language, originally from Timor-Leste.

# Example sentences
{sentences_str}

{dict_section}English: {input}
Mambai:"""


def format_prompt(sentences_str, dict_str, input):
    dict_section = f"# Dictionary entries\n{dict_str}\n\n" if dict_str else ""
    return prompt_template.format(
        sentences_str=sentences_str, dict_section=dict_section, input=input
    )


def get_sentences_str(rows):
    out = ""
    for row in rows:
        out += f"English: {row['English (eng)']}\n"
        out += f"Mambai: {row['Mambai (mgm)']}\n"
        out += "\n"
    return out


def get_dict_str(dict_entries):
    out = ""
    for row in dict_entries:
        out += f"English: {row['entry']}\n"
        out += f"Mambai: {row['definition']}\n"
        out += "\n"
    return out


def get_prompt(input):
    sentences = []
    if "tfidf" in config["retrieval_sentences"]:
        tfidf_count = config["retrieval_sentences"]["tfidf"]
        sentences.extend(find_top_k_tfidf(input, eng_sentences, tfidf_count))
    if "semantic_laser" in config["retrieval_sentences"]:
        semantic_count = config["retrieval_sentences"]["semantic_laser"]
        sentences.extend(find_top_k_semantic_laser(input, semantic_count))
    if config["retrieval_dict"]:
        dict_entries = list(get_relevant_dict_entries(input))
    else:
        dict_entries = []
    return format_prompt(
        sentences_str=get_sentences_str(sentences),
        dict_str=get_dict_str(dict_entries),
        input=input,
    )

In [None]:
print(get_prompt("We will be sitting there having coffee"))

In [None]:
import os
from openai import AsyncOpenAI

client = AsyncOpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)


async def get_gpt4_translation(sentence):
    chat_completion = await client.chat.completions.create(
        messages=[{"role": "user", "content": get_prompt(sentence)}],
        model="gpt-4-turbo-preview",
    )
    translation = chat_completion.choices[0].message.content
    # if "Mambai: " in translation, keep the part of the string after it
    translation = translation.split("Mambai: ")[-1]
    return translation

In [None]:
dev_data = [r for r in data if r["split"] == "dev"]
print(f"Total of {len(dev_data)} rows in the validation set.")

In [None]:
import asyncio


async def process_batch(sentences):
    tasks = [get_gpt4_translation(sentence) for sentence in sentences]
    return await asyncio.gather(*tasks)


async def translate_data(dev_data):
    batch_size = 10
    for i in range(0, len(dev_data), batch_size):
        print(f"Processing batch {i+1} to {i+batch_size}")
        batch = dev_data[i : i + batch_size]
        translations = await process_batch([row["English (eng)"] for row in batch])
        # for each row, add the translation under key 'mgm_translation'
        for row, translation in zip(batch, translations):
            row["mgm_translation"] = translation

In [None]:
# Function to run asyncio tasks from Jupyter Notebook
async def run(coroutine):
    try:
        # Attempt to get the running event loop
        loop = asyncio.get_running_loop()
    except RuntimeError:  # If no running event loop
        loop = asyncio.new_event_loop()  # Create a new loop
        asyncio.set_event_loop(loop)
    return await coroutine


await run(translate_data(dev_data))

In [27]:
import evaluate

chrf = evaluate.load("chrf")
bleu = evaluate.load("bleu")

predictions = [row["mgm_translation"] for row in dev_data]
references = [[row["Mambai (mgm)"]] for row in dev_data]

# Calculate metrics
bleu_results = bleu.compute(predictions=predictions, references=references)
chrf_results = chrf.compute(predictions=predictions, references=references)
chrfpp_results = chrf.compute(
    predictions=predictions, references=references, word_order=2
)

print(f"BLEU score: {bleu_results['bleu']}")
print(f"ChrF score: {chrf_results['score']}")
print(f"Chrf++ score: {chrfpp_results['score']}")

BLEU score: 0.13551247392372268
ChrF score: 38.21808456442234
Chrf++ score: 37.240239918245415
