# Example Notebook for Integrating with Hugging Face

TODO: Future work may include using `Arthur bench` to compare LLMS.

TODO: Re-enable functionality for traditional text-summarization.

# Setup

In [None]:
import requests
import huggingface_hub as hfhub

In [None]:
from sciterra import Atlas
from sciterra import Cartographer
from sciterra.librarians import SemanticScholarLibrarian # or ADSLibrarian
from sciterra.vectorization import SciBERTVectorizer # among others

### Settings

In [None]:
# Settings
config = dict(
    atlas_dirpath = "../atlas",
    model = "Falconsai/text_summarization",
    endpoint = "llama-2-7b-chat-hf-mhj",
    api_or_endpoint = 'endpoint',
    n_summarized = 10,
    task = "summarization",
)

In [None]:
summary_key = {
    "summarization": "text_summary",
    "text-generation": "generated_text",
}[config["task"]]

### Sciterra

In [None]:
atl = Atlas.load(config['atlas_dirpath'])

In [None]:
# Create a cartographer with a Semantic Scholar librarian and a SciBERT vectorizer
crt = Cartographer(
    librarian=SemanticScholarLibrarian(),
    vectorizer=SciBERTVectorizer(),
)

### HFHub

In [None]:
# Login
token = hfhub.get_token()
if token is None:
    hfhub.login()
    token = hfhub.get_token()

# Format for Inference API
headers = {"Authorization": f"Bearer {token}"}

In [None]:
if config['api_or_endpoint'] == 'api':

	assert False, "This needs to be fixed up again."

	def query(prompts):

		payload = {
			'inputs': prompts,
			'parameters': {
				'max_new_tokens': 250
			},
		}

		api_url = f"https://api-inference.huggingface.co/models/{model}"
		response = requests.post(api_url, headers=headers, json=payload)
		return response

In [None]:
if config['api_or_endpoint'] == 'endpoint':

	endpoint = hfhub.get_inference_endpoint(config['endpoint'])
	def query(abstracts, preprompt="summarize in two sentences:"):

		prompts = [
			f"{preprompt}: {abstract}"
			for abstract in abstracts
		]

		predictions = [
			endpoint.client.text_generation(
				prompt,
				max_new_tokens=1000,
			) for prompt in prompts
		]
		return predictions

### Other

In [None]:
import textwrap
def wrap(text):
    return "\n".join(textwrap.wrap(text, width=80))

# Exploration

In [None]:
# Find the publications most-similar to the original
sorted_keys, sorted_values = crt.sort(atl, center=atl.center)

In [None]:
# Get the abstracts for the most-similar publications
prompt = (
'''The following text are abstracts from several publications that are most
similar to the original publication. We will share each one, and then we will
summarize them.
'''
)
abstracts = []
for i, identifier in enumerate(sorted_keys[:config['n_summarized']]):

    abstract = "\n".join(atl.publications[identifier].abstract.split("."))

    # Combined prompt; used for LLMs
    prompt += f"This is the abstract for paper {i+1}:\n"
    prompt += abstract + "\n"

    # Individual prompts; used for summarization models
    abstracts.append(abstract)

prompt += "The summary of the papers, with one sentence per paper, is as follows:"

print(prompt)

In [None]:
# Make predictions
predictions = query(abstracts)
predictions

In [None]:
from requests import HTTPError
overall_summary_input = "\n\n".join([
    "Summary for paper 1: " + prediction
    for prediction in predictions
])
try:
    overall_prediction = query(
        [overall_summary_input, ],
        preprompt="Summarize in a few sentences current the field studied by the following research papers:",
    )[0]
except HTTPError as e:
    overall_prediction = ""
print(overall_prediction)

In [None]:
# Score
from evaluate import load
eval_module = load("rouge")
metrics = eval_module.compute(predictions=predictions, references=abstracts, use_aggregator=False)

In [None]:
# Pretty print
output_str = wrap(overall_prediction) + "\n\n"
for i in range(config['n_summarized']):

    output_str += f"Paper {i+1} summary ("
    output_str += f"n_char_orig: {len(abstracts[i])}, "
    output_str += f"n_char_summ: {len(predictions[i])}, "
    output_str += f"rouge2: {metrics['rouge2'][i]:.3g}):\n"
    output_str += "-------------------------------------------------------------\n"
    output_str += wrap(predictions[i])
    output_str += "\n\n"

print(output_str)