# cosmosage

See README for more information.

In [None]:
import pickle
import os
import scrape_arxiv
import analyze_asl_dict
import extract_textbooks
import multiprocessing
import pandas as pd
import re
import json
import matplotlib.pyplot as plt
from datetime import datetime

tex_files_path = "datasets/tex_files/"
json_file_path = "datasets/arxiv_tex.json"
cleaned_json_file_path = "datasets/combined_training_set.json"

## Step 1: Extract arXiv data

In [None]:
cache_file = "datasets/arxiv_ids_cache.pkl"

# Check if the cache file exists
if os.path.exists(cache_file):
    # Load the cached data
    with open(cache_file, "rb") as f:
        arxiv_ids = pickle.load(f)
else:
    # unique arXiv numbers from the asl database
    db_path = "datasets/dict_20231123.db"
    arxiv_id_asl_tagged = analyze_asl_dict.extract_unique_arxiv_numbers(db_path)

    # also extract all of my papers
    search_params = {"search_query": "au:de_Haan_T", "searchtype": "author"}
    arxiv_id_tdh = scrape_arxiv.get_arxiv_ids(search_params)

    # also extract the papers with "cosmic microwave background" in the abstract
    search_params = {"search_query": "abs:\"cosmic microwave background\""}
    arxiv_id_cmb = scrape_arxiv.get_arxiv_ids(search_params)

    # more arxiv papers recommended for me by asl
    arxiv_id_asl_rec = scrape_arxiv.other_arxiv_recommendation_ids()

    # join all of these arxiv ids and remove duplicates
    arxiv_ids = arxiv_id_asl_tagged + arxiv_id_tdh + arxiv_id_cmb + arxiv_id_asl_rec
    arxiv_ids = list(set(arxiv_ids))

    # Save the data to the cache file
    with open(cache_file, "wb") as f:
        pickle.dump(arxiv_ids, f)


In [None]:
import numpy as np
np.where(np.array(arxiv_ids) == "1707.07535")[0]/len(arxiv_ids)

## Step 2: Synthetic data generation

Here, we generate synthetic data using the following:
 - instruction-tuned model to generate the QA pairs
 - VLLM server to load the model once and provide good throughput
 - langchain to handle 
   - gathering of papers
   - extracting from PDFs
   - chunking data

In [None]:
def process_arxiv_id(arxiv_id):
    try:
        paper = scrape_arxiv.arxiv_paper(arxiv_id)
        paper.generate_summary()
        paper.generate_qa_pairs()
        paper.save_dataset_jsonl()
    except Exception as e:
        # Log the exception and arxiv_id
        print(f"Error processing {arxiv_id}: {e}")

# Create a pool of workers
pool = multiprocessing.Pool()

# Map the process_arxiv_id function to each arxiv_id in parallel
for arxiv_id in arxiv_ids:
    if not os.path.exists(f"datasets/arxiv_qa/{arxiv_id}.jsonl"):
        pool.apply_async(process_arxiv_id, args=(arxiv_id,))

# Close the pool of workers
pool.close()
pool.join()

The above code can take quite a while to run, so it is also available in script form at
`run_generate_synth.py` which will run inside e.g. a screen session.

A logger can be set up with `log_generate_synth.sh` to track the progress. The following code will plot the log file.

In [None]:
# Function to read and process log data
def read_and_process_log(file_path):
    with open(file_path, 'r') as file:
        log_contents = file.readlines()

    data = []
    for line in log_contents:
        match = re.search(r'(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) - Folder Size: (\d+)M - File Count: (\d+)', line)
        if match:
            date_time = datetime.strptime(match.group(1), '%Y-%m-%d %H:%M:%S')
            folder_size = int(match.group(2))
            file_count = int(match.group(3))
            data.append({'DateTime': date_time, 'FolderSizeMB': folder_size, 'FileCount': file_count})

    return pd.DataFrame(data)

