In [None]:
!pip install -u transformers
!pip install accelerate

In [None]:
import csv


with open("../data/paragraphs_test.csv", "r", encoding="utf-8") as csv_f:
    data = csv.DictReader(csv_f)
    zibaldone_test = list(data)    

In [None]:
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    token="put your token"
)

In [None]:
from tqdm import tqdm
import re

prompt = """
Estrai i riferimenti a persone, luoghi e opere all'interno del testo in input. 
Riscrivi il testo in input con i riferimenti annotati come nell'esempio seguente: 'Il primo codice della <WORK>Divina Commedia</WORK> di <PER>Dante</PER> è conservato presso la <LOC>Biblioteca Passerini-Landi</LOC> di <LOC>Piacenza</LOC>.'
Se nessun entità è menzionata, ritorna il testo così com'è.
Input: 
"""

entities = []

pbar = tqdm(total=len(zibaldone_test))


for row in zibaldone_test:
    tokens = re.split(r"\W", row["text"])
    curr_end = 0
    if len(tokens) > 1000:
        new_tokens = 2048
    elif len(tokens) > 500:
        new_tokens = 1024
    else: 
        new_tokens = 512 
    messages = [
        {"role": "system", "content": "Sei un utile sistema di annotazione di testi."},
        {"role": "user", "content": prompt+row["text"]},
    ]
    
    outputs = pipeline(
        messages,
        max_new_tokens=new_tokens,
    )
    response = outputs[0]["generated_text"][-1]["content"]
    pattern = r'<(?P<type>\w+)>(?P<surface_form>.*?)</\1>'
    for match in re.finditer(pattern, response):
        if match.group("surface_form") in row["text"] and match.group("type") in {"PER", "WORK", "LOC"}:
            matches_2 = re.finditer(match.group("surface_form"), row["text"])
            for match_2 in matches_2:
                if match_2.end()>curr_end:
                    entity = {
                        "id":row["id"],
                        "surface_form": match.group("surface_form"),
                        "start_pos": match_2.start(),
                        "end_pos": match_2.end(),
                        "type": match.group("type")
                    }
                    entities.append(entity)
                    curr_end = match_2.end()
                    break
    pbar.update(1)
    
    
    

In [None]:
import csv
import os

if not os.path.exists("../results/llama3_1_generative"):
    os.makedirs("../results/llama3_1_generative")

keys = entities[0].keys()
with open("../results/llama3_1_generative/output.csv", "w", encoding="utf-8") as f:
    dict_writer = csv.DictWriter(f, keys)
    dict_writer.writeheader()
    dict_writer.writerows(entities)
f.close()