# cosmosage

See README for more information.

In [None]:
import pickle
import os
import scrape_arxiv
import analyze_asl_dict
import glob
import random
import multiprocessing
import json
import time
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models import ChatOpenAI
import glob
import random

## Step 1: Choose arXiv papers

In [None]:
cache_file = "datasets/arxiv_ids_cache.pkl"

# Check if the cache file exists
if os.path.exists(cache_file):
    # Load the cached data
    with open(cache_file, "rb") as f:
        arxiv_ids = pickle.load(f)
else:
    # unique arXiv numbers from the asl database
    db_path = "datasets/dict_20231123.db"
    arxiv_id_asl_tagged = analyze_asl_dict.extract_unique_arxiv_numbers(db_path)

    # also extract all of my papers
    search_params = {"search_query": "au:de_Haan_T", "searchtype": "author"}
    arxiv_id_tdh = scrape_arxiv.get_arxiv_ids(search_params)

    # also extract the papers with "cosmic microwave background" in the abstract
    search_params = {"search_query": "abs:\"cosmic microwave background\""}
    arxiv_id_cmb = scrape_arxiv.get_arxiv_ids(search_params)

    # more arxiv papers recommended for me by asl
    arxiv_id_asl_rec = scrape_arxiv.other_arxiv_recommendation_ids()

    # join all of these arxiv ids and remove duplicates
    arxiv_ids = arxiv_id_asl_tagged + arxiv_id_tdh + arxiv_id_cmb + arxiv_id_asl_rec
    arxiv_ids = list(set(arxiv_ids))

    # Save the data to the cache file
    with open(cache_file, "wb") as f:
        pickle.dump(arxiv_ids, f)


## Step 2: Synthetic data generation on arXiv papers

Here, we generate synthetic data using the following:
 - instruction-tuned model to generate the QA pairs & summaries
 - VLLM server to load the model once and provide good throughput
 - langchain to handle 
   - gathering of papers
   - extracting from PDFs
   - chunking data
   - summarization

### Step 2.1: Download, parse, and cache the arXiv papers

In [None]:
# filter out any arxiv ids that don't start with 07 through 24 since the process will fail on these anyway
new_arxiv_ids = []
for arxiv_id in arxiv_ids:
    outfile = f"datasets/arxiv_cache/{arxiv_id}.pkl"
    if os.path.exists(outfile):
        continue
    id_year = int(arxiv_id[0:2])
    if id_year >= 7 and id_year <= 24:
        new_arxiv_ids.append(arxiv_id)
print(f"Found {len(new_arxiv_ids)} arxiv ids from 2007 to 2024.")

print("Shuffling arxiv ids...")
random.shuffle(new_arxiv_ids)

def cache_arxiv_id(arxiv_id):
    outfile = f"datasets/arxiv_cache/{arxiv_id}.pkl"
    if os.path.exists(outfile):
        print(f"File {outfile} already exists, skipping.")
        return
    print(f"Processing {arxiv_id}\n")
    try:
        print(f"Waiting 2 seconds before scraping {arxiv_id}...")
        time.sleep(2)
        paper = scrape_arxiv.ArxivPaper(arxiv_id)
        paper.save_to_cache()
    except Exception as e:
        print(f"Error processing {arxiv_id}: {e}\n")

# Create a pool of workers
pool = multiprocessing.Pool(1)

# Map the process_arxiv_id function to each arxiv_id in parallel
for arxiv_id in new_arxiv_ids:
    pool.apply_async(cache_arxiv_id, args=(arxiv_id,))

# Close the pool of workers
pool.close()
pool.join()

### Step 2.2 Add summaries

For the generation of summaries, you can use the following code block. However, to speed things up, you can also run the script `generate_summaries_standalone.py` which can be run in parallel on e.g. a GPU cluster.

In [None]:
# add summaries to the pkl files
filenames = glob.glob("datasets/arxiv_cache/*.pkl")
random.shuffle(filenames)

