In [None]:
import os
from getpass import getpass
from typing import List, Optional

import instructor
import pymupdf4llm
import weave
import wget
from openai import OpenAI
from pydantic import BaseModel

In [None]:
api_key = getpass("Enter you OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = api_key

In [None]:
weave.init(project_name="arxiv-data-extraction")

In [None]:
@weave.op()
def get_markdown_from_arxiv(url):
    file_path = wget.download(url=url, out=wget.detect_filename(url))
    md_text = pymupdf4llm.to_markdown(file_path)
    os.remove(file_path)
    return md_text

In [None]:
SYSTEM_PROMPT = """
You are a helpful assistant to a machine learning researcher who is reading a paper from arXiv.
You are to extract the following information from the paper:

- a list of main findings in from the paper and their corresponding detailed explanations
- the list of names of the different novel methods proposed in the paper and their corresponding detailed explanations
- the list of names of the different existing methods used in the paper, their corresponding detailed explanations, and
    their citations
- the list of machine learning techniques used in the paper, such as architectures, optimizers, schedulers, etc., their
    corresponding detailed explanations, and their citations
- the list of evaluation metrics used in the paper, the benchmark datasets used, the values of the metrics, and their
    corresponding detailed observation in the paper
- the link to the GitHub repository of the paper if there is any
- the hardware or accelerators used to perform the experiments in the paper if any
- a list of possible further research directions that the paper suggests

Here are some rules to follow:
1. When looking for the main findings in the paper, you should look for the abstract.
2. When looking for the explanations for the main findings, you should look for the introduction and methods section of
    the paper.
3. When looking for the list of existing methods used in the paper, first look at the citations, and then try explaining
    how they were used in the paper.
4. When looking for the list of machine learning methods used in the paper, first look at the citations, and then try
    explaining how they were used in the paper.
5. When looking for the evaluation metrics used in the paper, first look at the results section of the paper, and then
    try explaining the observations made from the results. Pay special attention to the tables to find the metrics,
    their values, the corresponding benchmark and the observation association with the result.
6. If there are no github repositories associated with the paper, simply return "None".
7. When looking for hardware and accelerators, pay special attentions to the quantity of each type of hardware and
    accelerator. If there are no hardware or accelerators used in the paper, simply return "None".
"""


class Method(BaseModel):
    method_name: str
    explanation: str
    citation: Optional[str]


class Evaluation(BaseModel):
    metric: str
    benchmark: str
    value: float
    observation: str


class PaperInfo(BaseModel):
    main_findings: List[str]
    main_finding_explanations: List[str]
    novel_methods: List[Method]
    existing_methods: List[Method]
    machine_learning_techniques: List[Method]
    metrics: List[Evaluation]
    github_repository: str
    hardware: str
    further_research: List[str]


openai_client = OpenAI()
structured_client = instructor.from_openai(openai_client)


class ArxivModel(weave.Model):
    model: str
    max_retries: int = 5
    seed: int = 42
    system_prompt: str = SYSTEM_PROMPT

    @weave.op()
    def predict(self, url_pdf: str) -> dict:
        md_text = get_markdown_from_arxiv(url_pdf)
        return structured_client.chat.completions.create(
            model=self.model,
            response_model=PaperInfo,
            max_retries=self.max_retries,
            seed=self.seed,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": md_text},
            ],
        ).model_dump()

In [None]:
@weave.op()
def arxiv_method_score(
    method: List[dict], model_output: Optional[dict]
) -> dict[str, float]:
    if model_output is None:
        return {"method_prediction_accuracy": 0.0}
    predicted_methods = (
        model_output["novel_methods"]
        + model_output["existing_methods"]
        + model_output["machine_learning_techniques"]
    )
    num_correct_methods = 0
    for gt_method in method:
        for predicted_method in predicted_methods:
            predicted_method = (
                f"{predicted_method['method_name']}\n{predicted_method['explanation']}"
            )
            if (
                gt_method["name"].lower() in predicted_method.lower()
                or gt_method["full_name"].lower() in predicted_method.lower()
            ):
                num_correct_methods += 1
    return {"method_prediction_accuracy": num_correct_methods / len(predicted_methods)}

In [None]:
arxiv_parser_model = ArxivModel(model="gpt-4o")
eval_dataset_ref = weave.ref("cv-papers:v0").get()
evaluation = weave.Evaluation(
    dataset=eval_dataset_ref.rows[:5], scorers=[arxiv_method_score]
)
summary = await evaluation.evaluate(arxiv_parser_model.predict)