<a href="https://colab.research.google.com/github/williamconvertino/RAG-BERT-GPT2/blob/main/RAG_with_BERT_and_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**RAG implementation using BERT and GPT2**

**Packages**

In [None]:
!pip3 install -q -U tensorflow==2.15.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m475.2/475.2 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-text 2.16.1 requires tensorflow<2.17,>=2.16.1; platform_machine != "arm64" or platform_system != "Darwin", but you have tensorflow 2.15.0 which is incompatible.[0m[31m
[0m

In [None]:
!pip3 install -q -U tensorflow-text==2.15.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
! pip3 install -q pymilvus

In [None]:
import numpy as np
import os
import shutil
import re
import time
import csv

In [None]:
from google.colab import drive
from google.colab import auth
from google.colab import files

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras

In [None]:
from transformers import GPT2Tokenizer, pipeline

In [None]:
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Milvus

**Model URLs**

In [None]:
BERT_ENCODER_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3"
BERT_PREPROCESSOR_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

**Google Drive**

In [None]:
drive.mount('/content/drive/')
HOME_DIR = '/content/drive/My Drive/RAG_BERT_GPT'
os.chdir(HOME_DIR)

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


**Milvus Setup**

In [None]:
milvus_uri=open('milvus_uri.txt').read().strip()
token=open('milvus_api_key.txt').read().strip()
connections.connect("default", uri=milvus_uri, token=token)

In [None]:
SEARCH_PARAMS = {
    'metric_type': "COSINE",
    'index_type': "HNSW",
    'params': {
        "M": 32,
        "efConstruction": 64
    },
    'auto_tune': True
  }
COLLECTION_NAME = 'document_embeddings'
DOCUMENT_EMBEDDING_COLLECTION = Collection(COLLECTION_NAME)

**Database Initialization**

In [None]:
DOCUMENT_DIRECTORY = '/content/drive/MyDrive/RAG_BERT_GPT/news_database'

In [None]:
PUNC_REGEX = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s')

In [None]:
def download_dataset():
  KAGGLE_KEY_DIR = os.path.join(HOME_DIR, 'kaggle.json')

  !pip3 install -q kaggle
  !mkdir ~/.kaggle
  !cp "{KAGGLE_KEY_DIR}" ~/.kaggle/
  !chmod 600 ~/.kaggle/kaggle.json
  !pip install kaggle
  !kaggle datasets download -d jeet2016/us-financial-news-articles
  !unzip us-financial-news-articles.zip

In [None]:
def build_schema():

  if utility.has_collection(COLLECTION_NAME):
        utility.drop_collection(COLLECTION_NAME)

  id = FieldSchema(name='id', dtype=DataType.VARCHAR, max_length=10, is_primary=True)
  embedding = FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=768)

  schema = CollectionSchema(fields=[id, embedding], description="BERT Embeddings for documents")

  collection = Collection(COLLECTION_NAME, schema)

  collection.create_index(field_name='embedding', index_params=SEARCH_PARAMS)

  collection.load()
  return collection

In [None]:
def build_database(max_chunk_size = 500, limit=2000):
  DOCUMENT_EMBEDDING_COLLECTION = build_schema()
  DOCUMENT_EMBEDDING_COLLECTION.load()

  BERT_preprocessor = hub.load(BERT_PREPROCESSOR_URL)
  BERT_encoder = hub.load(BERT_ENCODER_URL)

  kaggle_dir = '/content/bbc-text.csv'

  # Lettering helps avoid issues with Google Drive storage
  file_letters = ['A', 'B', 'C', 'D', 'E']
  file_letter_index = [0, 0, 0, 0, 0]
  file_index = 0
  file_count = 0

  for file_letter in file_letters:
    dir_path = os.path.join(DOCUMENT_DIRECTORY, file_letter)
    if os.path.exists(dir_path):
      shutil.rmtree(dir_path)

  with open(kaggle_dir, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      text = row['text']

      if len(text) <= max_chunk_size:
        chunks = [text]
      else:
        sentences = PUNC_REGEX.split(text)
        sentences = [x for x in sentences if x.strip()]

        chunks = []
        current_chunk = ''

        for sentence in sentences:
          if len(current_chunk) + len(sentence) <= max_chunk_size:
            current_chunk += sentence
          else:
            chunks.append(current_chunk)
            current_chunk = sentence

        chunks.append(current_chunk)

      for chunk in chunks:

        document_embedding = BERT_encoder(BERT_preprocessor(chunk)).numpy()[0]

        doc_path = os.path.join(DOCUMENT_DIRECTORY, file_letters[file_index])
        doc_filename = file_letters[file_index] + '_' + str(file_letter_index[file_index]).zfill(8)
        file_letter_index[file_index] = file_letter_index[file_index] + 1

        DOCUMENT_EMBEDDING_COLLECTION.insert([[doc_filename], [document_embedding]])

        if not os.path.exists(doc_path):
          os.makedirs(doc_path)

        with open(os.path.join(doc_path, doc_filename + '.txt'), 'w') as doc_file:
          doc_file.write(chunk)

        print(f'Saved {doc_filename} to {doc_path}')
        file_count = file_count + 1

        if file_count >= limit:
          print("Reached limit.")
          return

      file_index = (file_index + 1) % len(file_letters)

**BERT Document Retrieval**

In [None]:
class BERTDocumentRetrieval:

  def __init__(self, collection=DOCUMENT_EMBEDDING_COLLECTION, search_params=SEARCH_PARAMS, doc_directory=DOCUMENT_DIRECTORY):
    self.collection = collection
    self.doc_directory = doc_directory
    self.search_params = search_params
    self.BERT_preprocessor = hub.load(BERT_PREPROCESSOR_URL)
    self.BERT_encoder = hub.load(BERT_ENCODER_URL)

  def encode_text(self, text):
    return self.BERT_encoder(self.BERT_preprocessor([text]))['pooled_output'].numpy()[0]

  def get_doc_content(self, doc_id):
    dir_path = self.doc_directory
    letter_id = doc_id.split('_')[0]
    file_path = os.path.join(dir_path, letter_id, doc_id + '.txt')

    with open(file_path, 'r') as file:
      return file.read()

  def get_k_nearest_docs(self, query, k=10, verbose=False):
    query_embedding = self.encode_text(query)
    self.collection.load()
    results = self.collection.search(anns_field='embedding', data=[query_embedding], limit=k, param=self.search_params)[0]

    content = []

    for doc in results:
      doc_id = doc.id
      doc_content = self.get_doc_content(doc_id)
      content.append(doc_content)
      if verbose:
        print('='*10)
        print(doc_id)
        print(doc_content)

    return content

**RAG with GPT**

In [None]:
class GPTRAG:
  def __init__(self, document_retrieval):
    self.dr = document_retrieval
    self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    self.generator = pipeline('text-generation', model='gpt2')

  def generate_text(self, prompt, max_new_tokens=30, num_return_sequences=5, verbose=False):

    if verbose:
      print('='*10)
      print("ORIGINAL PROMPT:")
      print(prompt)
      responses = self.generator(prompt, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences)
      responses = [response['generated_text'] for response in responses]
      print('='*10)
      print("DEFAULT GPT RESPONSES:")
      for response in responses:
        print(response)

    documents = self.dr.get_k_nearest_docs(prompt, k=10)

    MAX_TOKENS = 512
    modified_prompt = prompt

    for doc in documents:
      num_doc_tokens = len(self.tokenizer(doc + ' ' + modified_prompt)['input_ids'])
      if num_doc_tokens >= MAX_TOKENS:
        break
      else:
        modified_prompt = doc + ' ' + modified_prompt

    if verbose:
      print('='*10)
      print("RAG PROMPT:")
      print(modified_prompt)

    responses = self.generator(modified_prompt, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences)

    responses = [response['generated_text'][len(modified_prompt) - len(prompt):] for response in responses]

    if verbose:
      print('='*10)
      print("RAG RESPONSES:")
      for response in responses:
        print(response)

    return responses

**Example Usage**

In [None]:
doc_retrieval = BERTDocumentRetrieval()

In [None]:
prompt = "seamen on the luxury cruise liner crystal harmony test a new technology. Holidaymakers enjoy balmy breezes as their crew tests the world's"

In [None]:
docs = doc_retrieval.get_k_nearest_docs(prompt, k=10, verbose=True)

A_00000142
seamen sail into biometric future the luxury cruise liner crystal harmony  currently in the gulf of mexico  is the unlikely setting for tests of biometric technology. as holidaymakers enjoy balmy breezes  their ship s crew is testing prototype versions of the world s first internationally issued biometric id cards  the seafarer s equivalent of a passport.
B_00000190
the  ticking budget  facing the us the budget proposals laid out by the administration of us president george w bush are highly controversial.the washington-based economic policy institute  which tends to be critical of the president  looks at possible fault lines.us politicians and citizens of all political persuasions are in for a dose of shock therapy.without major changes in current policies and political prejudices  the federal budget simply cannot hold together.
A_00000075
telegraph newspapers axe 90 jobs the daily and sunday telegraph newspapers are axing 90 journalist jobs - 17% of their editorial staff. 

In [None]:
rag = GPTRAG(doc_retrieval)

In [None]:
responses = rag.generate_text(prompt, verbose=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


ORIGINAL PROMPT:
seamen on the luxury cruise liner crystal harmony test a new technology. Holidaymakers enjoy balmy breezes as their crew tests the world's
DEFAULT GPT RESPONSES:
seamen on the luxury cruise liner crystal harmony test a new technology. Holidaymakers enjoy balmy breezes as their crew tests the world's most popular brands before opening the doors for the first time ever to Disney World to celebrate 100 years of the classic theme for a limited period. Disney,
seamen on the luxury cruise liner crystal harmony test a new technology. Holidaymakers enjoy balmy breezes as their crew tests the world's newest luxury luxury liner by the name of the Star Wars Star Wars. From a new starfighter to a new rocket launcher, the Star Wars starfighter
seamen on the luxury cruise liner crystal harmony test a new technology. Holidaymakers enjoy balmy breezes as their crew tests the world's signature holiday theme. The holiday mood continues with the first-half of the season.

1.10.12 The fin

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


RAG PROMPT:
boeing unveils new 777 aircraft us aircraft firm boeing has unveiled its new long-distance 777 plane  as it tries to regain its position as the industry s leading manufacturer. the 777-200lr will be capable of flying almost 11 000 miles non-stop  linking cities such as london and sydney.boeing  in contrast to european rival airbus  hopes airlines will want to fly smaller aircraft over longer distances. apple has sold more than six million ipods since the gadget was launched and has an 87% share of the market for portable digital music players  market research firm npd group has reported. more than 200 million songs have been sold by the itunes music store since it was launched. telegraph newspapers axe 90 jobs the daily and sunday telegraph newspapers are axing 90 journalist jobs - 17% of their editorial staff. the telegraph group says the cuts are needed to fund an £150m investment in new printing facilities.journalists at the firm met on friday afternoon to discuss how to

In [None]:
responses = rag.generate_text("The economy in the UK is ", verbose=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


ORIGINAL PROMPT:
The economy in the UK is 
DEFAULT GPT RESPONSES:
The economy in the UK is  growing faster than any other aspect of the global economy.
This growth is clearly reflected in the value of the pound, which rose 0.2
The economy in the UK is  still a large part of my view as a person as there are still people out there that say that the economy will be better as a result of
The economy in the UK is  still strong, albeit  a lot weaker than after the Great Recession. There is still an unemployment rate that's higher than  we would like for
The economy in the UK is  weak, and we are seeing growth slows in Europe. It's time to start cutting down on wasteful and unsustainable spending - something that will help us
The economy in the UK is  still quite small compared to other emerging economies after a huge loss of £50bn in GDP last year.
A new study by economists at the


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


RAG PROMPT:
he later secured the film rights from paramount  enabling them to use the title it s a wonderful life. under the us foreign corrupt practices act  it is a crime for american firms to bribe foreign officials. there have been rumours that the deal could be in trouble because us government agencies fear it could offer china opportunities for industrial espionage.the reports of the possibility of an investigation into the risk sent lenovo s shares up 6% in late january. the payments are shown in bands of up £5 000  making it difficult to calculate the exact earnings. but that is the kind of problem most gadget fans can live with. if labour voters  stayed at home  in marginal seats they could see tory leader michael howard  coming in the back door to number 10 with the tradesman s key to number 10  getting into power   he added. in the album charts  athlete s latest offering tourist claimed the top spot  toppling the chemical brother s push the button down  which fell to number 