def add_summary(filename):
    with open(filename, "rb") as f:
        print(f"Loading {filename}...")
        paper = pickle.load(f)
        if "summary" in paper:
            print(f"Skipping {filename}, already has summary.")
            return

        # make sure to start a vLLM ChatOpenAI server first
        inference_server_url = "http://0.0.0.0:8000/v1"

        llm = ChatOpenAI(
            model="/home/tijmen/cosmosage/packages/text-generation-webui/models/TheBloke_bagel-dpo-34b-v0.2-GPTQ_gptq-4bit-32g-actorder_True",
            openai_api_key="EMPTY",
            openai_api_base=inference_server_url,
            temperature=0.4,
        )

        print("Generating summary...")
        summarize_chain = load_summarize_chain(llm, chain_type="map_reduce")
        summary = summarize_chain.run(paper.pages)
        print(f"Summary generated. Length {len(summary)}")

        paper["summary"] = summary

    if "summary" in paper:
        if len(summary)>10:
            with open(filename, "wb") as f:
                print("Saving summary...")
                pickle.dump(paper, f)

# Create a pool of workers
pool = multiprocessing.Pool(1)

# Map the process_arxiv_id function to each arxiv_id in parallel
for filename in filenames:
    pool.apply_async(add_summary, args=(filename,))

# Close the pool of workers
pool.close()
pool.join()

### Step 2.3 Generate QA pairs

For generation of QA pairs, run `generate_synth_standalone.py`. This script can be run many times in parallel to speed up the process.

## Step 3: Synthetic data generation on textbooks

In [None]:
import glob
from extract_textbooks import TextBook
textbooks = []
for filepath in glob.glob("datasets/cosmology_textbooks/*.txt"):
    textbooks.append(TextBook(filepath))
for textbook in textbooks:
    textbook.generate_qa_pairs(multiprocess=True)
    textbook.save_dataset_jsonl()
    print(f"Saved {textbook.author} to jsonl")

In [None]:
import random

# collate all the JSONL files and shuffle them for good measure
textbook_jsonl_files_in = glob.glob("datasets/cosmology_textbooks_qa/*/*.jsonl")
textbook_jsonl_file_out = "datasets/cosmology_textbooks_qa.jsonl"

with open(textbook_jsonl_file_out, "w") as f:
    all_lines = []
    for textbook_jsonl_file in textbook_jsonl_files_in:
        with open(textbook_jsonl_file, "r") as g:
            all_lines.extend(g.readlines())
    random.shuffle(all_lines)
    f.writelines(all_lines)

## Step 4: Join, prepare the datasets

Depending on the choices we made in Steps 2 and 3, we will now have a bunch of synthetic summaries, QA pairs, and other data from arXiv papers and textbooks. We may also have other sources such as public datasets. 

The goal of Step 4 is to get all these datasets in JSONL format to get them ready for training with `fine_tune.py`, `fine_tune_lora.py`, or `axolotl`. We will go through and prepare any datasets that aren't yet in a useable JSONL format. One recurring theme will be the use of randomized choices in order to increase training set diversity. 

### Step 4.1 Make training data from the summaries

In [None]:
# convert summaries to QA format
filenames = glob.glob("datasets/arxiv_cache/*.pkl")
random.shuffle(filenames)
outfile = "datasets/arxiv_summary3.jsonl"
with open(outfile, "w") as fout:
    for filename in filenames:
        with open(filename, "rb") as f:
            paper = pickle.load(f)
            if "summary" not in paper:
                print(f"Warning: no summary found for {paper['arxiv_id']} in {filename}.")
                continue
            system_prompts = [
                "You are an expert cosmologist. You provide answers to questions about one particular paper.",
                "You are a knowledgeable cosmologist aware of the latest research. Follow the user's request.",
                "You are a seasoned astrophysicist. Provide insights into an cosmology or cosmology paper.",
                "As an expert in the field of cosmology, offer an explanation of whatever paper the user is asking about.",
                "You have a deep understanding of astrophysical research. Guide the user through the main points of whatever study they are asking about.",
                "You are adept in explaining complex astrophysical concepts. Explain about the user's requested topic.",
                "As a cosmology specialist, distill the essence of whatever the user is asking about."
            ]
            system_prompt = random.choice(system_prompts)
            user_prompts = [
                f"Please summarize {paper['shorthand_title']}.",
                f"What is {paper['arxiv_id']} about?",
                f"Can you explain the key findings of {paper['shorthand_title']}?",
                f"What are the major contributions of {paper['shorthand_title']} to cosmology?",
                f"I'm interested in the contents of the paper {paper['arxiv_id']}. Can you describe it?",
                f"What's {paper['shorthand_title']}?",
                f"Please explain the main points of {paper['shorthand_title']}.",
            ]
            user_prompt = random.choice(user_prompts)
            assistant_preambles = [
                f"{paper['shorthand_title']} is titled \"{paper['title']}\". Here is a summary of the paper. ",
                "", # no preamble
                f"In the paper titled \"{paper['title']}\", published in {paper['year']}, the authors explore an intriguing aspect of cosmology. Let's delve into the summary. ",
                f"The study \"{paper['title']}\" provides key insights. Here's an overview. ",
                f"Delving into \"{paper['title']}\", the summary goes as follows. ",
                f"This summary focuses on \"{paper['title']}\", a noteworthy paper with the arXiv ID {paper['arxiv_id']}. The key points are the following. ",
                f"Titled \"{paper['title']}\", this paper presents groundbreaking research in cosmology. Here's a summary. ",
                f"\"{paper['title']}\" is a paper about cosmology. It discusses the following. ",
                "Summary. ",
            ]
            assistant_preamble = random.choice(assistant_preambles)
            assistant_message = assistant_preamble+paper['summary']
            # sharegpt should take the form {"conversations": [{"from": "...", "value": "..."}]} where the message is from "system", "human", or "gpt".
            conversation_data = {
                "conversations": [
                    {"from": "system", "value": system_prompt},
                    {"from": "human", "value": user_prompt},
                    {"from": "gpt", "value": assistant_message}
                ]
            }

            fout.write(json.dumps(conversation_data) + '\n')

