In [None]:
import re
import scrape_arxiv
import tex_to_json
import analyze_asl_dict
import fine_tune

tex_files_path = "datasets/tex_files/"
json_file_path = "datasets/arxiv_tex.json"
cleaned_json_file_path = "datasets/combined_training_set.json"

In [None]:
# Step 1: Extract unique arXiv numbers from the asl database
db_path = "datasets/dict_20231123.db"
arxiv_id_asl_tagged = analyze_asl_dict.extract_unique_arxiv_numbers(db_path)

arxiv_id_asl_tagged.remove("2211.00542")  # no tarball available
arxiv_id_asl_tagged.remove("2306.01274")  # no tarball available
arxiv_id_asl_tagged.remove("2206.03389")  # no tarball available
arxiv_id_asl_tagged.remove("2207.11937")  # no tarball available
arxiv_id_asl_tagged.remove("2001.01724")  # no tarball available
arxiv_id_asl_tagged.remove("1601.00125")  # no tarball available

arxiv_id_asl_tagged.remove("1712.07541")  # funky tex file
arxiv_id_asl_tagged.remove("1902.09640")  # funky tex file
arxiv_id_asl_tagged.remove("2003.03431")  # funky tex file
arxiv_id_asl_tagged.remove("1704.00884")  # funky tex file

arxiv_id_asl_tagged = list(set(arxiv_id_asl_tagged))

In [None]:
# also extract all of my papers
search_params = {
    "search_query": "au:de_Haan_T",
    "searchtype": "author",
    "sortBy": "submittedDate",
    "sortOrder": "descending",
}
arxiv_id_tdh = scrape_arxiv.get_arxiv_ids(search_params)

In [None]:
# also extract the papers with "cosmic microwave background" in the abstract
search_params = {
    "search_query": 'abs:"cosmic microwave background"',
    "sortBy": "submittedDate",
    "sortOrder": "descending",
}
arxiv_id_cmb = scrape_arxiv.get_arxiv_ids(search_params)

In [None]:
# more arxiv papers recommended for me by asl (but no tags)
arxiv_id_asl_rec = scrape_arxiv.other_arxiv_recommendation_ids()

In [None]:
# collate all of these arxiv ids
arxiv_ids = arxiv_id_asl_tagged + arxiv_id_tdh + arxiv_id_asl_rec + arxiv_id_cmb
# remove duplicates
arxiv_ids = list(set(arxiv_ids))

In [None]:
# Step 2: Download the .tex files from arXiv
scrape_arxiv.extract_tex(arxiv_ids, tex_files_path)

In [None]:
# Step 3: Parse the downloaded .tex files and save to JSONL

# original method using pydetex
# parsed_tex_files = tex_to_json.parse_tex_files(tex_files_path)
# tex_to_json.save_to_json(parsed_tex_files, json_file_path)

# new method using command-line detex utility
tex_to_json.detex_files("datasets/tex_files/")
tex_to_json.detex_to_jsonl("datasets/tex_files/", "datasets/arxiv_tex.jsonl")

In [None]:
# Step 4: Gather the training data
# ALTERNATIVE: use JSONL files as inputs for axolotl

# JSON method
# # clean arxiv json data a little more and include multiple copies
# num_copies_arxiv = 4
# json_data = tex_to_json.load_from_json(json_file_path)
# cleaned_data = []
# for _ in range(num_copies_arxiv):
#     for paper, data_list in json_data.items():
#         # remove any sequences enclosed in square brackets (e.g. [1])
#         cleaned_data.extend([re.sub(r"\[[^\]]*\]", "", data) for data in data_list])

# # add physics Q&A data
# physics_questions = tex_to_json.load_from_json("datasets/physics_clean.json")
# cleaned_data.extend(physics_questions)
# tex_to_json.save_to_json(cleaned_data, cleaned_json_file_path)

In [None]:
# Step 5: Train the model
# ALTERNATIVE: train using axolotl and its config.yml

# fine_tune.fine_tune(
#     pretrained_model_file_path="zephyr-7b-beta",
#     training_data=cleaned_json_file_path,
#     lr=5e-5,
#     gradient_clip=1.0,
#     num_epochs=1,
#     out_dir="zephyr-7b-beta_cosmosage_v1",
# )

In [None]:
# visualize loss during training
import plot_tf_log
log = plot_tf_log.most_recent_log("mistral_cosmosage_v3")
plot_tf_log.plot_loss([log], logsmooth=True)

In [None]:
# Step 6: Evaluate the fine-tuned model
import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation",
    model="/home/tijmen/cosmosage/models/mistral_cosmosage_v4/relora_out/",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

def ask_cosmosage(question):
    prompt = f"""<s>SYSTEM: You are an expert cosmology assistant, able to answer questions on the cosmic microwave background, galaxy formation, large scale structure, theoretical cosmology, inflation, big bang nucleosynthesis, cosmology instrumentation, and other related topics. Please assume the user is fluent in scientific terminology and do not insult his or her intelligence. Elaborate where possible to give a complete answer along with its context. If you do not know, say you do not know. Give one answer after reading the two provided examples.
## EXAMPLE 1
USER: What is the ISW effect?
ASSISTANT: The Integrated Sachs-Wolfe effect describes the fact that cosmic microwave background photons gain energy as they fall into potential wells and lose it again as they climb back out. The origin of the effect is gravitational redshift. For static potential wells, there is no net effect. However, if the potentials grow or decay, there is a net effect, known as the Integrated Sachs-Wolfe effect. In a matter-dominated universe, linear potentials stay constant. The late-time ISW effect refers to the fact that at late times, dark energy causes a deviation from matter domination and linear potentials decay. This causes an effect in the cosmic microwave background which has been measured in cross-correlation with large-scale structure. The early-time ISW refers to the fact that right after recombination, radiation still contributed a non-negligible amount of the energy density of the universe, causing potentials to grow. This causes the photons to gain energy.
## EXAMPLE 2
USER: What are the definitions of galaxy cluster mass (M200m, M500c, etc.)?
ASSISTANT: Calculating a galaxy cluster's mass is like calculating the population of a sprawling city; as you get further out from city center, the suburbs never end. Eventually you have to stop counting and just say that from there on the people belong to a different town.\nIn cluster masses, we integrate from the cluster center out to some radius rΔ where Δ might be 200, 500, or some other number. The definition of this radius is that the density of the enclosed mass is Δ times the average density of the universe at that redshift. The associated mass is MΔ.\nOne more detail is that 'average density' can be defined in two ways. You can either include the energy density of dark energy - in which case you are comparing to the critical density at that redshift - or you can use the matter density only. This choice results in the c or m subscript (stands for critical or mean). Note that the critical density is higher than the mean density, so it gives an earlier cutoff in the integral. Therefore 'MΔc' masses are smaller than 'MΔm' ones.
## QUESTION
USER: {question}
ASSISTANT:"""
    outputs = pipe(
        prompt,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
    )
    return outputs

Let's try using the fine-tuned model as an assistant.

In [None]:
answer_cmb = ask_cosmosage("What is the temperature of the CMB?") # easy question

In [None]:
answer_cmb

In [None]:
answer_dan = ask_cosmosage("What is Digital Active Nulling?") # see if it's read the arxiv paper

In [None]:
answer_dan

In [None]:
answer_isw = ask_cosmosage("Explain the ISW effect.")  # hard question

In [None]:
answer_isw

In [None]:
# clear VRAM when not using it
del pipe