In [1]:
import datasets
from datasets import load_dataset, load_from_disk
from in_context_ssl.reasoning.template import *
import os
import openai
from openai import OpenAI
from tqdm import tqdm
import numpy as np
from pydantic import BaseModel, Field
import json
from in_context_ssl.reasoning.utils import *
from in_context_ssl.reasoning.dataset import *
import re
import pandas as pd
from in_context_ssl.reasoning.utils import *
import torchmetrics
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set your API here
client = OpenAI()

## Preprocessing

In [29]:
def create_translation_dataset(target_lang, stage):
    split = "dev" if stage == "train" else "devtest"
    ds = load_dataset("openlanguagedata/flores_plus")[split]
    df  = ds.to_pandas()

    def add_embedding(doc):
        out_doc = {
            "embedding": client.embeddings.create(
                input=[doc["question"]],
                model="text-embedding-3-large"
            ).data[0].embedding
        }
        return out_doc

    df_source = df[df["iso_639_3"] == "eng"]
    df_target = df[df["iso_639_3"] == target_lang]
    joined_df = pd.merge(df_source, df_target, on="id", how="inner")[[
        "text_x", "text_y", "topic_x"
    ]]
    joined_df = joined_df.rename(columns={
        "text_x": "question",
        "text_y": "answer",
        "topic_x": "group"
    })

    ds = datasets.Dataset.from_pandas(joined_df)
    ds = ds.map(add_embedding)
    ds.save_to_disk("in_context_ssl/reasoning/data/flores_{}_{}.hf".format(target_lang, stage))
    return ds


In [33]:
ds = load_dataset("openlanguagedata/flores_plus")["devtest"]
df  = ds.to_pandas()
df_source = df[df["iso_639_3"] == "bem"] # 

In [37]:
target_langs = ["bem"]
ds = load_from_disk("in_context_ssl/reasoning/data/flores_{}_test.hf".format(lang))
ds = ds.shuffle()
ds = ds.select(range(200))
ds.save_to_disk("in_context_ssl/reasoning/data/flores_{}_test_new.hf".format(lang))

Saving the dataset (1/1 shards): 100%|██████████| 200/200 [00:00<00:00, 3683.71 examples/s]


## Inference

In [None]:
k_total = 100
k_gt = 16

ds = TranslationDatasetBem()
print(ds.get_demonstrations(
    "in_context_ssl/reasoning/data/flores_fij_psl_k={}_entropy.hf".format(k_gt),
    k=k_total-k_gt, k_gt=k_gt, 
    style="psl", answer=True, rationale=False, quantile=0.9, topk=False, seed=42
))
preds = []
gold = []
messages = []

for inst in tqdm(ds):
    choice = query_openai(client, inst["query"], model="gpt-4o-mini", n=1, structured_output=False, confidence=False, logprobs=True)[0]
    o = parse_output_translation("Bemba", choice.message.content)
    messages.append(choice.message.content)
    preds.append(o["answer"])
    gold.append(inst["answer"])

chrf = torchmetrics.CHRFScore(return_sentence_level_score=True)
chrf(preds, gold)
score = chrf.compute()[1].mean()

## Naive-SemiICL

In [None]:
preds = []
gold = []
confidences = []
messages = []

new_ds_verbalized = []
new_ds_entropy = []

ds = TranslationDatasetBem()

k=16
for inst in tqdm(ds.train_iter(
    "in_context_ssl/reasoning/data/flores_fij_train.hf",
    k=k, answer=True, rationale=False, seed=42
)):
    choices = query_openai(client, inst["query"], n=1, model="gpt-4o-mini", structured_output=False, confidence=True, logprobs=True)

    o_verbalized = aggregate(choices, parser=lambda x: parse_output_translation("Bemba", x), confidence="verbalized", rationale=False)
    o_entropy = aggregate(choices, parser=lambda x: parse_output_translation("Bemba", x), confidence="entropy", rationale=False)

    d_verbalized = {
        "question": inst["question"],
        "answer": o_verbalized["answer"],
        "group": inst["group"],
        "confidence": o_verbalized["confidence"],
    }
    d_entropy = dict(d_verbalized)
    d_entropy["confidence"] = o_entropy["confidence"]
    new_ds_verbalized.append(d_verbalized)
    new_ds_entropy.append(d_entropy)

datasets.Dataset.from_pandas(pd.DataFrame(
    new_ds_verbalized
)).save_to_disk("in_context_ssl/reasoning/data/flores_fij_psl_k={}_verbalized.hf".format(k))
datasets.Dataset.from_pandas(pd.DataFrame(
    new_ds_entropy
)).save_to_disk("in_context_ssl/reasoning/data/flores_fij_psl_k={}_entropy.hf".format(k))