### Step 4.2 Make training data just from the metadata alone

In [None]:
filenames = glob.glob("datasets/arxiv_cache/*.pkl")
random.shuffle(filenames)
outfile = "datasets/arxiv_metadata_qa3.jsonl"
n_pairs_per_paper = 5
system_prompts = [
    "You are an AI programmed to provide brief, factual answers about arXiv papers.",
    "Your responses should be concise and limited to the essential details from the arXiv database.",
    "Provide short, precise answers with just the key information from the arXiv papers.",
    "You are an efficient AI, capable of giving terse answers about specific arXiv papers.",
    "Deliver quick and factual responses about arXiv papers, focusing only on the core details.",
    "Your role is to provide succinct, accurate information from the arXiv papers in as few words as possible.",
    "As an AI, offer brief and direct answers about arXiv papers, omitting any extraneous details.",
    "You are configured to give short, to-the-point responses about arXiv paper details.",
    "Your task is to provide the most important information from the arXiv papers in a concise manner.",
    "Respond with only the essential facts from the arXiv papers, keeping your answers brief.",
    "You are an AI trained to deliver short and accurate summaries of information from arXiv papers.",
    "Provide compact, factual responses, focusing solely on the critical aspects of arXiv papers.",
    "You are an AI assistant. Provide only brief, metadata-based responses, without explaining the content."
]

