In [None]:
import json
import os
import sys
import time
import copy
import numpy as np
import pandas as pd

In [None]:
with open("../data/candle/candle_dataset_v1.jsonl", "r") as f:
    data = f.readlines()
    data = [json.loads(line) for line in data]
data = [d for d in data if d["domain"] == "countries"]
for d in data:
    d.pop("raw_sentences")
print(len(data))

In [None]:
from collections import Counter
counter = Counter([d["subject"] for d in data])
print(counter.most_common(10))
print(counter["Iran"])
print(counter["Kenya"])

In [None]:
import copy
country_mapping = {"United States": ["United States of America","United States", "U.S.A.", "U.S.", "USA", "US", "Americans", "American", "America", ],
                   "India": ["Indians",  "Indian", "India", ],
                   "China": ["China","Chinese"],
                   "Iran": ["Iranians","Iranian", "Iran", "Persians", "Persian", "Persia", ],
                   "Kenya": ["Kenyans", "Kenyan", "Kenya", ]}


for s in data:
    if s["subject"] not in country_mapping:
        continue
    for alias in country_mapping[s["subject"]]:
        if alias in s["assertion"]:
            s["assertion"] = s["assertion"].replace(alias, "<mask>")
            s["blank"] = alias
        # s["assertion"] = s["assertion"].replace(alias, "<mask>")
        
        

In [None]:
import random
random.seed(42)
samples = {}

def get_top_samples(data, country, number=100):
    data_country = [d for d in data if d["subject"] == country and d["assertion"].count("<mask>")==1]
    sorted_data = sorted(data_country, key=lambda x: x['combined_score'], reverse=True)

    if len(sorted_data) < number:
        print(f"Warning: {country} has only {len(sorted_data)} samples")
    return sorted_data[:number]

for country in country_mapping.keys():
    samples[country] = get_top_samples(data, country, 140)



In [None]:
inputs_country = {} # lang, country {"prompt":, "answer":}


for country in country_mapping.keys():
    inputs_country[country] = []
    for sample in samples[country]:
        inputs_country[country].append({"prompt": sample["assertion"], "answer": sample["blank"]})

In [None]:
with open("../data/candle/inputs_country.json", 'w', encoding='utf-8') as json_file:
    json.dump(inputs_country, json_file, ensure_ascii=False, indent=4)


## Translate

In [None]:
import requests, uuid, json
from tqdm import tqdm 
from api_keys import msft_key 

In [None]:
# Add your key and endpoint
key = msft_key
endpoint = "https://api.cognitive.microsofttranslator.com"

# location, also known as region.
# required if you're using a multi-service or regional (not global) resource. It can be found in the Azure portal on the Keys and Endpoint page.
location = "eastus"

path = '/translate'
constructed_url = endpoint + path

def set_header(country):
    country_lang_mapping = {"United States":"en",
                        "India":'hi',
                        "China":'zh-Hans',
                        "Iran":'fa',
                        "Kenya":'sw',}
    
    lang_list = ['en', 'zh-Hans',"hi","fa","sw",]
    from_lang = 'en'
    to_lang = country_lang_mapping[country]
    
    lang_list.remove(from_lang)

    params = {
        'api-version': '3.0',
        'from': from_lang,
        'to': to_lang,
    }

    headers = {
        'Ocp-Apim-Subscription-Key': key,
        # location required if you're using a multi-service or regional (not global) resource.
        'Ocp-Apim-Subscription-Region': location,
        'Content-type': 'application/json',
        'X-ClientTraceId': str(uuid.uuid4()),
    }
    return params, headers


In [None]:
with open("../data/candle/inputs_country.json", 'r', encoding='utf-8') as json_file:
    inputs_country = json.load(json_file)

country_translation = {"the United States": "America",
                       "China": "中国",
                       "Iran": "ایران",
                        "India": "भारत",
                        "Kenya":"Kenya"
                       }


In [None]:

inputs_lang_country = {}
inputs_lang_country['en'] = {}

for country in inputs_country:
    print(country)
    params, headers = set_header(country)

    # translate 
    translated_lang = params['to']

    if translated_lang not in inputs_lang_country:
        inputs_lang_country[translated_lang] = {}
    inputs_lang_country[translated_lang][country] = []
    inputs_lang_country['en'][country] = []

    for sample in tqdm(inputs_country[country][:]):
        
        if translated_lang != 'en':
            text = sample['prompt'].replace("<mask>", sample["answer"])
            body = [{'text': text}]
            request = requests.post(constructed_url, params=params, headers=headers, json=body)
            response = request.json()
            response = response[0]['translations'][0]['text']
            if country_translation[country] in response:
                answer = country_translation[country]
                response = response.replace(country_translation[country], "<mask>")
                inputs_lang_country[translated_lang][country].append({"prompt":response, "answer":answer})
                inputs_lang_country['en'][country].append(sample.copy())
        else:
            inputs_lang_country['en'][country].append(sample.copy())
    
    print(len(inputs_lang_country['en'][country]))


In [None]:
with open("../data/candle/samples_lang_country.json", 'w', encoding='utf-8') as json_file:
    json.dump(inputs_lang_country, json_file, ensure_ascii=False, indent=4)


In [None]:
prompt_task = {'zh-Hans':"用国名完成句子：",
        'en':'Complete the sentence with a country name:',
        'hi': "किसी देश के नाम के साथ वाक्य पूरा करें:",
        'fa':"جمله را با نام کشور تکمیل کنید:",
        'sw':"Kamilisha sentensi kwa jina la nchi:",
        }

prompt_format = {'zh-Hans':"以 {“答案”： } 的 json 格式回答。",
            'en':"Answer in the json format of {\"Answer\": }.",
            'hi':"{\"उत्तर\": } के json स्वरूप में उत्तर दें।",
            'fa': "پاسخ در قالب json {\"پاسخ\": }.",
            'sw':"Jibu katika umbizo la json la {\"Answer\": }.",
        }

In [None]:
inputs_lang_country_truncated = copy.deepcopy(inputs_lang_country)

for lang in inputs_lang_country_truncated:
    for country in inputs_lang_country_truncated[lang]:
        inputs_lang_country_truncated[lang][country] = inputs_lang_country_truncated[lang][country][:100]
        for sample in inputs_lang_country_truncated[lang][country]:
            sample["prompt"] = prompt_task[lang] + '\n' + sample["prompt"] + '\n' +prompt_format[lang]


In [None]:
with open("../data/candle/inputs_lang_country.json", 'w', encoding='utf-8') as json_file:
    json.dump(inputs_lang_country_truncated, json_file, ensure_ascii=False, indent=4)
