# Chapter 5 (Cohere)

<a target="_blank" href="https://colab.research.google.com/github/wandb/edu/blob/rag-irl/rag-advanced/notebooks/Chapter05Cohere.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

<!--- @wandbcode{rag-course-cohere} -->

In [None]:
# !pip install -qqq markdown pymdown-extensions beautifulsoup4 wandb tiktoken blingfire numpy cohere python-dotenv scipy pandas weave

## Data Loading

In [None]:
import json
import pathlib
from datetime import datetime

import tiktoken

import wandb

WANDB_ENTITY = "wandbot"
WANDB_PROJECT = "advanced_rag"

wandb.require("core")

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type="data_loading", group="ingestion")

In [None]:
docs_dir = pathlib.Path("../data/wandb_docs_06_24")
docs_files = sorted(docs_dir.rglob("*.md"))

print(f"Number of files: {len(docs_files)}\n")
print("First 5 files:\n{files}".format(files='\n'.join(map(str, docs_files[:5]))))

In [None]:
tokenizer = tiktoken.get_encoding("cl100k_base")

In [None]:
data = []
for file in docs_files:
    content = file.read_text()
    data.append({
        "content": content,
        "metadata": {
            "source": str(file.relative_to(docs_dir)),
            "raw_tokens": len(tokenizer.encode(content))
        }})
data[:2]

In [None]:
total_tokens = sum(map(lambda x: x["metadata"]["raw_tokens"], data))
print(f"Total Tokens in dataset: {total_tokens}")

In [None]:
raw_artifact = wandb.Artifact(name="raw_data", type="dataset",
description="Wandb documentation", metadata={
    "total_files": len(docs_files),
    "date_downloaded": datetime.now().strftime("%Y-%m-%d"),
    "total_tokens": total_tokens
    })
with raw_artifact.new_file("documents.jsonl", mode="w") as f:
    for item in data:
        f.write(json.dumps(item) + "\n")
run.log_artifact(raw_artifact)
run.finish()

# Data Parsing and pre-processing