qa_templates = [
{
    "question": "Where can I find the arXiv paper titled {full_title} by {first_author} et al.?",
    "answer": "The paper titled '{full_title}' by {first_author} et al. can be found at http://arxiv.org/abs/{arxiv_id}"
},
{
    "question": "What is the arXiv ID for {first_author} et al.?",
    "answer": "The arXiv ID for '{full_title}' by {first_author} et al. is {arxiv_id}."
},
{
    "question": "Who is the first author of the paper with arXiv ID {arxiv_id}?",
    "answer": "The first author of {full_title} is {first_author}."
},
{
    "question": "What year was the arXiv paper {full_title} by {first_author} et al. published?",
    "answer": "{first_author} et al. was published in {year}."
},
{
    "question": "Can you give me the link to the arXiv paper authored by {first_author} et al. in {year}?",
    "answer": "The paper authored by {first_author} et al. in {year} can be found at http://arxiv.org/abs/{arxiv_id}."
},
{
    "question": "What is the topic of the paper by {first_author} et al. with arXiv ID {arxiv_id}?",
    "answer": "{first_author} et al. discuss '{full_title}' in their paper with arXiv ID {arxiv_id}."
},
{
    "question": "How can I access {first_author} et al.'s {year} paper on the arXiv?",
    "answer": "You can access the {year} paper by {first_author} et al. on arXiv at http://arxiv.org/abs/{arxiv_id}."
},
{
    "question": "Is the paper titled '{full_title}' available on arXiv?",
    "answer": "Yes, the paper titled '{full_title}' is available on arXiv under ID {arxiv_id}."
},
{
    "question": "What research did {first_author} et al. present in {year} on arXiv?",
    "answer": "In {year}, {first_author} et al. presented research on '{full_title}', available on arXiv with ID {arxiv_id}."
},
{
    "question": "Where can I read about {first_author} et al.'s findings in '{full_title}'?",
    "answer": "You can read about {first_author} et al.'s findings in '{full_title}' at http://arxiv.org/abs/{arxiv_id} on arXiv."
},
{
    "question": "Can you provide a link to {first_author} et al.'s work titled '{full_title}' on arXiv?",
    "answer": "Sure, the link to '{full_title}' by {first_author} et al. on arXiv is http://arxiv.org/abs/{arxiv_id}."
},
{
    "question": "What is the subject of the arXiv paper with ID {arxiv_id} by {first_author} et al.?",
    "answer": "The subject of the arXiv paper with ID {arxiv_id} by {first_author} et al. is '{full_title}'."
},
{
    "question": "Who led the research for the paper on arXiv titled '{full_title}'?",
    "answer": "{first_author} led the research for the paper titled '{full_title}' on arXiv."
},
{
    "question": "When was the paper titled '{full_title}' added to arXiv?",
    "answer": "The paper titled '{full_title}' was added to arXiv in {year}."
},
{
    "question": "I need a copy of '{full_title}' by {first_author} et al. Can you help find it?",
    "answer": "Here's the link where you can get your hands on '{full_title}' by {first_author} et al.: http://arxiv.org/abs/{arxiv_id}"
},
]
with open(outfile, "w") as fout:
    for filename in filenames:
        with open(filename, "rb") as f:
            paper = pickle.load(f)
            for _ in range(n_pairs_per_paper):
                system_prompt = random.choice(system_prompts)
                qa_template = random.choice(qa_templates)
                question = qa_template["question"].format(
                    first_author=paper["first_author"],
                    year=paper["year"],
                    full_title=paper["title"],
                    arxiv_id=paper["arxiv_id"],
                    shorthand_title=paper["shorthand_title"],
                )
                answer = qa_template["answer"].format(
                    first_author=paper["first_author"],
                    year=paper["year"],
                    full_title=paper["title"],
                    arxiv_id=paper["arxiv_id"],
                    shorthand_title=paper["shorthand_title"],
                )
                # sharegpt should take the form {"conversations": [{"from": "...", "value": "..."}]} where the message is from "system", "human", or "gpt".
                conversation_data = {
                    "conversations": [
                        {"from": "system", "value": system_prompt},
                        {"from": "human", "value": question},
                        {"from": "gpt", "value": answer}
                    ]
                }

                fout.write(json.dumps(conversation_data) + '\n')

### Step 4.3 Make training data from the generated QA pairs.

The QA information was generated by letting the model read chunks of text and write QA pairs. Then in a second pass, the model was told to critique, grade, write an alternative, grade. Let's start by analyze these grades.

In [None]:
import glob
import random
import pickle
papers = []
filenames = glob.glob("datasets/arxiv_cache/*.pkl")
random.shuffle(filenames)
outfile = "datasets/arxiv_qa3.jsonl"
with open(outfile, "w") as fout:
    for filename in filenames:
        with open(filename, "rb") as f:
            paper = pickle.load(f)
            if "qa" not in paper:
                print(f"Warning: no QA pairs found for {paper['arxiv_id']} in {filename}.")
                continue
            if len(paper["qa"]) < 1:
                print(f"Warning: empty QA pairs found for {paper['arxiv_id']} in {filename}.")
                continue
            if "refined_answer" not in paper["qa"][0]:
                print(f"Warning: no refined answers found for {paper['arxiv_id']} in {filename}.")
                continue
            papers.append(paper)

In [None]:
student_grades = []
teacher_grades = []
for paper in papers:
    for qa in paper['qa']:
        student_grades.append(qa['student_grade'])
        teacher_grades.append(qa['teacher_grade'])

