In [None]:
import os
import re
import json
import glob
import argparse

In [None]:
import torch
import accelerate
import transformers
import huggingface_hub
from tqdm.auto import tqdm

In [None]:
ROOT_PATH = "./"
MODEL_DIRECTORY = os.path.join(".", "models")
os.makedirs(MODEL_DIRECTORY, exist_ok=True)

In [None]:
DATASET_PATH = os.path.join(ROOT_PATH, "datasets", "cbp-lkg", "legalkg_dataset_prompts.json")

In [None]:
WHITESPACE_CHAR_REGEX = re.compile(r"(?ui)\s+")
SPECIAL_CHAR_REGEX = re.compile(r"(?ui)(?:[^\w\s]\s*)*[^\w\s]")

def preprocess(entity):
    entity = WHITESPACE_CHAR_REGEX.sub(" ", entity).strip()
    entity = SPECIAL_CHAR_REGEX.sub("", entity).strip()
    return entity

In [None]:
class TriplesAsQADataset(torch.utils.data.Dataset):
    def __init__(self, path, args, max_triples=-1):
      self.tokenizer = transformers.AutoTokenizer\
          .from_pretrained(args.base_model)
      self._read_data(path, max_triples)

    def _read_data(self, path, max_triples):
        lines_read, inputs, outputs = 0, [], []
        with open(path, 'r', encoding='utf-8') as file:
          data = json.load(file)
          for entry in tqdm(data):
            if lines_read == max_triples: break
            if len(entry['pair']) != 2:
                print(entry)
                break
            _in, _out = entry['pair']
            inputs.append("Answer in a few words: " + _in)
            outputs.append(_out)
            lines_read += 1
        self.data = { 'inputs': inputs, 'outputs': outputs }

    def __len__(self):
        return len(self.data['inputs'])

    def __getitem__(self, index):
        data = self.data
        input  = data['inputs' ][index]
        output = data['outputs'][index]
        return input, output

    def _collate_fn(self, items):
        inputs_tokenized  = self.tokenizer(
            list(item[0] for item in items),
            padding=True, truncation=True,
            max_length=128, return_tensors="pt"
        )
        outputs_tokenized = self.tokenizer(
            list(item[1] for item in items),
            padding=True, truncation=True,
            max_length=32, return_tensors="pt"
        )
        input_ids, attention_mask     = inputs_tokenized.input_ids, inputs_tokenized.attention_mask
        labels, labels_attention_mask = outputs_tokenized.input_ids, outputs_tokenized.attention_mask
        # for labels, set -100 for padding
        labels[labels==0] = -100
        return input_ids, attention_mask, labels, labels_attention_mask

    def decode(self, token_ids):
        return ''.join(self.tokenizer.convert_ids_to_tokens(token_ids))

In [None]:
args = argparse.Namespace(
    base_model_prefix      = "google/flan-t5",
    model_size             = "small",
    model_stem             = "cbp-lkg-qa",
    epochs                 = 10,
    batch_size             = 32,
    save_checkpoint        = 5000,
    loss_checkpoint        = 500,
    num_workers            = 3,
    checkpoint             = 0,
    max_checkpoints        = 5,
    resume_from_checkpoint = True,
    gradient_checkpointing = True,
    skip_batches           = 0,
    finetuning             = False
)
args.suffix = "-finetuned" if args.finetuning else ""
args.model = "{0.base_model_prefix}-{0.model_stem}-{0.model_size}{0.suffix}".format(args)
args.model = os.path.basename(args.model)
args.model_directory = os.path.join(MODEL_DIRECTORY, args.model)
args.base_model = "{0.base_model_prefix}-{0.model_size}".format(args)

if args.resume_from_checkpoint:
  files = glob.glob(os.path.join(args.model_directory, "chkpt_*.pt"))
  if len(files) > 0:
    args.checkpoint = max(int(os.path.basename(file)[6:-3]) for file in files)
  else:
    args.checkpoint = 0
dataset = TriplesAsQADataset(DATASET_PATH, args)

In [None]:
dataset.__getitem__(0)

In [None]:
from dotenv import load_dotenv
from genai.client import Client
from genai.credentials import Credentials
from genai.schema import TextGenerationParameters, TextGenerationReturnOptions
from genai.text.generation import CreateExecutionOptions

In [None]:
import json
from typing import List, Dict

def heading(text: str) -> str:
    """Helper function for centering text."""
    return "\n" + f" {text} ".center(80, "=") + "\n"

load_dotenv()

GENAI_KEY=""
GENAI_API="https://bam-api.res.ibm.com"

credentials = Credentials(api_key=GENAI_KEY, api=GENAI_API)
client = Client(credentials=credentials)

class QuestionAnsweringModel:
    def __init__(self, model_name):
        self.model_name = model_name
    
    def llm_response_sdk(self, prompt):
        try:
            parameters = TextGenerationParameters(
                max_new_tokens=200,
                min_new_tokens=1,
                decoding_method="greedy",
                return_options=TextGenerationReturnOptions(
                    input_text=True,
                ),
            )

            response = client.text.generation.create(
                model_id=self.model_name,
                inputs=[prompt],
                parameters=parameters,
            )

            # Assuming the first result in the batch is what we want
            result = next(response).results[0]
            generated_text = result.generated_text.split("Input")[0].strip()
            return generated_text
        except Exception as e:
            print(f"Error during SDK request: {e}")
            return "Error processing request."

    def read_dataset(self, file_path) -> List[Dict]:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data

    def generate_prompt(self, triple, question):
        return f"Given the following triples {triple[0]}, {triple[1]}, and {triple[2]}, you need to generate answer for the following question. Give as many answers as possible each separated by a semicolon. Question: {question}"

    def evaluate(self, file_path):
        dataset = self.read_dataset(file_path)
        hits_at_1 = 0
        hits_at_5 = 0
        hits_at_10 = 0
        total = len(dataset)
        
        i = 0
        for item in dataset:
            if i > 10:
                break
            triple = item["triple"]
            question, ground_truth = item["pair"]
            prompt = self.generate_prompt(triple, question)
            print(prompt)
            generated_answers = self.llm_response_sdk(prompt)
            print(generated_answers)
            
            if ground_truth in generated_answers[:1]:
                hits_at_1 += 1
            if ground_truth in generated_answers[:5]:
                hits_at_5 += 1
            if ground_truth in generated_answers[:10]:
                hits_at_10 += 1

        accuracy_at_1 = hits_at_1 / total if total > 0 else 0
        accuracy_at_5 = hits_at_5 / total if total > 0 else 0
        accuracy_at_10 = hits_at_10 / total if total > 0 else 0

        filename = "./results/metrics.txt"
        with open(filename, 'w') as file:
            file.write(f"Hits at 1 Accuracy: {accuracy_at_1*100:.2f}%\n")
            file.write(f"Hits at 5 Accuracy: {accuracy_at_5*100:.2f}%\n")
            file.write(f"Hits at 10 Accuracy: {accuracy_at_10*100:.2f}%\n")

model_name = "codellama/codellama-34b-instruct"
qa_model = QuestionAnsweringModel(model_name)
qa_model.evaluate("./datasets/cbp-lkg/legalkg_dataset_prompts.json")
