# RGU-IIT CBR Generator

## Download Import Libraries

### Download Libraries

In [1]:
!pip install mistral -q
!pip install mistralai -q
!pip install datasets

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/284.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.6/284.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m276.5/284.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m284.6/284.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.4/233.4 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.5/58.5 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m360.5/360.5 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m


### Import Libraries

In [None]:
from tqdm import tqdm
from typing import *
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

import numpy as np
import requests
import pandas as pd
import json
import numpy as np
import pandas as pd
import time

## Mistral AI setup

In [None]:
api_key = "XXX"
model = "open-mistral-7b"
client = MistralClient(api_key=api_key)

## CBR system with Similarity Indexers

In [None]:
class Case:
    def __init__(self, index, question, matching_question_embeddings, text, retrieval_text_embeddings, keywords, retrieval_keywords_embeddings, answer):
        self.index = index
        self.question = question
        self.matching_question_embeddings = np.array(eval(matching_question_embeddings))
        self.text = text
        self.retrieval_text_embeddings = np.array(eval(retrieval_text_embeddings))
        self.keywords = keywords
        self.retrieval_keywords_embeddings = np.array(eval(retrieval_keywords_embeddings))
        self.answer = answer

case_database = []

class CbrSystem:
    def __init__(self, cases):
        self.cases = cases
        self.global_matching_question_embeddings = []
        self.global_retrieval_text_embeddings = []
        self.global_retrieval_keyword_embeddings = []
        cases = cases
        for case in self.cases:
            self.global_matching_question_embeddings.append(case.matching_question_embeddings)
            self.global_retrieval_text_embeddings.append(case.retrieval_text_embeddings)
            self.global_retrieval_keyword_embeddings.append(case.retrieval_keywords_embeddings)

    def retrieve_matches(self, matching_query_embeddings, retrieval_query_embeddings, document_count, question_weight, snippet_weight, keywords_weight):
        question_similarities = np.zeros(len(self.cases))
        snippet_similarities = np.zeros(len(self.cases))
        keyword_similarities = np.zeros(len(self.cases))

        if question_weight > 0:
            question_similarities = np.array([cosine_similarity([matching_query_embeddings], [matching_question_embedding])[0][0] for matching_question_embedding in self.global_matching_question_embeddings])
        if snippet_weight > 0:
            snippet_similarities = np.array([cosine_similarity([retrieval_query_embeddings], [retrieval_text_embeddings])[0][0] for retrieval_text_embeddings in self.global_retrieval_text_embeddings])
        if keywords_weight > 0:
            keyword_similarities = np.array([cosine_similarity([retrieval_query_embeddings], [retrieval_keyword_embeddings])[0][0] for retrieval_keyword_embeddings in self.global_retrieval_keyword_embeddings])

        combined_similarities = (question_weight * np.array(question_similarities)) + (snippet_weight * np.array(snippet_similarities)) + (keywords_weight * np.array(keyword_similarities))
        closest_case_indices = np.argsort(combined_similarities)[-document_count:][::-1]

        results = [self.cases[index] for index in closest_case_indices]
        return results, closest_case_indices

## Load Assets from Pre-Built Sources

In [None]:
pd.read_csv("resources/cbr_dataset_embeddings.csv").columns