def grade_to_int(grade):
    if isinstance(grade, int):
        if 0 <= grade <= 100:
            return grade
        else:
            print(f"Error: grade '{grade}' is an int, but not in the range 0-100.")
            return 0
    if isinstance(grade, dict):
        for key in ['student', 'teacher', 'score', 'grade', 'value', 'complete_answer', 'total', 'percentile', 'int', 'raw', 'overall', 'explanation', 'age']:
            if key in grade:
                return grade_to_int(grade[key])  # Recursively process the extracted value
    elif isinstance(grade, list) and grade:
        return grade_to_int(grade[0])  # Recursively process the first item
    elif isinstance(grade, str):
        if grade.startswith('grade '):
            grade = grade[6:]  # Skip 'grade '
        try:
            return int(grade)
        except ValueError:
            print(f"Cannot convert string '{grade}' to int.")
            return 0
    else:
        print(f"Error: grade '{grade}' is a {type(grade)}, which cannot be converted to int.")
        return 0
    # Handle any other unanticipated cases
    return 0  # Default return for unhandled cases

# Convert student and teacher grades to integers
import numpy as np
student_grades_arr = np.array([grade_to_int(grade) for grade in student_grades])
teacher_grades_arr = np.array([grade_to_int(grade) for grade in teacher_grades])
# zero out any None
student_grades_arr = np.nan_to_num(student_grades_arr)
teacher_grades_arr = np.nan_to_num(teacher_grades_arr)
print(f"The average student grade is {np.mean(student_grades_arr):.2f} with a standard deviation of {np.std(student_grades_arr):.2f}.")
print(f"The average teacher grade is {np.mean(teacher_grades_arr):.2f} with a standard deviation of {np.std(teacher_grades_arr):.2f}.")

In [None]:
# Plot separate histograms for student and teacher grades
import matplotlib.pyplot as plt

plt.hist(student_grades_arr, bins=np.linspace(0, 100, 100+1), alpha=0.5, label='Student')
plt.hist(teacher_grades_arr, bins=np.linspace(0, 100, 100+1), alpha=0.5, label='Teacher')

plt.xlabel("Grade")
plt.ylabel("Count")
plt.title("Distribution of Student and Teacher Grades")
plt.legend()
plt.show()

From this, let's choose the following rules:
 - if the student gets a 90% or above, use their answer
 - else if the teacher gets 90% of above use their answer
 - else disregard this QA pair

