# Retrieval Augmented Generation - RAG 

We will use the RAG technique to use language models to attempt to solve the multi-class, multi-label classification problem.

In [None]:
from dspy.retrieve.qdrant_rm import QdrantRM
from qdrant_client import QdrantClient
import dspy
import json
from typing import List, Dict, Tuple
from utils import parse_sgm_to_dataframe

## Data Exploration

In [None]:
df = parse_sgm_to_dataframe('../../data/reuters21578/reut2-000.sgm')

In [None]:
def check_topics_in_file(topic_list, file_path):
    # Read topics from the file
    with open(file_path, 'r') as file:
        file_topics = set(file.read().splitlines())
    
    # Check if all topics in topic_list are in file_topics
    missing_topics = set(topic_list) - file_topics
    
    if not missing_topics:
        return True, []
    else:
        return False, list(missing_topics)

In [None]:
topic_list = df['Topic'].unique()

In [None]:
topic_list = [topic for topic in topic_list if topic != '']

In [None]:
check_topics_in_file(topic_list, '../../data/reuters21578/all-topics-strings.lc.txt')

In [None]:
topic_file = '../../data/reuters21578/all-topics-strings.lc.txt'
with open(topic_file, 'r') as file:
    file_topics = set(file.read().splitlines())

In [None]:
topics = list(file_topics)

In [None]:
topics = [s.strip() for s in topics]

In [None]:
topics.sort()

In [None]:
df.shape

In [None]:
df.head(20)

In [None]:
def parse_reuters_dataframe(df):
    # Initialize a dictionary to store bodies and topics
    articles = {}
    
    for _, row in df.iterrows():
        article_id = row['ID']
        body = row['Body']
        topic = row['Topic']
        
        # If the article_id is not yet in the dictionary, add it
        if article_id not in articles:
            articles[article_id] = {'body': body, 'topics': []}
        
        # Append the topic to the list of topics if it's not blank
        if pd.notna(topic) and topic.strip() != "":
            articles[article_id]['topics'].append(topic)
    
    # Convert the dictionary to two lists
    bodies = [data['body'] for data in articles.values()]
    topics = [data['topics'] for data in articles.values()]
    
    return bodies, topics

In [None]:
bodies, topics = parse_reuters_dataframe(df)

In [None]:
len(bodies)

In [None]:
len(topics)

In [None]:
topics[4]

In [None]:
bodies[4]

## Find suitable pre-trained models

In [None]:
import logging

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

In [None]:
def clean_json_string(json_str: str) -> str:
    # Remove the backticks and the "json" text
    return json_str.replace('```json\n', '').replace('\n```', '')

In [None]:
import re

def split_sentences(text):
    # This regex splits sentences but ignores periods in common abbreviations
    pattern = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s'
    sentences = re.split(pattern, text)
    return [s.strip() for s in sentences if s.strip()]

In [None]:
def parse_ollama_output(output_str: str, clean_values: bool = True) -> List[str]:
    if clean_values:
        output_str = clean_json_string(output_str)
    output_dict = json.loads(output_str)
    predicted_classes = [key for key, value in output_dict.items() if value == 1]
    return predicted_classes

In [None]:
def build_retriever_client(labels: List[str], collection_name: str, k: int, vectorizer: str = None) -> QdrantRM:
    client = QdrantClient(":memory:")
    ids = list(range(len(labels)))
    
    if vectorizer:
        client.set_model(vectorizer)
        
    client.add(
        collection_name=collection_name,
        documents=labels,
        ids=ids
    )
    return QdrantRM(collection_name, client, k=k)

In [None]:
class ClassifyText(dspy.Signature):
    """Classify the news article into multiple topic labels from the given candidates. 
    It is possible to have no label, a single label, or multiple labels. You should return the 
    extracted information as a single JSON string with a key for each candidate topic label and a value of
    1 if the article is about the topic and 0 otherwise. There should be no
    text or explanation, only the JSON. For example if there 
    were 3 candidates you could have the following output:

    {
        "label_1": 1,
        "label_2": 0,
        "label_3": 1
    }"""
    text = dspy.InputField()
    label_candidates = dspy.InputField(desc="List of candidate labels for the text")
    article_labels = dspy.OutputField(desc="Dictionary of candidate labels, 1 or 0, for the text")