Index(['question', 'answer', 'snippet', 'snippet_sentences', 'keywords',
       '__index_level_0__', 'referenced_acts',
       'question_normal_bert_matching_embeddings',
       'question_legal_bert_matching_embeddings',
       'question_angle_bert_matching_embeddings',
       'question_normal_bert_retrieval_embeddings',
       'question_legal_bert_retrieval_embeddings',
       'question_angle_bert_retrieval_embeddings',
       'answer_normal_bert_matching_embeddings',
       'answer_legal_bert_matching_embeddings',
       'answer_angle_bert_matching_embeddings',
       'answer_normal_bert_retrieval_embeddings',
       'answer_legal_bert_retrieval_embeddings',
       'answer_angle_bert_retrieval_embeddings',
       'keyword_normal_bert_matching_embeddings',
       'keyword_legal_bert_matching_embeddings',
       'keyword_angle_bert_matching_embeddings',
       'keyword_normal_bert_retrieval_embeddings',
       'keyword_legal_bert_retrieval_embeddings',
       'keyword_angle_bert_retrie

In [None]:
df = pd.read_csv("resources/cbr_dataset_embeddings.csv")
selected_columns = [
                    'question', 'answer', 'snippet', 'keywords',
                    'question_normal_bert_matching_embeddings',
                    'question_legal_bert_matching_embeddings',
                    'question_angle_bert_matching_embeddings',
                    'question_normal_bert_retrieval_embeddings',
                    'question_legal_bert_retrieval_embeddings',
                    'question_angle_bert_retrieval_embeddings',
                    'answer_normal_bert_matching_embeddings',
                    'answer_legal_bert_matching_embeddings',
                    'answer_angle_bert_matching_embeddings',
                    'answer_normal_bert_retrieval_embeddings',
                    'answer_legal_bert_retrieval_embeddings',
                    'answer_angle_bert_retrieval_embeddings',
                    'keyword_normal_bert_matching_embeddings',
                    'keyword_legal_bert_matching_embeddings',
                    'keyword_angle_bert_matching_embeddings',
                    'keyword_normal_bert_retrieval_embeddings',
                    'keyword_legal_bert_retrieval_embeddings',
                    'keyword_angle_bert_retrieval_embeddings',
                    'snippet_normal_bert_matching_embeddings',
                    'snippet_legal_bert_matching_embeddings',
                    'snippet_angle_bert_matching_embeddings',
                    'snippet_normal_bert_retrieval_embeddings',
                    'snippet_legal_bert_retrieval_embeddings',
                    'snippet_angle_bert_retrieval_embeddings'
                    ]
main_df = df.loc[:, selected_columns].drop_duplicates(subset='question').reset_index(drop=True)
print(len(main_df))

2084


In [None]:
cases_normal_bert = []
cases_legal_bert = []
cases_angle_bert = []

In [None]:
for index, row in tqdm(main_df.iterrows()):
  cases_normal_bert.append(Case(index, row['question'], row['question_normal_bert_matching_embeddings'], row['snippet'], row['snippet_normal_bert_retrieval_embeddings'], row['keywords'], row['keyword_normal_bert_retrieval_embeddings'], row['answer']))
  cases_legal_bert.append(Case(index, row['question'], row['question_legal_bert_matching_embeddings'], row['snippet'], row['snippet_legal_bert_retrieval_embeddings'], row['keywords'], row['keyword_legal_bert_retrieval_embeddings'], row['answer']))
  cases_angle_bert.append(Case(index, row['question'], row['question_angle_bert_matching_embeddings'], row['snippet'], row['snippet_angle_bert_retrieval_embeddings'], row['keywords'], row['keyword_angle_bert_retrieval_embeddings'], row['answer']))


2084it [01:34, 22.12it/s]


In [None]:
print(len(cases_normal_bert))
print(len(cases_legal_bert))
print(len(cases_angle_bert))

2084
2084
2084


In [None]:
normal_bert_cbr_system = CbrSystem(cases_normal_bert)
legal_bert_cbr_system = CbrSystem(cases_legal_bert)
angle_bert_cbr_system = CbrSystem(cases_angle_bert)

print('\n Cases loaded into the systems');


 Cases loaded into the systems


## Mistral Inference Component

In [None]:
def get_mistral_response(request):
  time.sleep(0.5)
  messages = [
      ChatMessage(role="user", content=request)
  ]

  chat_response = client.chat(
      model=model,
      messages=messages,
  )

  return chat_response.choices[0].message.content

## Algorithms

In [None]:
def execute_pipeline_no_rag(question, is_simulation=False):
  response = get_mistral_response(question)
  if (is_simulation):
    return None
  return response

def execute_pipeline_x_snippet(cbr_system, question, question_matching_embeddings, question_retrieval_embeddings, k, w1, w2, w3, is_simulation=False):
    matched_cases, indexes = cbr_system.retrieve_matches(np.array(eval(question_matching_embeddings)), np.array(eval(question_retrieval_embeddings)), k, w1, w2, w3)
    if (is_simulation):
        return None, None, indexes
    matched_case_string = ['"' + matched_case.text + '"' for matched_case in matched_cases]
    context = ' | '.join(matched_case_string)
    response = get_mistral_response('Answer "' + question + '" by using the following contexts: { ' + context + ' }')
    return response, context, indexes

def execute_pipeline_x_case(cbr_system, question, question_matching_embeddings, question_retrieval_embeddings, k, w1, w2, w3, is_simulation=False):
    matched_cases, indexes = cbr_system.retrieve_matches(np.array(eval(question_matching_embeddings)), np.array(eval(question_retrieval_embeddings)), k, w1, w2, w3)
    if (is_simulation):
        return None, None, indexes
    matched_case_string = ['"Question: "' + matched_case.question + '", citation: "' + matched_case.text + '" and " answer: "' + matched_case.answer + '"' for matched_case in matched_cases]
    context = ' | '.join(matched_case_string)
    response = get_mistral_response('Answer "' + question + '" as a simple string (with no structure) by using the following question, citation and answer tuples as context: { ' + context + ' }')
    return response, context, indexes

# Run - Mistral generation tests

In [None]:
k = 2084
df_test = pd.read_csv("resources/test_dataset_embeddings.csv")
print(len(df_test))

35


In [None]:
new_df = pd.DataFrame(columns=['case_index','no_rag_pipeline_result','normal_bert_hybrid_snippet_k1_indexes','normal_bert_hybrid_snippet_k1_context','normal_bert_hybrid_snippet_k1_result','normal_bert_hybrid_case_k1_indexes','normal_bert_hybrid_case_k1_context','normal_bert_hybrid_case_k1_result','legal_bert_hybrid_snippet_k1_indexes','legal_bert_hybrid_snippet_k1_context','legal_bert_hybrid_snippet_k1_result','legal_bert_hybrid_case_k1_indexes','legal_bert_hybrid_case_k1_context','legal_bert_hybrid_case_k1_result','angle_bert_hybrid_snippet_k1_indexes','angle_bert_hybrid_snippet_k1_context','angle_bert_hybrid_snippet_k1_result','angle_bert_hybrid_case_k1_indexes','angle_bert_hybrid_case_k1_context','angle_bert_hybrid_case_k1_result','normal_bert_hybrid_snippet_k3_indexes','normal_bert_hybrid_snippet_k3_context','normal_bert_hybrid_snippet_k3_result','normal_bert_hybrid_case_k3_indexes','normal_bert_hybrid_case_k3_context','normal_bert_hybrid_case_k3_result','legal_bert_hybrid_snippet_k3_indexes','legal_bert_hybrid_snippet_k3_context','legal_bert_hybrid_snippet_k3_result','legal_bert_hybrid_case_k3_indexes','legal_bert_hybrid_case_k3_context','legal_bert_hybrid_case_k3_result','angle_bert_hybrid_snippet_k3_indexes','angle_bert_hybrid_snippet_k3_context','angle_bert_hybrid_snippet_k3_result','angle_bert_hybrid_case_k3_indexes','angle_bert_hybrid_case_k3_context','angle_bert_hybrid_case_k3_result'])

i = 0
for index, row in df_test.iterrows():
    row['case_index'] = index

    # k = 0
    no_rag_pipeline_result = execute_pipeline_no_rag(row['question'])
    row['no_rag_pipeline_result'] = no_rag_pipeline_result

    # k = 1
    normal_bert_hybrid_snippet_k1_result, normal_bert_hybrid_snippet_k1_context, normal_bert_hybrid_snippet_k1_indexes = execute_pipeline_x_snippet(normal_bert_cbr_system, row['question'], row['question_normal_bert_matching_embeddings'], row['question_normal_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['normal_bert_hybrid_snippet_k1_indexes'] = normal_bert_hybrid_snippet_k1_indexes
    row['normal_bert_hybrid_snippet_k1_context'] = normal_bert_hybrid_snippet_k1_context
    row['normal_bert_hybrid_snippet_k1_result'] = normal_bert_hybrid_snippet_k1_result
    normal_bert_hybrid_case_k1_result, normal_bert_hybrid_case_k1_context, normal_bert_hybrid_k1_indexes = execute_pipeline_x_case(normal_bert_cbr_system, row['question'], row['question_normal_bert_matching_embeddings'], row['question_normal_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['normal_bert_hybrid_case_k1_indexes'] = normal_bert_hybrid_k1_indexes
    row['normal_bert_hybrid_case_k1_context'] = normal_bert_hybrid_case_k1_context
    row['normal_bert_hybrid_case_k1_result'] = normal_bert_hybrid_case_k1_result
    legal_bert_hybrid_snippet_k1_result, legal_bert_hybrid_snippet_k1_context, legal_bert_hybrid_snippet_k1_indexes = execute_pipeline_x_snippet(legal_bert_cbr_system, row['question'], row['question_legal_bert_matching_embeddings'], row['question_legal_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['legal_bert_hybrid_snippet_k1_indexes'] = legal_bert_hybrid_snippet_k1_indexes
    row['legal_bert_hybrid_snippet_k1_context'] = legal_bert_hybrid_snippet_k1_context
    row['legal_bert_hybrid_snippet_k1_result'] = legal_bert_hybrid_snippet_k1_result
    legal_bert_hybrid_case_k1_result, legal_bert_hybrid_case_k1_context, legal_bert_hybrid_k1_indexes = execute_pipeline_x_case(legal_bert_cbr_system, row['question'], row['question_legal_bert_matching_embeddings'], row['question_legal_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['legal_bert_hybrid_case_k1_indexes'] = legal_bert_hybrid_k1_indexes
    row['legal_bert_hybrid_case_k1_context'] = legal_bert_hybrid_case_k1_context
    row['legal_bert_hybrid_case_k1_result'] = legal_bert_hybrid_case_k1_result
    angle_bert_hybrid_snippet_k1_result, angle_bert_hybrid_snippet_k1_context, angle_bert_hybrid_snippet_k1_indexes = execute_pipeline_x_snippet(angle_bert_cbr_system, row['question'], row['question_angle_bert_matching_embeddings'], row['question_angle_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['angle_bert_hybrid_snippet_k1_indexes'] = angle_bert_hybrid_snippet_k1_indexes
    row['angle_bert_hybrid_snippet_k1_context'] = angle_bert_hybrid_snippet_k1_context
    row['angle_bert_hybrid_snippet_k1_result'] = angle_bert_hybrid_snippet_k1_result
    angle_bert_hybrid_case_k1_result, angle_bert_hybrid_case_k1_context, angle_bert_hybrid_k1_indexes = execute_pipeline_x_case(angle_bert_cbr_system, row['question'], row['question_angle_bert_matching_embeddings'], row['question_angle_bert_retrieval_embeddings'], 1, 0.25, 0.4, 0.35)
    row['angle_bert_hybrid_case_k1_indexes'] = angle_bert_hybrid_k1_indexes
    row['angle_bert_hybrid_case_k1_context'] = angle_bert_hybrid_case_k1_context
    row['angle_bert_hybrid_case_k1_result'] = angle_bert_hybrid_case_k1_result

    # k = 3
    normal_bert_hybrid_snippet_k3_result, normal_bert_hybrid_snippet_k3_context, normal_bert_hybrid_snippet_k3_indexes = execute_pipeline_x_snippet(normal_bert_cbr_system, row['question'], row['question_normal_bert_matching_embeddings'], row['question_normal_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['normal_bert_hybrid_snippet_k3_indexes'] = normal_bert_hybrid_snippet_k3_indexes
    row['normal_bert_hybrid_snippet_k3_context'] = normal_bert_hybrid_snippet_k3_context
    row['normal_bert_hybrid_snippet_k3_result'] = normal_bert_hybrid_snippet_k3_result
    normal_bert_hybrid_k3_case_result, normal_bert_hybrid_k3_case_context, normal_bert_hybrid_k3_indexes = execute_pipeline_x_case(normal_bert_cbr_system, row['question'], row['question_normal_bert_matching_embeddings'], row['question_normal_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['normal_bert_hybrid_case_k3_indexes'] = normal_bert_hybrid_k3_indexes
    row['normal_bert_hybrid_case_k3_context'] = normal_bert_hybrid_k3_case_context
    row['normal_bert_hybrid_case_k3_result'] = normal_bert_hybrid_k3_case_result
    legal_bert_hybrid_snippet_k3_result, legal_bert_hybrid_snippet_k3_context, legal_bert_hybrid_snippet_k3_indexes = execute_pipeline_x_snippet(legal_bert_cbr_system, row['question'], row['question_legal_bert_matching_embeddings'], row['question_legal_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['legal_bert_hybrid_snippet_k3_indexes'] = legal_bert_hybrid_snippet_k3_indexes
    row['legal_bert_hybrid_snippet_k3_context'] = legal_bert_hybrid_snippet_k3_context
    row['legal_bert_hybrid_snippet_k3_result'] = legal_bert_hybrid_snippet_k3_result
    legal_bert_hybrid_case_k3_result, legal_bert_hybrid_case_k3_context, legal_bert_hybrid_indexes = execute_pipeline_x_case(legal_bert_cbr_system, row['question'], row['question_legal_bert_matching_embeddings'], row['question_legal_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['legal_bert_hybrid_case_k3_indexes'] = legal_bert_hybrid_indexes
    row['legal_bert_hybrid_case_k3_context'] = legal_bert_hybrid_case_k3_context
    row['legal_bert_hybrid_case_k3_result'] = legal_bert_hybrid_case_k3_result
    angle_bert_hybrid_snippet_k3_result, angle_bert_hybrid_snippet_k3_context, angle_bert_hybrid_snippet_k3_indexes = execute_pipeline_x_snippet(angle_bert_cbr_system, row['question'], row['question_angle_bert_matching_embeddings'], row['question_angle_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['angle_bert_hybrid_snippet_k3_indexes'] = angle_bert_hybrid_snippet_k3_indexes
    row['angle_bert_hybrid_snippet_k3_context'] = angle_bert_hybrid_snippet_k3_context
    row['angle_bert_hybrid_snippet_k3_result'] = angle_bert_hybrid_snippet_k3_result
    angle_bert_hybrid_case_k3_result, angle_bert_hybrid_case_k3_context, angle_bert_hybrid_indexes = execute_pipeline_x_case(angle_bert_cbr_system, row['question'], row['question_angle_bert_matching_embeddings'], row['question_angle_bert_retrieval_embeddings'], 3, 0.25, 0.4, 0.35)
    row['angle_bert_hybrid_case_k3_indexes'] = angle_bert_hybrid_indexes
    row['angle_bert_hybrid_case_k3_context'] = angle_bert_hybrid_case_k3_context
    row['angle_bert_hybrid_case_k3_result'] = angle_bert_hybrid_case_k3_result

    new_df = new_df.append(row)

    i = i + 1
    print(i)

In [None]:
new_df

In [None]:
!huggingface-cli login --token XXX
new_df.reset_index(drop=True, inplace=True)

output_dataset = DatasetDict({'w025w040w035': Dataset.from_pandas(new_df)})
output_dataset.push_to_hub('XXX')