In [None]:
import re
import random
import json
conversations = []
for paper in papers:
    system_prompts_short = [
        "You are an AI programmed to provide brief answers about arXiv papers.",
        "As an expert cosmologist, you provide concise answers to the user's questions about an arXiv paper.",
        "You are a specialized AI, offering succinct insights into cosmology research papers.",
        "Provide brief, yet informative insights from arXiv papers in cosmology.",
        "As a focused expert in cosmology, respond with short, precise answers about scientific papers.",
        "Deliver quick and concise explanations of complex concepts from recent cosmology papers.",
        "You are programmed to give short, clear responses to queries about arXiv cosmology publications.",
        "Offer concise and direct answers to technical questions on cosmology research.",
        "As an AI trained in cosmology, provide succinct responses to detailed scientific inquiries.",
        "You are an AI designed to give brief, accurate descriptions of arXiv cosmology papers.",
        "Quickly decipher and explain the key points from cosmology papers.",
        "As an expert in the field, deliver short and precise interpretations of cosmology research.",
        "Condense long cosmology papers into brief, understandable answers.",
        "You are a compact knowledge source for quick insights into cosmology papers.",
        "Provide to-the-point, accurate summaries of recent findings in cosmology.",
        "As an AI, offer concise, clear-cut answers about specific aspects of cosmology papers.",
        "Distill complex cosmology concepts from research papers into brief explanations.",
        "Deliver quick, expert responses to questions about detailed cosmology research.",
        "You are programmed to succinctly answer queries on advanced cosmology topics.",
        "Offer crisp, clear summaries of the key findings in new cosmology research papers.",
    ]
    system_prompt_medium = [
        "In a few sentences, explain the main findings of arXiv papers in cosmology.",
        "Provide answers to the user's cosmology questions. Stick to a moderately long answer.",
        "Elaborate briefly on the methodologies and conclusions of recent cosmology papers from arXiv.",
        "Provide detailed, yet concise explanations of key theories and discoveries in recent cosmology papers.",
        "As a cosmology expert, deliver medium-length answers that clarify complex concepts in arXiv papers.",
        "Summarize the critical points of arXiv cosmology papers in a clear, moderately detailed manner.",
        "Explain the significance and implications of findings in recent cosmology research, in a few sentences.",
        "Interpret and convey the essence of cosmology papers from arXiv, aiming for moderate-length responses.",
        "In a moderate amount of detail, discuss the innovations and findings in contemporary cosmology papers.",
        "Delve into the core aspects of arXiv cosmology papers, providing answers that are informative yet succinct.",
        "Shed light on the complexities of cosmology research with moderately expansive answers.",
        "Analyze and explain the key aspects of arXiv cosmology papers in a clear, moderately lengthy format.",
    ]
    system_prompt_long = [
        "Provide detailed answers to the user's questions about arXiv papers in cosmology.",
        "Explain the methodologies and conclusions of recent cosmology papers from arXiv.",
        "Provide detailed explanations of key theories and discoveries in recent cosmology papers.",
        "As a cosmology expert, deliver long answers that clarify complex concepts in arXiv papers.",
        "Explain the significance and implications of findings in recent cosmology research.",
        "Interpret and convey the essence of cosmology papers from arXiv, aiming for longer responses.",
        "In a detailed manner, discuss the innovations and findings in contemporary cosmology papers.",
        "Delve into the core aspects of arXiv cosmology papers, providing answers that are informative yet detailed.",
        "Shed light on the complexities of cosmology research with long answers.",
        "Analyze and explain the key aspects of arXiv cosmology papers in a clear, lengthy format.",
    ]
    for qa in paper["qa"]:
        question = qa["question"]
        if grade_to_int(qa["student_grade"]) >= 90:
            answer = qa["answer"]
        elif grade_to_int(qa["teacher_grade"]) >= 90:
            answer = qa["refined_answer"]
        else:
            continue
        qa_length = len(qa["answer"])
        if qa_length < 40:
            continue
        elif qa_length < 200:
            system_prompt = random.choice(system_prompts_short)
        elif qa_length < 500:
            system_prompt = random.choice(system_prompt_medium)
        else:
            system_prompt = random.choice(system_prompt_long)

        def replace_vague_phrases(text, replacements):
            for phrase, replacement in replacements.items():
                text = re.sub(phrase, replacement, text)
            return text

        # Dictionary of replacements. This will be iterated over, so the order matters. This is safe in python>=3.7
        replacements = {
            "the analysis": paper["shorthand_title"],
            "The analysis": paper["shorthand_title"],
            "The passage": paper["shorthand_title"],
            "the passage": paper["shorthand_title"],
            "the PASSAGE": paper["shorthand_title"],
            "The PASSAGE": paper["shorthand_title"],
            "the paper": paper["shorthand_title"],
            "The paper": paper["shorthand_title"],
            "the study": paper["shorthand_title"],
            "The study": paper["shorthand_title"],
            "the research": paper["shorthand_title"],
            "The research": paper["shorthand_title"],
            "the authors' findings": paper['shorthand_title'],
            "The authors' findings": paper['shorthand_title'],
            "the authors": f"{paper['first_author']} et al.",
            "The authors": f"{paper['first_author']} et al.",
            "the findings": paper["shorthand_title"],
            "The findings": paper["shorthand_title"],
            "the method": f"the method used in {paper['shorthand_title']}",
            "The method": f"The method used in {paper['shorthand_title']}",
        }

        # Apply replacements to user prompt and assistant message
        user_prompt = replace_vague_phrases(qa["question"], replacements)
        assistant_message = replace_vague_phrases(qa["answer"], replacements)

        # sharegpt should take the form {"conversations": [{"from": "...", "value": "..."}]} where the message is from "system", "human", or "gpt".
        conversation_data = {
            "conversations": [
                {"from": "system", "value": system_prompt},
                {"from": "human", "value": user_prompt},
                {"from": "gpt", "value": assistant_message}
            ]
        }

        conversations.append(json.dumps(conversation_data))

random.shuffle(conversations)

outfile = "datasets/qa_tune/arxiv_refined_qa.jsonl"
with open(outfile, "w") as fout:
    for conversation in conversations:
        fout.write(conversation + '\n')

### Step 4.4 Add miscellanous datasets

Here you can add whatever other datasets you might like.

### Step 4.5 Combine and shuffle

In [None]:
import json
import random

