In [None]:
from llama_index.llms.llama_api import LlamaAPI

with open("../keys/llama.txt", "r") as f:
    api_key = f.read()

llm_llama = LlamaAPI(api_key=api_key, temperature=0.0)


In [None]:
from llama_index.llms.openai import OpenAI

with open("../keys/openai.txt", "r") as f:
    api_key = f.read()

model = "gpt-3.5-turbo-0125"
llm_openai = OpenAI(model=model, api_key=api_key, temperature=0.0)


# ZeroShot

In [None]:
from enum import Enum
from pydantic import BaseModel, Field
from typing import List

class Label(str, Enum):
    title = 'title'
    performer = 'performer'

class MusicEntity(BaseModel):
    """Data model of a music entity"""
    utterance: str 
    label: Label
    cue: str

    class Config:  
        use_enum_values = True
        
class EntityList(BaseModel):
    """Data model for list of music entities."""
    content: List[MusicEntity]
    

In [None]:
from llama_index.program.openai import OpenAIPydanticProgram

prompt_template = """\
From the following text which contains a user requests for music suggestions, extract all the music entities.
A music entity has the following attributes:
    - utterance: The utterance of the entity in the text. For example "the beatles" in "recommend me music like the beatles".
    - label: The label of the entity. It can either be 'title' (eg. a song title, an album title, a symphony) or it can be 'performer' which refers to a performing musical artist.
    - cue: The contextual cue which indicates the musical entity (eg. "music like" in "recommend me music like the beatles" indicating "the beatles")
Here is the text: {text}
"""

program = OpenAIPydanticProgram.from_defaults(
    output_cls=EntityList,
    llm=llm_openai,
    prompt_template_str=prompt_template,
    allow_multiple=False,
    verbose=False,
)


In [None]:
import sys
sys.path.append("..")
from tqdm import tqdm
from src.Utils import read_IOB_file, transform_to_dict, write_jsonlines
import os 

# load data
dataset_id = str(1)
data_path = f"../baseline/music-ner-eacl2023/data/dataset{dataset_id}/test.bio"
texts, labels = read_IOB_file(data_path)

outputs = []
for tokens, iob in tqdm(zip(texts, labels)):

    text = ' '.join(tokens)
    true_ents = transform_to_dict(tokens, iob)

    # put input data and true entities
    output = {}
    output["text"] = text
    output["performers"] = true_ents.get("Artist") or []
    output["titles"] = true_ents.get("WoA") or []

    # extract with LLM
    ent_list = program(text=text)
    llm_ents = [ent.model_dump() for ent in ent_list.content]
    output["extracted"] = llm_ents

    outputs.append(output)

# write output    
output_dir = os.path.join("..", "output", model)
os.makedirs(output_dir, exist_ok=True)

write_jsonlines(output_dir + os.sep + f"reddit{dataset_id}.jsonl", outputs)


# Mixtral

In [None]:
# Just runs .complete to make sure the LLM is listening
from llama_index.llms.ollama import Ollama
from llama_index.core.program import FunctionCallingProgram, LLMTextCompletionProgram

llm_mixtral = Ollama(model="mixtral")

program = LLMTextCompletionProgram.from_defaults(
    output_cls=EntityList,
    llm=llm_mixtral,
    prompt_template_str=prompt_template,
    allow_multiple=False,
    verbose=False,
)
