## 1. Load required libraries

In [8]:
import os
import time
import json
import pandas as pd
import google.generativeai as genai

from typing import Union
from dotenv import load_dotenv
from google.generativeai.types import HarmCategory, HarmBlockThreshold


# Load environment variables from .env file
load_dotenv()

# Retrieve API key from environment variable
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

## 2. Define LLM functions

In [9]:
genai.configure(api_key=GEMINI_API_KEY)
#gemini-1.0-pro
model = genai.GenerativeModel("gemini-1.0-pro-001", safety_settings={
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE})


def get_descriptive_words_from_gemini(abstracts: list[str], word_count: int=10,
                                               max_retries: int=3) -> list[dict[str | list]]:

    prompt_template = """
Provide a comma separated list of <word_count> words uniquely identifying this abstract. 
The words should only be unitary words.
Do not use hyphenated words. 
Do not use phrases.

Abstract: 
<abstract>
"""

    descriptive_words = []

    for abstract in abstracts:
        prompt = prompt_template.replace("<word_count>", str(word_count))
        prompt = prompt.replace("<abstract>", abstract)
        print(f"Prompt: {prompt}")

        retries = 0
        while retries < max_retries:
            try:
                model_results = model.generate_content(prompt)
                print(f"Completed {model_results.text} abstracts ...")
                descriptive_words.append(model_results.text)
                break
            except Exception as e:
                print(f"Exception occurred: {e}")
                retries += 1
                print(f"Retrying... ({retries}/{max_retries})")
                time.sleep(10)
        if retries == 3:
            return None
        time.sleep(3)
    return descriptive_words


def get_most_recent_file(directory: str) -> Union[str, None]:
    files = [os.path.join(directory, f) for f in os.listdir(directory)]
    files = [f for f in files if os.path.isfile(f)]
    if not files:
        return None
    file_ctimes = [(f, os.path.getctime(f)) for f in files]
    most_recent_file = sorted(file_ctimes, key=lambda x: x[1], reverse=True)[0][0]
    return most_recent_file


def daily_processing(filename=None):
    if filename is None:
        filename = get_most_recent_file("../data")
    df = pd.read_json(filename, orient="records")
    if not df.empty:
        abstracts = df["abstract"].tolist()
        descriptive_words = get_descriptive_words_from_gemini_by_chunk(abstracts)
        descriptive_words = [{key: value} for key, value in descriptive_words.items()]
        results = pd.DataFrame(descriptive_words, columns=["abstract"])
        return results

In [10]:
#df = daily_processing("../data/2024_03_08_cs.json")
df = pd.read_json("../data/2024_03_13_cs.json.gz", orient="records")
abstracts = df["abstract"].tolist()
descriptive_words = get_descriptive_words_from_gemini(abstracts)

Prompt: 
Provide a comma separated list of 10 words uniquely identifying this abstract. 
The words should only be unitary words.
Do not use hyphenated words. 
Do not use phrases.

Abstract: 
We study ill-conditioned positive definite matrices that are disturbed by the sum of $m$ rank-one matrices of a specific form. We provide estimates for the eigenvalues and eigenvectors. When the condition number of the initial matrix tends to infinity, we bound the values of the coordinates of the eigenvectors of the perturbed matrix. Equivalently, in the coordinate system where the initial matrix is diagonal, we bound the rate of convergence of coordinates that tend to zero.

Completed ill-conditioned, positive, definite, matrices, rank-one, eigenvalues, eigenvectors, condition, number, convergence abstracts ...
Prompt: 
Provide a comma separated list of 10 words uniquely identifying this abstract. 
The words should only be unitary words.
Do not use hyphenated words. 
Do not use phrases.

Abstract

In [None]:
mistral_client = MistralClient(api_key=MISTRAL_API_KEY)


def get_embeddings_from_mistral(input: list[str]) -> list[list[float]]:
    embeddings_batch_response = mistral_client.embeddings(
        model="mistral-embed",
        input=input,
        )
    return embeddings_batch_response


def process_embedding_batches_with_mistral(input: list[str], batch_size: int) -> list[list[float]]:
    embeddings = []
    for i in range(0, len(input), batch_size):
        batch = abstracts[i : i + batch_size]
        batch_embedding = get_embeddings_from_mistral(batch)
    embeddings.extend(batch_embedding)
    return embeddings
    

def get_descriptive_words_from_mistral(input: list[str], model: str="open-mixtral-8x7b") -> list[str]:
    # open-mistral-7b
    # open-mixtral-8x7b
    # mistral-small-latest
    # mistral-medium-latest
    # mistral-large-latest

    messages = [
        ChatMessage(role="user", content=input)
    ]
    chat_response = mistral_client.chat(
        model=model,
        messages=messages,
        # response_format={"type": "json_object"},
    )
    return chat_response # [0].message.content

prompt_template = """
For each abstract in the list of <abstract_count> abstracts, provide <word_count> words that uniquely identify the abstract.

List of abstracts: 
<list_of_abstracts>

Return json data with exactly <abstract_count> records. 
Example response:
{
    "abstract_1": [<ten words>],
}
"""

chunk = {f"abstract_{i+1}": sentence for i, sentence in enumerate(abstracts[0: 10])}

prompt = prompt_template.replace("<abstract_count>", str(10))
prompt = prompt.replace("<word_count>", str(10))
prompt = prompt.replace("<list_of_abstracts>", str(chunk))
print(f"Prompt: {prompt}")


# embeddings = get_embeddings_from_mistral(["sentence one", "sentence two"])
#words = get_descriptive_words_from_mistral(prompt)
#json.loads(words.choices[0].message.content)