def read_jsonl(file_path):
    """Read a JSON Lines file and return a list of JSON objects."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return [json.loads(line) for line in file]

def deduplicate_jsonl(json_objects):
    """De-duplicate a list of JSON objects based on their string representation."""
    unique_objects = set()
    deduplicated_list = []
    for obj in json_objects:
        obj_str = json.dumps(obj, sort_keys=True)
        if obj_str not in unique_objects:
            unique_objects.add(obj_str)
            deduplicated_list.append(obj)
    return deduplicated_list

def shuffle_and_write_jsonl(json_objects, output_path):
    """Shuffle a list of JSON objects and write them to a JSON Lines file."""
    random.shuffle(json_objects)
    with open(output_path, 'w', encoding='utf-8') as file:
        for obj in json_objects:
            file.write(json.dumps(obj) + '\n')

# Paths to your JSON Lines files
input_files = ["datasets/arxiv_qa3.jsonl", "datasets/arxiv_metadata_qa3.jsonl", "datasets/arxiv_summary3.jsonl"]
output_path = "datasets/arxiv_sharegpt2.jsonl"

# Read, combine, and de-duplicate
combined_json_objects = []
for file_path in input_files:
    combined_json_objects.extend(read_jsonl(file_path))
combined_json_objects = deduplicate_jsonl(combined_json_objects)

# Shuffle and write to file
shuffle_and_write_jsonl(combined_json_objects, output_path)


## Step 5: Train the model

Now, we have two options. We can either keep control of the training loop. To do this uncomment and run the following code. The other option is to train on the JSONL files with the `axolotl` package. The advantage of this is that it comes with a lot of bells and whistles.

In [None]:
# OPTION 5.1: keep control of the training loop
# fine_tune.fine_tune(
#     pretrained_model_file_path="zephyr-7b-beta",
#     training_data=cleaned_json_file_path,
#     lr=5e-5,
#     gradient_clip=1.0,
#     num_epochs=1,
#     out_dir="zephyr-7b-beta_cosmosage_v1",
# )

# OPTION 5.2: train using axolotl and its config.yml

#### Axolotl 

You can run 
```accelerate launch -m axolotl.cli.train config.yml --prepare_ds_only --debug```
to see examples of what data your model is being finetuned on. It is useful for knowing the exact prompt template to use during inference.

In [None]:
# visualize loss during training
import plot_tf_log
v16 = plot_tf_log.most_recent_log("mistral_cosmosage_v16")
v14 = plot_tf_log.most_recent_log("mistral_cosmosage_v14")
v15 = plot_tf_log.most_recent_log("mistral_cosmosage_v15")
plot_tf_log.plot_loss([v15], plot_type="detailed", detailed_pts_per_eval=5)

## Step 6: Evaluate the fine-tuned model

In [None]:
model_path = "models/cosmosage_v2/"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(model_path).to(device, dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path)

def ask_cosmosage(question):
    prompt = f"You are cosmosage, an AI programmed to provide excellent and detailed answers to the user's question. You are an expert cosmology assistant, able to answer questions on the cosmic microwave background, galaxy formation, large scale structure, theoretical cosmology, inflation, big bang nucleosynthesis, cosmology instrumentation, and other related topics. Please assume the user is fluent in scientific terminology. Elaborate where possible to give a complete answer. If you do not know, say you do not know.▁ USER: {question}▁ ASSISTANT:"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_ids = model.generate(input_ids, max_length=1024, do_sample=True, temperature=0.7, top_k=None, pad_token_id=tokenizer.eos_token_id)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    answer = generated_text.split("ASSISTANT: ")[-1]
    return answer

In [None]:
print(ask_cosmosage("What is the temperature of the CMB according to Fixsen (2009)? What datasets did he use to derive the value?"))

In [None]:
print(ask_cosmosage("What is Digital Active Nulling?"))

In [None]:
print(ask_cosmosage("Explain the ISW effect."))

In [None]:
print(ask_cosmosage("How does the time of matter-radiation equality affect the damping tail?"))

In [None]:
print(ask_cosmosage("Explain how one would calculate the helium fraction at the surface of last scattering."))

## Step 7: Push model to huggingface

In [None]:
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(
    folder_path="models/cosmosage_qa",
    repo_id="tijmen2/cosmosage_v2",
    repo_type="model",
)