In [1]:
import os
import logging

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

In [2]:
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

# ------------- Retrieval-Augmented Generation  ------------- #

def get_docs():
    """
    Loads each file into one document, like knowledge base
    :return: docs
    """

    loader = DirectoryLoader("docs/lyrics", "*.txt", loader_cls=TextLoader)  # Reads custom data from local files

    docs = loader.load()
    return docs

def split_text(docs):
    """
    Get chunks from docs. Our loaded doc may be too long for most models, and even if it fits is can struggle to find relevant context. So we generate chunks
    :param docs: docs to be split
    :return: chunks
    """

    text_splitter = RecursiveCharacterTextSplitter( # recommended splitter for generic text
        chunk_size=2000,
        chunk_overlap=200,
        add_start_index=True
    )
    chunks = text_splitter.split_documents(docs)

    return chunks

def get_data_store(chunks):
    """
    Store chunks into a db. ChromaDB uses vector embeddings as the key, creates a new DB from the documents
    :param docs:
    :param chunks:
    :return: database
    """
    embeddings = HuggingFaceEmbeddings( #  embedding=OpenAIEmbeddings() rate limit
        model_name='sentence-transformers/all-MiniLM-L6-v2',
        model_kwargs={'device': 'cpu'} # TODO gpu
    )

    db = Chroma.from_documents(
        documents=chunks,
        embedding=embeddings
    )
    return db

In [3]:
import os, sys, warnings
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

docs = get_docs()           # Load custom files
chunks = split_text(docs)   # Split into chunks
db = get_data_store(chunks) # Generate vectorstore

  embeddings = HuggingFaceEmbeddings( #  embedding=OpenAIEmbeddings() rate limit
  from tqdm.autonotebook import tqdm, trange