1. Load and parse the markdown document with [pymdownx](https://facelessuser.github.io/pymdown-extensions/) library and convert them to html.
2. Convert the html to text using [BeautifulSoup](https://beautiful-soup-4.readthedocs.io/en/latest/).
3. Remove unnecessary characters and extra spaces.

In [None]:
import re

import markdown
from bs4 import BeautifulSoup

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type="data_processing", group="ingestion")

In [None]:
def convert_contents_to_text(contents: str) -> str:
    markdown_document = markdown.markdown(
        contents,
        extensions=[
            "toc",
            "pymdownx.extra",
            "pymdownx.blocks.admonition",
            "pymdownx.magiclink",
            "pymdownx.blocks.tab",
            "pymdownx.pathconverter",
            "pymdownx.saneheaders",
            "pymdownx.striphtml",
        ],
    )
    soup = BeautifulSoup(markdown_document, "html.parser")
    return soup.get_text()


def make_text_tokenization_safe(content: str) -> str:
    
    # Normalize whitespace including Space, Tab, Newline, Carriage return, Form feed, Vertical tab
    content = re.sub(r'\s+', ' ', content, flags=re.UNICODE)

    special_tokens_set = tokenizer.special_tokens_set

    def remove_special_tokens(text: str) -> str:
        """Removes special tokens from the given text.

        Args:
            text: A string representing the text.

        Returns:
            The text with special tokens removed.
        """
        for token in special_tokens_set:
            text = text.replace(token, "")
        return text

    cleaned_content = remove_special_tokens(content)
    return cleaned_content

In [None]:
raw_artifact = run.use_artifact(f'{WANDB_ENTITY}/{WANDB_PROJECT}/raw_data:latest', type='dataset')
artifact_dir = raw_artifact.download()
raw_data_file = pathlib.Path(f"{artifact_dir}/documents.jsonl")
raw_data = list(map(json.loads, raw_data_file.read_text().splitlines()))
raw_data[:2]

In [None]:
parsed_data = []

for doc in raw_data:
    parsed_doc = doc.copy()
    parsed_doc["parsed_content"]=convert_contents_to_text(doc["content"])
    parsed_doc["metadata"]["parsed_tokens"] = len(tokenizer.encode(parsed_doc["parsed_content"]))
    parsed_data.append(parsed_doc)
parsed_data[:2]

In [None]:
total_parsed_tokens = sum(map(lambda x: x["metadata"]["parsed_tokens"], parsed_data))

preprocessed_artifact = wandb.Artifact(name="preprocessed_data", type="dataset",
description="Preprocessed wandb documentation", metadata={
    "total_files": len(parsed_data),
    "date_preprocessed": datetime.now().strftime("%Y-%m-%d"),
    "total_parsed_tokens": total_parsed_tokens,
    }
)
with preprocessed_artifact.new_file("documents.jsonl", mode="w") as f:
    for item in parsed_data:
        f.write(json.dumps(item) + "\n")
run.log_artifact(preprocessed_artifact)
run.finish()

## Data Chunking

1. First we split the text into sentences using [BlingFire](https://github.com/microsoft/BlingFire) library.
2. Then we split the sentences into chunks of a maximum number of tokens.

In [None]:
from blingfire import text_to_sentences

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type="data_chunking", group="ingestion")

In [None]:
preprocessed_artifact = run.use_artifact(f'{WANDB_ENTITY}/{WANDB_PROJECT}/preprocessed_data:latest', type='dataset')
artifact_dir = preprocessed_artifact.download()
preprocessed_data_file = pathlib.Path(f"{artifact_dir}/documents.jsonl")
preprocessed_data = list(map(json.loads, preprocessed_data_file.read_text().splitlines()))
preprocessed_data[:2]

In [None]:
#ref: https://platform.openai.com/docs/tutorials/web-qa-embeddings

CHUNK_SIZE=500

# Function to split the text into chunks of a maximum number of tokens


def split_into_chunks(text, max_tokens = CHUNK_SIZE):

    # Split the text into sentences
    sentences = text_to_sentences(text).split("\n")

    # Get the number of tokens for each sentence
    n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]

    chunks = []
    tokens_so_far = 0
    chunk = []

    # Loop through the sentences and tokens joined together in a tuple
    for sentence, token in zip(sentences, n_tokens):

        # If the number of tokens so far plus the number of tokens in the current sentence is greater
        # than the max number of tokens, then add the chunk to the list of chunks and reset
        # the chunk and tokens so far
        if tokens_so_far + token > max_tokens:
            chunks.append(" ".join(chunk))
            chunk = []
            tokens_so_far = 0

        # If the number of tokens in the current sentence is greater than the max number of
        # tokens, go to the next sentence
        if token > max_tokens:
            continue

        # Otherwise, add the sentence to the chunk and add the number of tokens to the total
        chunk.append(sentence)
        tokens_so_far += token + 1

    return chunks

In [None]:
chunked_data = []
for doc in preprocessed_data:
    chunks = split_into_chunks(doc["parsed_content"])
    for chunk in chunks:
        chunked_data.append(
            {
                "parsed_content" : chunk,
                "metadata": {
                    "source": doc["metadata"]["source"],
                    "parsed_tokens": len(tokenizer.encode(chunk))
            }})

In [None]:
cleaned_data = []

for doc in chunked_data:
    cleaned_doc = doc.copy()
    cleaned_doc["embeddable_content"] = make_text_tokenization_safe(doc["parsed_content"])
    cleaned_doc["metadata"]["embeddable_tokens"] = len(tokenizer.encode(cleaned_doc["embeddable_content"]))
    cleaned_data.append(cleaned_doc)

cleaned_data[:2]

In [None]:
total_parsed_tokens = sum(map(lambda x: x["metadata"]["parsed_tokens"], cleaned_data))
total_embeddable_tokens = sum(map(lambda x: x["metadata"]["embeddable_tokens"], cleaned_data))

chunked_artifact = wandb.Artifact(name="chunked_data", type="dataset",
description="Chunked wandb documentation", metadata={
    "total_files": len(cleaned_data),
    "date_processed": datetime.now().strftime("%Y-%m-%d"),
    "total_parsed_tokens": total_parsed_tokens,
    "total_embeddable_tokens": total_embeddable_tokens, 
    "chunk_size": CHUNK_SIZE
    }
)
with chunked_artifact.new_file("documents.jsonl", mode="w") as f:
    for item in cleaned_data:
        f.write(json.dumps(item) + "\n")