# Function to plot file count and folder size
def plot_file_count_and_folder_size(df):
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Plotting file count
    color = 'tab:red'
    ax1.set_xlabel('Date and Time')
    ax1.set_ylabel('File Count', color=color)
    ax1.plot(df['DateTime'], df['FileCount'], color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    # Plotting folder size with a second y-axis
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Folder Size (MB)', color=color)
    ax2.plot(df['DateTime'], df['FolderSizeMB'], color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    plt.title('File Count and Folder Size Over Time')
    fig.tight_layout()
    plt.show()

# File path to your log data
log_file_path = 'datasets/arxiv_qa/generate_synth.log'

# Reading and processing the log data
df = read_and_process_log(log_file_path)

# Plotting the data
plot_file_count_and_folder_size(df)


## Step 3: Join and prepare the datasets

We now have many JSONL files, one for each arXiv paper. Each one has an initial question which asks to summarize the whole paper. The subsequent questions are specific questions about the paper. When I used Mistral-7B-v0.2 to generate these questions, it did not include much context and some of the questions are impossible to answer unless you know what context they are being asked in. For that reason, let's provide the summary as context to the Q&A.

Let's also collate the summaries into a single JSONL file that uses QA format, without context.

In [None]:
arxiv_qa_jsonl_files = [f"datasets/arxiv_qa/{arxiv_id}.jsonl" for arxiv_id in arxiv_ids]

# collate the first line of each JSONL file into a single summariesfile
arxiv_qa_summarize_jsonl_file = "datasets/arxiv_qa_summarize.jsonl"
# also write the remaining lines as QA using the summary as context
arxiv_qa_with_context_jsonl_file = "datasets/arxiv_qa_with_context.jsonl"

with open(arxiv_qa_summarize_jsonl_file, "w") as f1, open(arxiv_qa_with_context_jsonl_file, "w") as f2:
    for arxiv_qa_jsonl_file in arxiv_qa_jsonl_files:
        if os.path.exists(arxiv_qa_jsonl_file) and os.path.getsize(arxiv_qa_jsonl_file) > 0:
            with open(arxiv_qa_jsonl_file, "r") as g:
                first_line = g.readline()
                f1.write(first_line + "\n")
                first_line_json = json.loads(first_line)
                summary = first_line_json["answer"]
                rest_of_lines = g.readlines()
                for line in rest_of_lines:
                    try:
                        line_json = json.loads(line)
                        qa_dict = {}
                        qa_dict["context"] = summary
                        question = line_json["question"]
                        assert type(question)==str
                        qa_dict["question"] = question
                        answer = line_json["answer"]
                        assert type(answer)==str
                        qa_dict["answer"] = answer
                        # this context will be repeated many times, but that's okay
                        f2.write(json.dumps(qa_dict) + "\n")
                    except:
                        pass


## Step 4: Extract textbooks and create another JSONL file

In [None]:
extract_textbooks.textbooks_to_jsonl("datasets/textbooks.jsonl")

## Step 5: Train the model

Now, we have two options. We can either keep control of the training loop. To do this uncomment and run the following code. The other option is to train on the JSONL files with the `axolotl` package. The advantage of this is that it comes with a lot of bells and whistles.

In [None]:
# OPTIONAL : manually collect training data, tokenize, and run training loop 

# JSON method for collating data
# # clean arxiv json data a little more and include multiple copies
# num_copies_arxiv = 4
# json_data = tex_to_json.load_from_json(json_file_path)
# cleaned_data = []
# for _ in range(num_copies_arxiv):
#     for paper, data_list in json_data.items():
#         # remove any sequences enclosed in square brackets (e.g. [1])
#         cleaned_data.extend([re.sub(r"\[[^\]]*\]", "", data) for data in data_list])

# # add physics Q&A data
# physics_questions = tex_to_json.load_from_json("datasets/physics_clean.json")
# cleaned_data.extend(physics_questions)
# tex_to_json.save_to_json(cleaned_data, cleaned_json_file_path)

# Train the model
# ALTERNATIVE: train using axolotl and its config.yml

# fine_tune.fine_tune(
#     pretrained_model_file_path="zephyr-7b-beta",
#     training_data=cleaned_json_file_path,
#     lr=5e-5,
#     gradient_clip=1.0,
#     num_epochs=1,
#     out_dir="zephyr-7b-beta_cosmosage_v1",
# )

You can run 
```accelerate launch -m axolotl.cli.train config.yml --prepare_ds_only --debug```
to see examples of what data your model is being finetuned on. It is useful for knowing the exact prompt template to use during inference.

In [None]:
# visualize loss during training
import plot_tf_log
v8 = plot_tf_log.most_recent_log("mistral_cosmosage_v8")
v9 = plot_tf_log.most_recent_log("mistral_cosmosage_v9")
v6 = plot_tf_log.most_recent_log("mistral_cosmosage_v6")
v7 = plot_tf_log.most_recent_log("mistral_cosmosage_v7")
plot_tf_log.plot_loss([v9], plot_type="detailed")

In [None]:
%matplotlib widget

## Step 6: Evaluate the fine-tuned model

In [None]:
import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation",
    model="/home/tijmen/cosmosage/models/mistral_cosmosage_v4/relora_out/",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

def ask_cosmosage(question):
    prompt = f"""<s>SYSTEM: You are an expert cosmology assistant, able to answer questions on the cosmic microwave background, galaxy formation, large scale structure, theoretical cosmology, inflation, big bang nucleosynthesis, cosmology instrumentation, and other related topics. Please assume the user is fluent in scientific terminology and do not insult his or her intelligence. Elaborate where possible to give a complete answer along with its context. If you do not know, say you do not know. Give one answer after reading the two provided examples.
## EXAMPLE 1
USER: What is the ISW effect?
ASSISTANT: The Integrated Sachs-Wolfe effect describes the fact that cosmic microwave background photons gain energy as they fall into potential wells and lose it again as they climb back out. The origin of the effect is gravitational redshift. For static potential wells, there is no net effect. However, if the potentials grow or decay, there is a net effect, known as the Integrated Sachs-Wolfe effect. In a matter-dominated universe, linear potentials stay constant. The late-time ISW effect refers to the fact that at late times, dark energy causes a deviation from matter domination and linear potentials decay. This causes an effect in the cosmic microwave background which has been measured in cross-correlation with large-scale structure. The early-time ISW refers to the fact that right after recombination, radiation still contributed a non-negligible amount of the energy density of the universe, causing potentials to grow. This causes the photons to gain energy.
## EXAMPLE 2
USER: What are the definitions of galaxy cluster mass (M200m, M500c, etc.)?
ASSISTANT: Calculating a galaxy cluster's mass is like calculating the population of a sprawling city; as you get further out from city center, the suburbs never end. Eventually you have to stop counting and just say that from there on the people belong to a different town.\nIn cluster masses, we integrate from the cluster center out to some radius rΔ where Δ might be 200, 500, or some other number. The definition of this radius is that the density of the enclosed mass is Δ times the average density of the universe at that redshift. The associated mass is MΔ.\nOne more detail is that 'average density' can be defined in two ways. You can either include the energy density of dark energy - in which case you are comparing to the critical density at that redshift - or you can use the matter density only. This choice results in the c or m subscript (stands for critical or mean). Note that the critical density is higher than the mean density, so it gives an earlier cutoff in the integral. Therefore 'MΔc' masses are smaller than 'MΔm' ones.
## QUESTION
USER: {question}
ASSISTANT:"""
    outputs = pipe(
        prompt,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
    )
    return outputs

Let's try using the fine-tuned model as an assistant.

In [None]:
answer_cmb = ask_cosmosage("What is the temperature of the CMB?") # easy question

In [None]:
answer_cmb

In [None]:
answer_dan = ask_cosmosage("What is Digital Active Nulling?") # see if it's read the arxiv paper

In [None]:
answer_dan

In [None]:
answer_isw = ask_cosmosage("Explain the ISW effect.")  # hard question

In [None]:
answer_isw

In [None]:
# clear VRAM when not using it
del pipe

## Step 7: Push model to huggingface

In [None]:
from huggingface_hub import HfApi
api = HfApi()

# Upload all the content from the local folder to your remote Space.
# By default, files are uploaded at the root of the repo
api.upload_folder(
    folder_path="/QUPMLcommon/tijmen/cosmosage_v0.1",
    repo_id="tijmen2/cosmosage_v0.1",
    repo_type="model",
)

## ALTERNATIVE STEP 2: Download the .tex files from arXiv

In [None]:
# # # sequential version (one thread):
# # scrape_arxiv.extract_tex(arxiv_ids, tex_files_path)

# # multithreaded version:
# from multiprocessing import Pool
# def download_papers(arxiv_id_list):
#     scrape_arxiv.extract_tex(arxiv_id_list, tex_files_path)
# n_processes = 12
# random.shuffle(arxiv_ids)
# arxiv_id_split = [arxiv_ids[i::n_processes] for i in range(n_processes)]
# with Pool(n_processes) as p:
#     p.map(download_papers, arxiv_id_split)

## ALTERNATIVE Step 3: Parse the downloaded .tex files and save to JSONL

In [None]:
# # # method using pydetex
# # parsed_tex_files = tex_to_json.parse_tex_files(tex_files_path)
# # tex_to_json.save_to_json(parsed_tex_files, json_file_path)

# # method using command line detex 
# tex_to_json.detex_files("datasets/tex_files/")
# # manual regular expressions to clean up .detex and save to a single JSONL file
# tex_to_json.detex_to_jsonl("datasets/tex_files/", "datasets/arxiv_tex.jsonl")