## A notebook to generate synthetic grants for underrepresented RA categories

In [None]:
import pandas as pd
import json
from tqdm import tqdm

from langchain_community.llms import Ollama

In [None]:

llm = Ollama(model="mixtral", num_gpu=1, keep_alive='1h', format='json')
# llm.invoke("Why is the sky blue?")

n = 20
def llama_prompt(text):
    return llm.invoke(f"""
        <purpose>              
        please create {n} separate grants from the following selection criteria. 
        </purpose>

        <instructions>
            <instruction> For each grant, I want a title and abstract, as a dictionary ('title' and 'abstract' as keys) which I can parse using json.loads() in Python. </instruction>
            <instruction> The output should look like a jsonl of length {n} with each line being a dictionary with 'title' and 'abstract' keys </instruction>
            <instruction> Make sure to return EXACTLY {n} grants. Do not include any explanations or apologies in your responses. </instruction>
            <instruction> Please make sure that the abstracts are coherent and relevant to the title. </instruction>
            <instruction> The grants should be written in a formal tone. </instruction>
            <instruction> the grant selection criteria can be found below. </instruction>
        </instructions>
    
        <content>
        {text}
        </content>


    """)

In [None]:
jsonl_path = "../data/label_names/ra_description.jsonl"

with open(jsonl_path, "r") as f:
    cat_dict = [json.loads(line) for line in f]


In [None]:
train = pd.read_parquet('../data/preprocessed/ra/train.parquet')

In [None]:
pd_synthetic = pd.DataFrame(columns=train.columns)
pd_synthetic.to_csv("../data/synthetic/ra/train.csv", index=False)

In [None]:
# set a limit on trial runs, sometimes the llm does not return the structured output we were hoping for
trial_runs = 3

for i in range(trial_runs):
    for ra_category in tqdm(cat_dict):
        key = list(ra_category.keys())[0]
        cat_description = ra_category[key]

        # read in in case of interruption
        pd_synthetic = pd.read_csv("../data/synthetic/ra/train.csv")

        # find categories that are done so we can skipt these
        # categories done have at least one grant in the synthetic dataset
        done = list(pd_synthetic.columns[:-1][pd_synthetic.sum()[:-1]>0])

        if key not in done:
            category = key
            result = llama_prompt(cat_description)
            result_json = json.loads(result)

            # if result_json is a list, transform back to a dictionary
            if isinstance(result_json, list):
                result_json = {i: result for i, result in enumerate(result_json)}
            # check if result_json is a dictionary or a list of dictionaries
            elif isinstance(result_json, dict):
                result_json = [result for result in list(result_json.values())]
                if isinstance(result_json[0], list):
                    result_json = [result for result in result_json[0]]
            # check if result_json is a list of dictionaries or a list of lists
            elif isinstance(result_json[0], list):
                result_json = [result for result in result_json[0]]

            for grant in result_json:
                grant_text = grant['title']+' '+grant['abstract']
                grant_text = grant_text.replace('\n', ' ')
                grant_text = grant_text.replace('\r', ' ')
                grant_text = grant_text.replace('\t', ' ')
                grant_text = grant_text.lower()
                labels = [0]*(len(train.columns)-1)
                new_row = pd.DataFrame([list(labels)+[grant_text]], columns=train.columns)
                # insert 1 at column corresponding to category
                new_row[category] = 1

                # add new row to synthetic dataset
                pd_synthetic = pd.concat([pd_synthetic, new_row], ignore_index=True)

                # save synthetic dataset
                pd_synthetic.to_csv("../data/synthetic/ra/train.csv", index=False)


 33%|███▎      | 4/12 [19:35<43:14, 324.30s/it]

In [None]:
# scramble rows
train_enhanced = train_enhanced.sample(frac=1).reset_index(drop=True)

# save
train_enhanced.to_parquet('../data/preprocessed/ra/train_enhanced.parquet')