run.log_artifact(chunked_artifact)
run.finish()

## Data Embedding

1. We use [Cohere](https://cohere.ai/) to embed the chunks of text.
2. We preprocess the text to remove any special characters and whitespace.
3. We then embed the text in batches using the cohere `embed-english-v3.0` model.


In [None]:
import os
from typing import List

import nest_asyncio
from dotenv import load_dotenv

nest_asyncio.apply()
import asyncio

load_dotenv()

import cohere

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type="data_embedding", group="ingestion")

In [None]:
chunked_artifact = run.use_artifact(f'{WANDB_ENTITY}/{WANDB_PROJECT}/chunked_data:latest', type='dataset')
artifact_dir = chunked_artifact.download()
chunked_data_file = pathlib.Path(f"{artifact_dir}/documents.jsonl")
chunked_data = list(map(json.loads, chunked_data_file.read_text().splitlines()))
chunked_data[:2]

In [None]:
co = cohere.AsyncClient(api_key=os.getenv("CO_API_KEY"))

In [None]:
async def embed_batch(texts: List[str]) -> List[float]:
    response = await co.embed(
        texts=texts, model="embed-english-v3.0", input_type="search_document"
    )
    return response.embeddings


async def embed_texts(texts: List[str], batch_size=50) -> List[List[float]]:
    tasks = [embed_batch(texts[i:i+batch_size]) for i in range(0, len(texts), batch_size)]
    results = await asyncio.gather(*tasks)
    return [item for sublist in results for item in sublist]

In [None]:
embeddings = asyncio.run(embed_texts(list(map(lambda x: x["embeddable_content"], chunked_data))))

In [None]:
embedded_data = []
for document, embedding in zip(chunked_data, embeddings):
    embedded_document = document.copy()
    embedded_document["embedding"] = embedding
    embedded_data.append(embedded_document)

embedded_data[:2]

In [None]:
embedded_artifact = wandb.Artifact(name="embedded_data", type="dataset",
description="Embedded wandb documentation", metadata={
    "total_files": len(embedded_data),
    "date_processed": datetime.now().strftime("%Y-%m-%d"),
    "embedding_model": "cohere-embed-english-v3.0",
    "embedding_dim": 1024,
    "chunk_size": CHUNK_SIZE,
    }
)
with embedded_artifact.new_file("documents.jsonl", mode="w") as f:
    for item in embedded_data:
        f.write(json.dumps(item) + "\n")
run.log_artifact(embedded_artifact)
run.finish()

## RAG

### Retrieval

1. We embed the query with the [Cohere](https://cohere.ai/) `embed-english-v3.0` model.
2. We use the cosine distance to find the most relevant chunks of text from the embedding.
3. We return the top-k chunks of text along with metadata

In [None]:
import json
import pathlib
import re
from typing import Any, Dict, List

import markdown
import nest_asyncio
import numpy as np
import pandas as pd
import weave
from IPython.display import Markdown
from scipy import spatial

import wandb

nest_asyncio.apply()
import asyncio
import os

from dotenv import load_dotenv

load_dotenv()

from typing import List

import cohere

In [None]:
WANDB_ENTITY = "parambharat"
WANDB_PROJECT = "advanced_rag"

wandb.require("core")

run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type="data_retrieval", group="retrieval")
_ = weave.init(f"{WANDB_ENTITY}/{WANDB_PROJECT}")

In [None]:
embedded_artifact = run.use_artifact(f'{WANDB_ENTITY}/{WANDB_PROJECT}/embedded_data:latest', type='dataset')
artifact_dir = embedded_artifact.download()
embedded_data_file = pathlib.Path(f"{artifact_dir}/documents.jsonl")
embedded_data = list(map(json.loads, embedded_data_file.read_text().splitlines()))

embedded_df = pd.DataFrame(embedded_data)
embedded_df["embedding"] = embedded_df["embedding"].map(np.array)
embedded_df.head()

In [None]:
co = cohere.AsyncClient(api_key=os.getenv("CO_API_KEY"))

