# About

We have a RAG-Bot for the BiofidPortal, which is a LLM that fetches context through a vector database and with this context, answers user context. The problem which we try to solve in this notebook is:

- When do we need to fetch context and when don't we?

Because as off now, with every user input, we fetch context that fits the inputs but often enough, that's useless and destructive ("What does that mean?", "I dont understand that"). 

To fix that, we train a classifier that checks first: given the input, should we fetch more context? **That's the notebook's premise**.

----------------------

**Sources:**
- [Article about fine-tuning BERT for Classification](https://medium.com/@khang.pham.exxact/text-classification-with-bert-7afaacc5e49b)

In [197]:
import os
import torch
import random
import re
import uuid
import numpy as np
import json
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
from openai import OpenAI

print('Imports OK.')

Imports OK.


# Dataset

Load in the dataset and create the custom Dataset class.

## Create

We firstly need to create the dataset. The dataset could be in the form of:

```
{
    'chat_history': [
        'User: I need info about birds, do you have documents for that?',        
    ],
    'fetch_context': True,
},
{
    'chat_history': [
        'User': 'I need info about birds, do you have documents for that?',  
        'System': 'Yes, we have documents for that. What would you like to know?',
        'User': 'Tell me more about them in general.'       
    ],
    'fetch_context': False,
},
```

I haven't decided on the exact form yet, we will see.

In [45]:
class ChatGPT:

    def __init__(self):
        pass

    def complete(self, messages, api_key):
        client = OpenAI(api_key = api_key)
        response = client.chat.completions.create(
            model='gpt-3.5-turbo',
            messages = messages
        )
        return response.choices[0].message.content

## Dataset Creation Loop

In [111]:
judge_template = ''''
Below you find a chat history between a user and a Retrieval Augmented Generation Model. 
The problem: the RAG doesn't know when it should fetch context, when it shouldn't and when the user references old context. 
- Whenever the user opens a new subject or explicitly references one in his question, the RAG Model needs context.
- The RAG Model doesn't need new context when the user just chit-chats or asks follow up questions like "What does that mean" or "Elaborate". 
Given the chat history, I want you to print 1 if any or new context is needed and 0 if it's not.

------
History: {HISTORY}
------
'''

chat_start = '''
You simulate a chat between a user of an online search portal and a Retrieval Augmented Generation Model 
of that search portal which helps users find documents and explains topics. 
Below you find a chat history. Continue the chat with the appropriate role as described in the following:
- The user asks for information on the given topic below, while also simply chit-chatting or asking follow up 
questions like "What does that mean?" or "Could you elaborate?". The user should occasionally switch topic and subject as well.
- For the RAG model, assume the online portal contains all documents on any topics and 
can therefore fetch all context needed to answer the questions. As the model, reply shortly and politely.

Keep the chatting concise and short. Continue with a single turn. The topic of this chat is: {TOPIC}

------
'''

In [94]:
def random_odd_number(range):
    number = random.randint(1, range)
    if number % 2 != 0:
        return number
    else:
        return number + 1

In [95]:
def txt_to_array(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        arrays = [line.strip() for line in lines]
    return arrays

random_words = txt_to_array('rag/src/cBERT/random_words.txt')
print(random_words[:5])

['potato', 'perpetual', 'screeching', 'man', 'advertisement']


In [103]:
def write_to_json_file(data, file_path):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=4)

In [114]:
gpt = ChatGPT()
iterations = 10000
datasets = []
api_key = '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'

for i in range(0, iterations):
    try:
        print('Doing iteration ' + str(i))
        # First, we need to create a potential chat which we judge to be context required or not.
        chat = chat_start
        chat_id = str(uuid.uuid4())
        # We want a variety of topics the chat is about. Otherwise we get chats about pizza all the time.
        topic = random.choice(random_words)
        chat = chat.replace('{TOPIC}', topic)
        chat += 'System: Hello, how may I help you today?'
        turns = random_odd_number(7)

        for k in range(1, turns + 1):
            if(k % 2 == 1):
                chat += '\nUser: '
            else:
                chat += '\nSystem: '

            continuation = gpt.complete([{'role': 'system', 'content': chat}], api_key)
            chat += continuation

            # We store and judge those chats that end with a User turn
            if(k % 2 == 0):
                continue
            
            only_turns = chat.split('------')[1]
            #print(only_turns)
            
            # Judge the chat whether we needed context to answer that specific turn or not
            prompt = judge_template.replace('{HISTORY}', only_turns)
            fetch_context = gpt.complete([{'role': 'system', 'content': prompt}], api_key)
            #print(fetch_context)

            datasets.append({
                'chat_id': chat_id,
                'topic': topic,
                'chat': only_turns,
                'fetch_context': fetch_context
            })
    except Exception as ex:
        print("Error trying to create and judge chats, skipping one iteration:\n")
        print(ex)

write_to_json_file(datasets, 'rag/src/cBERT/data/context_chats_10k.json')
    


Doing iteration 0
Doing iteration 1
Doing iteration 2
Doing iteration 3
Doing iteration 4
Doing iteration 5
Doing iteration 6
Doing iteration 7
Doing iteration 8
Doing iteration 9
Doing iteration 10
Doing iteration 11
Doing iteration 12
Doing iteration 13
Doing iteration 14
Doing iteration 15
Doing iteration 16
Doing iteration 17
Doing iteration 18
Doing iteration 19
Doing iteration 20
Doing iteration 21
Doing iteration 22
Doing iteration 23
Doing iteration 24
Doing iteration 25
Doing iteration 26
Doing iteration 27
Doing iteration 28
Doing iteration 29
Doing iteration 30
Doing iteration 31
Doing iteration 32
Doing iteration 33
Doing iteration 34
Doing iteration 35
Doing iteration 36
Doing iteration 37
Doing iteration 38
Doing iteration 39
Doing iteration 40
Doing iteration 41
Doing iteration 42
Doing iteration 43
Doing iteration 44
Doing iteration 45
Doing iteration 46
Doing iteration 47
Doing iteration 48
Doing iteration 49
Doing iteration 50
Doing iteration 51
Doing iteration 52
Doi

## Chit-Chat dataset

Since the created dataset has a bias towards creating chats which need context, we add to that some random chit-chat form the [chit-chat-dataset](https://github.com/microsoft/botframework-cli/blob/main/packages/qnamaker/docs/chit-chat-dataset.md) which doesn't require context as these questions and chats are pretty much irrelevant.

By doing so, we augment the data and add `non_context` required chats to our dataset without costs or efforts.

In [259]:
chit_chat_df = pd.read_csv('~/home/biofid/BioFIDPortal/rag/src/cBERT/data/chit-chat_dataset.tsv', sep='\t')
chit_chat_df['chat'] = chit_chat_df['Question']
chit_chat_df['fetch_context'] = 0
chit_chat_df['chat_id'] = uuid.uuid4()
chit_chat_df['topic'] = 'chit-chat'
chit_chat_df = chit_chat_df[['chat', 'fetch_context', 'chat_id', 'topic']]
display(chit_chat_df.head())
print(len(chit_chat_df))

Unnamed: 0,chat,fetch_context,chat_id,topic
0,Do you get hurt?,0,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat
1,Do you have fingers?,0,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat
2,Do you ever breathe,0,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat
3,Do you masticate?,0,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat
4,Can you throw up?,0,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat


9793


## Load

In [260]:
def load_dataset(file_path):
    # Read in our created dataset
    df = pd.read_json(file_path)

    # Foreach chat, we only look at the last turn which the user did.
    df['chat'] = df['chat'].apply(lambda c: c.split('\n')[-1])
    # Drop empty string
    df['chat'].replace('', np.nan, inplace=True)
    df.dropna(subset=['chat'], inplace=True)
    df['chat'] = df['chat'].apply(lambda c: c.replace('User: ', ''))

    # Clean the dataset. Sometimes the output of GPT had more than just "1" or "0".
    df['fetch_context'] = df['fetch_context'].apply(lambda x: re.sub(r'[^01]', '', x))
    df['fetch_context'] = df['fetch_context'].astype(int)

    # and merge it with the chit-chat dataset
    result = pd.concat([df, chit_chat_df])

    # So there is a funny phenomenon I witnessed with the first training rounds:
    # The model very heavy relies on ! . ? at the end of every sentence. It pretty much decides over
    # whether context is needed or not, but that's not desired. User's often enough
    # forget the proper sentence endings and I don't want the model to rely on them, so I delete them.
    result['chat'] = result['chat'].apply(lambda c: c[:-1] if c.endswith(('.', '?', '!')) else c)

    print('Total dataset length: ' + str(len(result)))
    print('Chats with would not require context fetching: ' + str(len(result[result['fetch_context'] == 0])))
    print('Chats with would require context fetching: ' + str(len(result[result['fetch_context'] == 1])))
    display(result.sample(15))

    texts = result['chat'].tolist()
    labels = result['fetch_context'].tolist()
    return texts, labels

In [261]:
path = '~/home/biofid/BioFIDPortal/rag/src/cBERT/data/context_chats_10k.json'
texts, labels = load_dataset(path)

Total dataset length: 34984
Chats with would not require context fetching: 11637
Chats with would require context fetching: 23347


Unnamed: 0,chat_id,topic,chat,fetch_context
6737,68980e39-5c7f-4591-a4d3-e49ae6dbfc1b,needle,Hi! Can you tell me about the different types ...,1
12182,7312ee30-2c84-448c-ba41-f6649eb91de0,normal,"Hi there! I'm curious about the concept of ""no...",1
9120,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat,I'm planning to end it all today,0
7464,0c84b3e9-78d0-4760-ad71-3784957d6f3f,ancient,Hi! I'm interested in learning more about anci...,1
3248,4d371a4e-eacb-4fef-9e3c-bec25693d38f,wall,What else would you like to know about the Gre...,1
898,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat,You seem really upbeat,0
21921,f9f08787-ea2c-4d94-9b5f-e125bb207804,modern,That's a great summary of Impressionism! Do yo...,0
16727,f699f890-8c0c-429c-b5f4-b40d3fc0a88c,clear,"Hi there! I'm curious about the concept of ""cl...",1
27029,d4ba1061-8953-4e5b-a4b9-af965708b939,brake,Which type of brake is more common in modern cars,1
5683,51e591fa-6874-4188-95d4-a10d5781040e,chit-chat,Bullseye,0


**Now that we have our dataset, create our own Dataset class and let's begin with the model and training.**

In [262]:
class TextClassificationDataset(Dataset):
    
    def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

# cBERT

Create the BERT Classifier called "cBERT". For that, we create a class wrapper around the BERT.

In [263]:
class cBERT(nn.Module):
    
    def __init__(self, bert_model_name, num_classes):
        super(cBERT, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            x = self.dropout(pooled_output)
            logits = self.fc(x)
            return logits

# Train

We need a train loop, with evaluation and checkpoints.

In [264]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [265]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

The prediction method. For now, this is still Sentiment. Change that later for our purpose.

In [266]:
def predict_context_needed(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    
    print(preds.item())
    return "context_needed" if preds.item() == 1 else "context_not_needed"

Hyperparameters, splitting the data, getting ready the training process.

In [267]:
# Set up parameters
bert_model_name = 'google-bert/bert-base-multilingual-cased' # We want to support multiple languages 
num_classes = 2
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5

In [268]:
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)
print(len(train_texts))
print(len(val_texts))
print(len(train_labels))
print(len(val_labels))


27987
6997
27987
6997


Load the corresponding tokenizer.

In [269]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

Load in the model and put it onto the correct device.

In [270]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = cBERT(bert_model_name, num_classes).to(device)

In [271]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)



**The Training Loop**

In [272]:
for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train(model, train_dataloader, optimizer, scheduler, device)
        accuracy, report = evaluate(model, val_dataloader, device)
        print(f"Validation Accuracy: {accuracy:.4f}")
        print(report)

Epoch 1/4
Validation Accuracy: 0.9495
              precision    recall  f1-score   support

           0       0.97      0.88      0.92      2371
           1       0.94      0.98      0.96      4626

    accuracy                           0.95      6997
   macro avg       0.95      0.93      0.94      6997
weighted avg       0.95      0.95      0.95      6997

Epoch 2/4
Validation Accuracy: 0.9491
              precision    recall  f1-score   support

           0       0.98      0.87      0.92      2371
           1       0.94      0.99      0.96      4626

    accuracy                           0.95      6997
   macro avg       0.96      0.93      0.94      6997
weighted avg       0.95      0.95      0.95      6997

Epoch 3/4
Validation Accuracy: 0.9480
              precision    recall  f1-score   support

           0       0.95      0.89      0.92      2371
           1       0.95      0.98      0.96      4626

    accuracy                           0.95      6997
   macro avg  

Save the model.

In [273]:
model_path = "rag/src/cBERT/models/cBERT.pth"
torch.save(model.state_dict(), model_path)

# Inference/Prediction

Use the just created model for prediction/inference.

In [281]:
test_text = "What does that mean"
context_needed = predict_context_needed(test_text, model, tokenizer, device)
print(f"Predicted context: {context_needed}")

0
Predicted context: context_not_needed
