<a href="https://colab.research.google.com/github/sebastianrohr/AFAE_exercises/blob/main/t5_flan_rag_no_outputs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install rank-bm25 nltk transformers sentencepiece

In [None]:
import pandas as pd
import re
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import nltk
from transformers import T5ForConditionalGeneration, T5Tokenizer
import string

nltk.download('punkt')
nltk.download('stopwords')

In [None]:
def parse_docs(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        file_content = file.read()

    # Regular expression to match <doc> elements
    doc_pattern = re.compile(r'<doc id="([^"]+)" url="([^"]+)" title="([^"]+)">(.*?)</doc>', re.DOTALL)

    # Find all matches
    matches = doc_pattern.findall(file_content)

    # Extract data and create a list of dictionaries
    docs = [{'id': match[0], 'url': match[1], 'title': match[2], 'text': match[3].strip()} for match in matches]
    return docs

# Create a DataFrame
def create_dataframe(docs):
    return pd.DataFrame(docs)

def is_only_special_chars(word):
    special_chars = set(string.punctuation)
    # Check if each character in the word is a special character
    return all(char in special_chars for char in word)

# Preprocess text: tokenize and remove stopwords
def preprocess(text):
    # Define a set of special characters
    special_chars = set(string.punctuation)

    # Initialize the stemmer
    stemmer = PorterStemmer()

    # Load set of English stopwords
    stop_words = set(stopwords.words('english'))

    # Tokenize and convert to lowercase
    tokens = word_tokenize(text.lower())

    # Preprocess tokens
    processed_tokens = []
    for token in tokens:
        if token not in stop_words and not all(char in special_chars for char in token):
            stemmed_token = stemmer.stem(token)
            processed_tokens.append(stemmed_token)

    return processed_tokens

file_path = 'drive/MyDrive/apps/data-anlp/starwarsfandomcom-removed-space-and.txt'  # Replace with the actual path to your file
docs = parse_docs(file_path)
df = create_dataframe(docs)
df['tokenized_text'] = df['text'].apply(preprocess)
df = df[df['title'] != df['text']]

In [None]:
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')

In [None]:
## dev set

dev_set_path = '/content/drive/MyDrive/data-anlp/star_wars_dataset_dev/'

# Function to read a file and return a list of lines
def read_file_to_list(file_path):
    with open(file_path, 'r') as file:
        return [line.strip() for line in file]

# Load the data
dev_questions = read_file_to_list(dev_set_path + 'questions.txt')
dev_answers = read_file_to_list(dev_set_path + 'answers.txt')

In [None]:
def robs_eval(gold, pred):
    """
    An answer is considered correct if at least half of the gold
    tokens are in the prediction. Note that this is a shortcut,
    and will favor long answers.
    """
    gold = set(gold.strip().lower().replace('.', '').split(' '))
    pred = set(pred.strip().lower().replace('.', '').split(' '))
    return len(gold.intersection(pred)) >= len(gold)/2

In [None]:
def query_bm25(query, num_results=10):
    query_tokens = preprocess(query)
    print(query_tokens)
    doc_scores = bm25.get_scores(query_tokens)
    top_doc_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:num_results]
    return df.iloc[top_doc_indices]['text'].str.cat(sep=' ')

def generate_response(query, context):
    # Concatenate the query and context
    input_text = f'Answer this question: "{query}". Based off the following context: "{context}".'

    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Truncate the input to the model's max length if necessary
    max_length = tokenizer.model_max_length
    if input_ids.size(1) > max_length:
        input_ids = input_ids[:, :max_length]

    truncated_input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    output_ids = model.generate(input_ids)[0]
    response = tokenizer.decode(output_ids, skip_special_tokens=True)
    return response, context, truncated_input_text

In [None]:
### grid search

from sklearn.model_selection import ParameterGrid
import numpy as np
from tqdm import tqdm

max_pairs = len(dev_questions)
#max_pairs = 10

param_grid = {
    'k1': np.linspace(0.75, 1.75, num=3),
    'b': np.linspace(0.3, 0.7, num=3)
}
grid = ParameterGrid(param_grid)
best_score = -1
best_params = None
results = {}  # Dictionary to store results

for params in tqdm(grid, desc="Grid Search"):
    print(f"\n{params}")
    bm25 = BM25Okapi(corpus=df['tokenized_text'].tolist(), k1=params['k1'], b=params['b'])
    total_correct = 0

    for query, answer in tqdm(zip(dev_questions[:max_pairs], dev_answers[:max_pairs]), total=len(dev_questions[:max_pairs]), desc="Evaluating Queries"):
        context = query_bm25(query, 15)
        response, context, truncated_input_text = generate_response(query, context)
        print(f"""
{query}
{response}
{answer}
----------------------------------------------""")

        if robs_eval(answer, response):
            total_correct += 1
            print("total = " + str(total_correct))

    # Save the results
    results[str(params)] = total_correct

    # Assess the effectiveness of the parameters
    print(f"Total Correct: {total_correct}")
    if total_correct > best_score:
        best_score = total_correct
        best_params = params

print(f"Best Params: {best_params}, Best Score: {best_score}")
# Print the results dictionary
print("\nAll Results:")
for param, score in results.items():
    print(f"{param}: {score}")


In [None]:
test_set_path = '/content/drive/MyDrive/data-anlp/star_wars_dataset_test/'
test_questions = read_file_to_list(test_set_path + 'questions.txt')
test_answers = read_file_to_list(test_set_path + 'answers.txt')

bm25 = BM25Okapi(corpus=df['tokenized_text'].tolist(), k1=best_params['k1'], b=best_params['b'])

data = []

total_correct = 0

for query, answer in tqdm(zip(test_questions, test_answers), total=len(test_answers), desc="Processing"):
    context = query_bm25(query, 15)

    response, context, truncated_input_text = generate_response(query, context)

    # Print the text (optional, can be removed if not needed)
    print(f"""
{truncated_input_text}
{response}
{answer}
----------------------------------------------""")

    # Check if the response is correct
    if robs_eval(answer, response):
        total_correct += 1
        print("total = " + str(total_correct))

    data.append({'Query': query, 'Correct Answer': answer, 'Model Response': response})

df_results = pd.DataFrame(data)
print(f"Total correct answers: {total_correct}/{len(test_answers)}")