In [1]:
import os, sys
import json
import jsonlines
import pandas as pd
from glob import glob
from tqdm import tqdm

import cohere
import requests
import base64
import pickle
import torch

import arabic_reshaper
from bidi.algorithm import get_display

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/samuelcahyawijaya_cohere_com/.config/sagemaker/config.yaml


In [2]:
# Init Cohere API
co = cohere.ClientV2(
    base_url="https://stg.api.cohere.com", 
    api_key=os.getenv("CO_API_KEY_STAGING")
)

In [3]:
# Read raw data
ar_raw_df = pd.read_csv('repairbench_ar.csv')
ar_raw_df = ar_raw_df.loc[ar_raw_df['product_name'] != 'PS4']# Skip PS4 since there is only a single document chunk pointing to an extenal url

# Read context documents
docs_data = []
for path in glob('/home/samuelcahyawijaya_cohere_com/repos_v3/retrieval-augmentation/repairbench/repairbench_ar_docs/*'):
    if 'PS4' in path:
        continue
    docs_data += json.load(open(path, 'r'))
    print(path, len(json.load(open(path, 'r'))))

/home/samuelcahyawijaya_cohere_com/repos_v3/retrieval-augmentation/repairbench/repairbench_ar_docs/Bose_quietcomfort-45_AR.pdf.json 56
/home/samuelcahyawijaya_cohere_com/repos_v3/retrieval-augmentation/repairbench/repairbench_ar_docs/Bose_og_tv-speaker_AR.pdf.json 48
/home/samuelcahyawijaya_cohere_com/repos_v3/retrieval-augmentation/repairbench/repairbench_ar_docs/FlyingTiger_advent_calendar_AR.pdf.json 5
/home/samuelcahyawijaya_cohere_com/repos_v3/retrieval-augmentation/repairbench/repairbench_ar_docs/Bose_frames-soprano-tenor_AR.pdf.json 53


In [4]:
# Build Index
if os.path.exists('arabic_docs_embed.pkl'):
    print('loading docs embed from cache...')
    docs_embeds = pickle.load(open('arabic_docs_embed.pkl', 'rb'))
else:
    texts = list(map(lambda x: x['text'], docs_data))
    bs = 32