In [None]:
class RAGMultiLabelClassifier(dspy.Module):
    def __init__(self, custom_retriever, num_candidates=10):
        super().__init__()
        self.retrieve = custom_retriever
        self.classify = dspy.Predict(ClassifyText)
        self.num_candidates = num_candidates
    
    def forward(self, text):
        sentences = split_sentences(text)
        all_retrieved_labels = set()
        
        for sentence in sentences:
            retrieved_docs = self.retrieve(sentence, k=self.num_candidates)
            sentence_labels = [doc['long_text'] for doc in retrieved_docs]
            all_retrieved_labels.update(sentence_labels)
        
        retrieved_labels = ','.join(all_retrieved_labels)
        print(f"Retrieved labels: {retrieved_labels}")
        
        classification_result = self.classify(text=text, label_candidates=retrieved_labels)
        result = classification_result.article_labels
        result = clean_json_string(result)
        
        logger.debug(f"Raw classification result: {result}")
        
        try:
            parsed_result = json.loads(result)
        except json.JSONDecodeError:
            # If JSON parsing fails, try to extract a dictionary-like structure
            import re
            dict_match = re.search(r'\{.*\}', result, re.DOTALL)
            if dict_match:
                dict_str = dict_match.group(0)
                try:
                    parsed_result = eval(dict_str)
                except:
                    parsed_result = {"wrong": 1}  # Fallback to hard-coded wrong output
            else:
                parsed_result = {"wrong": 1}  # Fallback to hard-coded wrong output
        
        # Ensure the output is a dictionary
        if not isinstance(parsed_result, dict):
            parsed_result = {"wrong": 1}  # Fallback to hard-coded wrong output
        
        logger.debug(f"Final parsed result: {parsed_result}")
        return parsed_result

## Explore the use of models pretrained on article data

In [None]:
vectorizer = "BAAI/bge-large-en-v1.5"
ollama_model_name = 'gemma2'

In [None]:
retriever_model = build_retriever_client(labels=topic_list, 
                                         collection_name="reuters", 
                                         k=10, 
                                         vectorizer=vectorizer)

In [None]:
class CustomOllamaLocal(dspy.OllamaLocal):
    def __init__(self, model, **kwargs):
        logger.debug(f"Initializing CustomOllamaLocal with model: {model}")
        self.model = model  # Explicitly set the model attribute
        super().__init__(model=model, **kwargs)
        
    def copy(self, **kwargs):
        logger.debug(f"Copying CustomOllamaLocal with kwargs: {kwargs}")
        new_kwargs = self.__dict__.copy()
        new_kwargs.update(kwargs)
        return CustomOllamaLocal(**new_kwargs)
    
    def basic_request(self, prompt, **kwargs):
        logger.debug(f"Making basic request with model: {self.model}")
        return super().basic_request(prompt, **kwargs)

In [None]:
ollama_model = CustomOllamaLocal(
    model=ollama_model_name, 
    model_type='text',
    max_tokens=512,
    temperature=0,
    top_p=1,
    frequency_penalty=0,
    top_k=10,
    format='json'
)

In [None]:
dspy.settings.configure(lm=ollama_model, rm=retriever_model)
classifier = RAGMultiLabelClassifier(custom_retriever=retriever_model, num_candidates=10)

In [None]:
def calculate_metrics(ground_truth: List[List[str]], predictions: List[List[str]]) -> Dict[str, float]:
    tp, fp, fn = 0, 0, 0

    for gt_labels, pred_labels in zip(ground_truth, predictions):
        gt_set = set(gt_labels)
        pred_set = set(pred_labels)

        tp += len(gt_set & pred_set)
        fp += len(pred_set - gt_set)
        fn += len(gt_set - pred_set)

    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

    return {"precision": precision, "recall": recall, "f1_score": f1_score}

In [None]:
import random

def sample_bodies_and_topics(bodies, topics, num_samples, random_seed=None):
    # Ensure the length of bodies and topics are the same
    assert len(bodies) == len(topics), "Bodies and topics lists must be of the same length."
    
    # Set the random seed if provided
    if random_seed is not None:
        random.seed(random_seed)
    
    # Create a list of indices and sample from them
    indices = list(range(len(bodies)))
    sampled_indices = random.sample(indices, num_samples)
    
    # Create the sampled lists for bodies and topics
    sampled_bodies = [bodies[i] for i in sampled_indices]
    sampled_topics = [topics[i] for i in sampled_indices]
    
    return sampled_bodies, sampled_topics

In [None]:
bodies, topics = sample_bodies_and_topics(bodies, topics, num_samples=10, random_seed=42)

In [None]:
predictions = []
raw_results = []

for i, (topic, body) in enumerate(zip(topics, bodies)):
    result_str = classifier(text=body)
    try:
        if isinstance(result_str, str):
            predicted_classes = parse_ollama_output(result_str)
        else:
            predicted_classes = [k for k, v in result_str.items() if v == 1]
        predictions.append(predicted_classes)

        raw_results.append({
            'body': body,
            'predicted_labels': json.dumps(predicted_classes),
            'actual_labels': json.dumps(topic)
        })
    except json.JSONDecodeError:
        print("Warning! Could not parse output from Ollama. Skipping this result.")
        print(f'Body: {body}')
        print(f'Result string: {result_str}')
        continue

metrics = calculate_metrics(bodies, predictions)

## Optimize our RAG pipeline

## Explore retrieval improvements

Retrieval is in-expensive so it is in most cases a good tradeoff to do more on the retrieval side in order to ensure that we have a good list of candidate labels that includes the true labels.

To explore:
- Sentence splitting
- Retrieval ensemble