In [400]:
import os
import subprocess

os.chdir(
    subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip().decode("utf-8")
)
from tqdm import tqdm
from pathlib import Path
import pandas as pd

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

In [13]:
SOURCE_FOLDER = Path("climateGPT/data/sft_dataset/open_mixtral_8x7b")
CLEAN_FOLDER = SOURCE_FOLDER / "clean"
CLEAN_FOLDER.mkdir(exist_ok=True, parents=True)
TEMP_FOLDER = SOURCE_FOLDER / "temp"
TEMP_FOLDER.mkdir(exist_ok=True, parents=True)

In [87]:
# Load in txt format the input output from mixtral
def load_clean_and_raw_files():
    clean_files = [file.name for file in CLEAN_FOLDER.glob("*.json")]
    all_files = [file.name for file in SOURCE_FOLDER.glob("*.json") if file.name not in clean_files]
    print(f"Number of files to correct: {len(all_files)}")
    return all_files

In [80]:
## File corrections

In [311]:
def identity(x: str) -> str:
    return x
def delete_unwanted_newlines_and_comma(x: str) -> str:
    return (
        x
        .replace('{\n', '{')
        .replace('\n}', '}')
        .replace('\n"answer"', ' "answer"')
        .replace("},", "}")
        .replace("}\\", "}")
        )
def remove_duplicate_newline(x: str) -> str:
    return x.replace("\n\n", "\n")
def remove_newline_comma_newline(x: str) -> str:
    return x.replace("\n,\n", "\n")
def add_last_brackets(x: str) -> str:
    if x[-1] == '"':
        return x + "}"
    else: 
        return x
def add_comma_before_answer_key(x: str) -> str:
    return x.replace('\n"answer"', ', "answer"').replace('" "answer"', '", "answer"')
def remove_open_end_bracket(x: str) -> str:
    if x[-2:] in ["\n{", "{\n"]:
        return x[:-2]
    else:
        return x
def add_missing_newline(x: str) -> str:
    return x.replace("}{", "}\n{")
def correct_end_brackets(x: str) -> str:
    if x[-1] == '"':
        return x + '}'
    else:
        return x.replace('"}"}', '"}')
def remove_unwanted_characters(x: str) -> str:
    return x.replace("</s>", "")

In [312]:
## Testing format
def is_ndjson_format_correct(x: str) -> bool:
    with open(TEMP_FOLDER / "tmp.ndjson", "w") as f:
        f.write(x)
    try:
        df = pd.read_json(TEMP_FOLDER / "tmp.ndjson", lines=True)
        return True
    except:
        return False

In [402]:
def clean_and_save_ndjson():
    all_files = load_clean_and_raw_files()
    count_correct = 0
    for file in all_files:
        with open(SOURCE_FOLDER / file, "r") as f:
            tmp = f.read()
            if len(tmp) < 200:
                os.remove(SOURCE_FOLDER / file)
            else:
                for func in [
                    identity, 
                    remove_open_end_bracket,
                    delete_unwanted_newlines_and_comma, 
                    remove_duplicate_newline, 
                    remove_newline_comma_newline, 
                    add_last_brackets, 
                    add_comma_before_answer_key, 
                    add_missing_newline,
                    correct_end_brackets,
                    remove_unwanted_characters,
                    ]:
                    tmp = func(tmp)
                    
                    if is_ndjson_format_correct(tmp):
                        with open(CLEAN_FOLDER / file, "w") as f:
                            f.write(tmp)
                        count_correct += 1
                        break
    print(f"Number of corrected files: {count_correct} representing {count_correct / len(all_files): .2%} of the total")

In [None]:
clean_and_save_ndjson()

In [387]:
## For remaining files, call Mistral large to correct them

In [396]:
prompt = """Can you correctly reformat the following file in NDJSON format?
The expected format needs to be (for a two lines example):
{{“instruction": "blablabla", "answer": "blablabla"}}\n{{“instruction": "blablabla", "answer": "blablabla"}}.
If some answers are incomplete, just remove the line.
Only output the NDJSON file.

The file:
{file}
"""

In [397]:
all_files = load_clean_and_raw_files()

Number of files to correct: 26


In [401]:
api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key)
model = "mistral-large-latest"

for file in tqdm(all_files):
    with open(SOURCE_FOLDER / file, "r") as f:
        tmp = f.read()
        
    messages = [
        ChatMessage(role="user", content=prompt.format(**{"file": tmp}))
    ]
    chat_response = client.chat(
        model=model,
        messages=messages,
    )
    
    corrected_file = chat_response.choices[0].message.content
    with open(SOURCE_FOLDER / file, "w") as f:
        f.write(corrected_file)

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
100%|██████████| 26/26 [04:04<00:00,  9.39s/it]


In [405]:
clean_and_save_ndjson()

Number of files to correct: 1
Number of corrected files: 1 representing  100.00% of the total