docs_embeds = []
for i in tqdm(range((len(texts) // bs) + 1)):
    text_batch = texts[i*bs:(i+1)*bs]
    for j, text in enumerate(text_batch):
        text_batch[j] = '\n'.join(list(map(lambda x: get_display(x), text.split('\n'))))
        
    response = co.embed(
        model="embed-multilingual-v3.0",
        input_type="search_document", 
        embedding_types=["float"],
        texts=text_batch
    )

    docs_embeds += response.embeddings.float_
docs_embeds = torch.Tensor(docs_embeds)
pickle.dump(docs_embeds, open('arabic_docs_embed.pkl', 'wb'))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:12<00:00,  2.04s/it]


In [28]:
# Embed all queries
if os.path.exists('arabic_query_embed.pkl'):
    print('loading query embed from cache...')
    query_embeds = pickle.load(open('arabic_query_embed.pkl', 'rb'))
else:
    query_embeds = torch.Tensor(co.embed(
        model="embed-multilingual-v3.0",
        input_type="search_query", 
        embedding_types=["float"],
        texts=ar_raw_df['query'].tolist()
    ).embeddings.float_)
    pickle.dump(query_embeds, open('arabic_query_embed.pkl', 'wb'))

filename_to_product_map = {
    'Bose_frames-soprano-tenor_AR.pdf': 'Bose Frames',
    'Bose_og_tv-speaker_AR.pdf': 'Bose TV Speaker',
    'Bose_quietcomfort-45_AR.pdf': 'Bose QuietComfort 45 Headphones',
    'FlyingTiger_advent_calendar_AR.pdf': 'Flying Tiger Advent Calendar'
}

formatted_data = []
for i, row in tqdm(ar_raw_df.iterrows()):
    # Search for top-15 documents
    query_embed = query_embeds[i,:]
    topk_indices = (query_embed @ docs_embeds.T).topk(15).indices.tolist()
    # print((query_embed @ docs_embeds.T).topk(15).values.min())
    
    # Add output data
    formatted_data.append({
        'unique_id': f"Q{i}",
        'turn_no': 0,
        'source_dataset': 'repair_bench_2024-09-30_arabic',
        'history': [],
        'question': row['query'],
        'search_query': [{'type': 'direct-injected-document'}],
        'long_answer': None,
        'rationale': None,
        'waypoints': [0],
        'init_plan_rationale': None,
        'rephrased_answer': None,
        'search_results': [[{
            'unique_id': f"{docs_data[idx]['metadata']['filename']}_{docs_data[idx]['metadata']['page_number']}",
            'rank': rank,
            'snippet': '\n'.join(list(map(lambda x: get_display(x), docs_data[idx]['text'].split('\n')))),
            'source': docs_data[idx]['metadata']['filename'],
            'text': '\n'.join(list(map(lambda x: get_display(x), docs_data[idx]['text'].split('\n')))),
            'url': '',
            'title': filename_to_product_map[docs_data[idx]['metadata']['filename']],
            'is_relevant': None,
            'html_view': '<!DOCTYPE html>\n<html>\n<body><div style="text-align: right;">\n' + \
                '\n'.join(list(map(lambda x: get_display(x), docs_data[idx]['text'].split('\n')))) + \
                '</div></body>\n</html>',
            'chunk_idx': idx
        } for rank, idx in enumerate(topk_indices)]],
        'gold_search_results': [[]],
        'closed_book_answer': None,
        'short_answer': None,
        'database_name': None,
        'stop_tool_training': [True],
        'follow_up_questions': None,
        'follow_up_answers': None,
        'timestamp_override': None,
        'preamble_override': None,
        'structured_preamble': None,
        'original_document': None,
        'chunk_strategy': None,
        'chunk_size': None,
        'is_sensitive': False,
        'sensitive_category': None,
        'metadata': {
            'label': row['label'], 
            'question_type': row['question_type'], 
            'product_name': row['product_name'], 
            'failure_mode':row['failure_mode'], 
            'category': row['category'],
        },
        'no_search_required': None,
        'custom_tool_definitions': None,
    })

with jsonlines.open('arabic-annotation-dataset.jsonl', 'w') as writer:
    writer.write_all(formatted_data)

loading query embed from cache...


50it [00:03, 13.14it/s]


In [27]:
ar_raw_df['product_name'].unique()

array(['Bose Frames', 'Bose QuietComfort 45 Headphones',
       'Bose TV Speaker', 'Flying Tiger Advent Calendar'], dtype=object)

In [25]:
set(map(lambda x: x['metadata']['filename'], docs_data))

{'Bose_frames-soprano-tenor_AR.pdf',
 'Bose_og_tv-speaker_AR.pdf',
 'Bose_quietcomfort-45_AR.pdf',
 'FlyingTiger_advent_calendar_AR.pdf'}

In [30]:
list(map(lambda x: x['title'], formatted_data[0]['search_results'][0]))

['Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose Frames',
 'Bose TV Speaker',
 'Bose Frames',
 'Bose QuietComfort 45 Headphones']

In [33]:
for i, row in pd.read_json('gs://cohere-data/retrieval_augmentation/repair_bench/20240930/output/arabic-annotation-dataset.jsonl', lines=True).iterrows():
    print(list(map(lambda x: x['title'],row['search_results'][0])))

['Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose TV Speaker', 'Bose Frames', 'Bose QuietComfort 45 Headphones']
['Bose Frames', 'Bose QuietComfort 45 Headphones', 'Bose Frames', 'Bose QuietComfort 45 Headphones', 'Bose Frames', 'Bose Frames', 'Bose TV Speaker', 'Bose QuietComfort 45 Headphones', 'Bose Frames', 'Bose Frames', 'Bose TV Speaker', 'Bose TV Speaker', 'Bose Frames', 'Bose Frames', 'Bose Frames']
['Bose Frames', 'Bose Frames', 'Bose TV Speaker', 'Bose QuietComfort 45 Headphones', 'Bose Frames', 'Bose QuietComfort 45 Headphones', 'Bose QuietComfort 45 Headphones', 'Bose QuietComfort 45 Headphones', 'Bose TV Speaker', 'Bose Frames', 'Bose Frames', 'Bose TV Speaker', 'Bose TV Speaker', 'Bose TV Speaker', 'Bose TV Speaker']
['Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose Frames', 'Bose QuietComfort 45 Headphones', 'Bose TV Speaker', 'Bose Q