In [None]:
@weave.op()
async def retieve_context(question: str, context: List[Dict[str, Any]], top_k: int=10) -> List[Dict[str, Any]]:
    context_df = pd.DataFrame(context)
    q_embeddings = await co.embed(texts=[question], model="embed-english-v3.0", input_type="search_query")
    q_embeddings = np.array(q_embeddings.embeddings)

    document_embeddings = np.vstack(context_df.embedding.tolist())
    # Get the distances from the embeddings
    context_df["distances"] = spatial.distance.cdist(q_embeddings, document_embeddings, metric='cosine')[0]
    output = context_df.sort_values("distances").head(top_k).to_dict(orient="records")
    return output

### Re-rank

1. We rerank the chunks of text using the [Cohere](https://cohere.ai/) `rerank-english-v3.0` model.
2. We return the top-k reranked chunks of text along with metadata

In [None]:
@weave.op()
async def rerank_context(question: str, context: List[Dict[str, Any]], top_k: int=5) -> List[Dict[str, Any]]: 
    response = await co.rerank(query=question, documents=[item["embeddable_content"] for item in context], model="rerank-english-v3.0", top_n=top_k)
    reranked_indices = [item.index for item in response.results]
    return [context[index] for index in reranked_indices]

In [None]:
@weave.op()
async def build_context(question: str, context: List[Dict[str, Any]], max_len: int=4096) -> List[Dict[str, Any]]:
    retrieved_context = await retieve_context(question, context, top_k=20)
    reranked_context = await rerank_context(question, retrieved_context, top_k=10)
    
    outputs = []
    cur_len = 0

    for row in reranked_context:
        cur_len += row["metadata"]['parsed_tokens'] + 4
        if cur_len > max_len:
            break
        outputs.append(row)

    # Return the context
    outputs = [{"text": item["parsed_content"], "source": item["metadata"]["source"]} for item in outputs]

    return outputs

### Response Synthesis

1. We use the [Cohere](https://cohere.ai/) `command-r-plus` model to generate an answer to the query.
2. We return the answer in Markdown format

In [None]:
SYSTEM_PROMPT = """You are Wandbot, a support expert for Weights & Biases, wandb, and weave.
Your goal is to help users with questions related to the Weights & Biases Platform, providing accurate and helpful responses based solely on the given context.

You will be provided the context you should use to answer the user's question

First, ensure you understand the question and the relevant information in the context. If the question is unclear, prepare to ask for clarification.


Process the question and context as follows:
1. Identify the main topic and any subtopics in the question.
2. Locate relevant information in the context.
3. Formulate a clear, concise answer based only on the provided context.
4. If code snippets are needed, ensure they are derived only from the context and are syntactically correct and functional.
5. Prepare to cite your sources for each piece of information you use.

Format your response in Markdown using MLA but without using headers. Structure your answer as follows:
1. Direct answer to the question
2. Explanation or steps (if applicable)
3. Code snippet (if relevant)
4. Additional information or tips (if appropriate)

For each piece of information you use, add a citation.

If the context doesn't provide sufficient information to answer the question fully or accurately, admit your uncertainty and suggest contacting Weights & Biases support at support@wandb.com or visiting the community forums at https://wandb.me/community.

Remember, you must always provide both an answer (or an admission of uncertainty) and citations in your response. Do not refer to the context directly in your answer; instead, provide the information and cite the source."""

In [None]:
@weave.op()
async def generate_response(question: str, context: List[Dict[str, str]]) -> Dict[str, Any]:

    response = await co.chat(
        preamble=SYSTEM_PROMPT,
        message=question,
        model="command-r-plus",
        documents=context,
        temperature=0.1,
        max_tokens=2000
        
        )
    
    return response.dict()

In [None]:
@weave.op()
def render_html_response(question: str, response: Dict[str, Any])-> str:
    text = response['text']
    citations = sorted(response['citations'], key=lambda x: x['start'], reverse=True)
    documents = {item['id']: item for item in response['documents']}
    
    # Create a dictionary to store unique sources
    sources_dict = {}
    
    for i, citation in enumerate(citations):
        # Create a list of source numbers for this citation
        source_numbers = []
        for doc_id in citation['document_ids']:
            source = documents[doc_id]['source']
            if source not in sources_dict:
                sources_dict[source] = len(sources_dict) + 1
            source_numbers.append(str(sources_dict[source]))
        
        # Join the source numbers for the hover text
        hover_text = f"Sources: [{', '.join(source_numbers)}]"
        cited_text = text[citation['start']:citation['end']]
        html_span = f'<span class="citation" data-tooltip="{hover_text}">{cited_text}</span>'
        text = text[:citation['start']] + html_span + text[citation['end']:]

    # Convert markdown to HTML after processing citations
    text = markdown.markdown(text, extensions=[
            "toc",
            "pymdownx.extra",
            "pymdownx.blocks.admonition",
            "pymdownx.magiclink",
            "pymdownx.blocks.tab",
            "pymdownx.pathconverter",
            "pymdownx.saneheaders",
            "pymdownx.striphtml",
        ],
    )

    # Create the footer with numbered sources
    footer = "<ol>"
    for source, number in sorted(sources_dict.items(), key=lambda x: x[1]):
        footer += f"<li>{source}</li>"
    footer += "</ol>"

    html = f"""
    <div class="response">
        <h3>Question:</h3>
        <p>{markdown.markdown(question)}</p>
        <h3>Answer:</h3>
        {text}
        <hr>
        <h4>Sources:</h4>
        {footer}
    </div>

    <style>
        .citation {{
            text-decoration: underline;
            cursor: pointer;
            position: relative;
        }}
        .citation::after {{
            content: attr(data-tooltip);
            position: absolute;
            bottom: 100%;
            left: 50%;
            transform: translateX(-50%);
            background-color: #333;
            color: white;
            padding: 5px;
            border-radius: 3px;
            opacity: 0;
            transition: opacity 0.3s;
            white-space: nowrap;
            pointer-events: none;
        }}
        .citation:hover::after {{
            opacity: 1;
        }}
    </style>
    """
    return html

In [None]:
@weave.op
def render_markdown_with_footnotes(question: str, response: Dict[str, Any]) -> str:
    text = response['text']
    citations = sorted(response['citations'], key=lambda x: x['start'], reverse=True)
    documents = {item['id']: item for item in response['documents']}
    
    sources_dict = {}
    
    code_blocks = re.finditer(r'(`{1,3})[\s\S]*?\1', text)
    code_block_ranges = [(m.start(), m.end(), m.group(1)) for m in code_blocks]

    for citation in citations:
        source_numbers = []
        for doc_id in citation['document_ids']:
            source = documents[doc_id]['source']
            if source not in sources_dict:
                sources_dict[source] = len(sources_dict) + 1
            source_numbers.append(str(sources_dict[source]))
        
        footnote = f'<sup>[{", ".join(source_numbers)}]</sup>'
        
        in_code_block = False
        for start, end, delim in code_block_ranges:
            if start <= citation['start'] < end:
                in_code_block = True
                if len(delim) == 3:  # Multiline code block
                    footnote = f'\n{footnote}'
                break
        
        if in_code_block:
            text = text[:end] + footnote + text[end:]
        else:
            cited_text = text[citation['start']:citation['end']]
            underlined_text = f'<ins>{cited_text}</ins>{footnote}'
            text = text[:citation['start']] + underlined_text + text[citation['end']:]

    footer = []
    for source, number in sorted(sources_dict.items(), key=lambda x: x[1]):
        footer.append(f'<a name="fn{number}"></a>[{number}] {source}')
    footer = '\n - '.join(footer)
    markdown = f"""## Question\n**{question}**\n\n---\n\n## Answer\n{text}\n\n---\n\n## Sources\n\n - {footer}"""
    return markdown

In [None]:
@weave.op()
async def answer_question(question: str, context: List[Dict[str, Any]]):
    context = await build_context(question, context)
    response = await generate_response(question, context)
    markdown = render_markdown_with_footnotes(question, response)
    return markdown

In [None]:
sample_query="How can I resume my accidentally stopped sweeps?"
markdown_output = asyncio.run(answer_question(sample_query, embedded_df.to_dict(orient="records")))

In [None]:
Markdown(markdown_output)