# Synthetic Data Generation

This guide provides a quickstart for creating a synthetic QA and Retrieval-Augmented Generation (RAG) dataset using your own PDF document.

## Setup 
We recommend using Python 3.11. Make sure you have the necessary packages installed. If not, install them using the following command:

In [None]:
!pip install ipykernel==6.29.5
!pip install langchain-sambanova==0.1.6
!pip install "unstructured[pdf,local-inference]"
!pip install numpy==1.26.4

To use [SambaNova Cloud](https://cloud.sambanova.ai) models, you'll need to set your API key. Run the following code to securely input your [SambaNova Cloud API Key](https://cloud.sambanova.ai/apis).

In [None]:
import getpass
import os
if not os.getenv("SAMBANOVA_API_KEY"):
    os.environ["SAMBANOVA_API_KEY"] = getpass.getpass(
        "Enter your SambaNova Cloud API key: "
    )

Set the LLM to be used to generate the QA pairs.

In [None]:
from langchain_sambanova import ChatSambaNovaCloud

# Initialize the LLM and specify the model
llm = ChatSambaNovaCloud(
    model="Meta-Llama-3.1-8B-Instruct",
    temperature=0.01,
    max_tokens=2048
)

##  Load data

First, specify the location of the PDF file to process and extracts elements from.

In [None]:
# Specify location of the PDF file
filename = "./data/SambaNova_Dataflow.pdf"

Then, we will load the data from your source file using the [Unstructured](https://docs.unstructured.io/open-source/introduction/quick-start) library.

In [None]:
from unstructured.partition.pdf import partition_pdf

def extract_pdf(file_path):
    """Extract text and tables from PDF file

    Args:
        file_path (str): Path to the PDF file to be processed.
    
    Returns:
        List[Element]: A list of document elements (text, tables, etc.) extracted from the PDF.
    """
    raw_pdf_elements = partition_pdf(
        filename=file_path,
        extract_images_in_pdf=False, # Keep False
        strategy='hi_res',
        hi_res_model_name='yolox',
        infer_table_structure=True, # Set to True to enable table detection; otherwise, set to False.
        chunking_strategy='by_title',
        max_characters=4096,
        combine_text_under_n_chars=500)

    return raw_pdf_elements

Next, we will store the extracted elements into a list.

In [None]:
from langchain_core.documents import Document

text_documents = []
table_documents = []
raw_pdf_elements = extract_pdf(filename)
for document in raw_pdf_elements:
    if document.category == 'Table':
        #transform table documents into langchain documents
        table_documents.append(Document(page_content=document.metadata.text_as_html))
    else:
        if document.metadata.text_as_html is not None:
                table_documents.append(Document(page_content=document.metadata.text_as_html))
        else:
            text_documents.append(Document(page_content=document.text))

print(len(table_documents))        
print(len(text_documents))
documents = text_documents + table_documents

## Generate QA pairs

With our granular documents ready, we can use a Large Language Model (LLM) to create QA pairs. Consider the following:

- Depending on the dataset's purpose, you may want the model to include references used to generate the answer.
- You might want the model to include reasoning steps from context to answer. A good strategy for this is [Chain of Thought (CoT)](https://www.promptingguide.ai/techniques/cot).
- The model should generate a structured output from which we can extract the question, the thought process, the answer, and the references

First, we'll define the schema for the QA data.

In [None]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import json

class SyntheticDatum(BaseModel):
    """Model of a synthetic generated datum"""
    question: str = Field(description='generated question')
    answer: str = Field(description='generated answer')
    references: list[str] = Field(description='references for generated answer')
    thought: str = Field(description='thought for answer generation')


class SyntheticData(BaseModel):
    """Model of a synthetic data generation"""
    data: list[SyntheticDatum] = Field(description='synthetic data pairs')    

We will define a prompt instructing the model to generate QA pairs using the provided document and the specified number of QA pairs. The prompt will ask the model to generate a list of JSON objects containing the question, thought process, answer, and references.

In [None]:
prompt = ChatPromptTemplate([
        ("system", "You are a JSON generator who generates machine-readable JSON"),
        ("human", """
            Based on the following document, follow the instruction below
            Document:
            {document}
            Instruction:
            Generate {amount} of unique question, thought, answer, and references from the above document in the following JSON format. 
            The answers must avoid words that are not specific (e.g., "many", "several", "few", etc.). 
            The answers must contain specific, verbose, self-contained, grammatically correct sentences that answer the question comprehensively. 
            The answers must strictly contain content from the document and no content from outside the document. 
            There may be multiple references that contain verbatim text from the document to support the answers.             
            JSON format:
            [
                {{
                    "question": "<generated question>",            
                    "thought": "<generated thought on what is needed to answer the question. Start with 'To answer the question, I need'>",
                    "answer": "<generated answer>",
                    "references": [
                        "<verbatim text from document that supports the answer>",
                        "<verbatim text from document that supports the answer>"
                    ]
                }}
            ]
            The first character of the response must be '[' and the last character must be ']'. No header text should be included.
            """
        )
    ]
)


With the prompt defined, we can create a method to instantiate a LangChain chain, pass the input arguments (the context document and the number of QA pairs to generate), and process the model's response using the defined QA data schemas.

In [None]:
def generate_qa_pairs(context, amount, include_context = False, include_thoughts = False, include_references = False):
    """Generate synthetic QA pairs from a given context using a LangChain chain.

    Args:
        context (str): The source text to generate questions and answers from.
        amount (int): Number of QA pairs to generate.
        include_context (bool): Whether to include the original context in each output entry.
        include_thoughts (bool): Whether to include model 'thoughts' in each QA pair.
        include_references (bool): Whether to include reference sources in each QA pair.

    Returns:
        List[dict]: A list of dictionaries containing QA pairs (and optional metadata).
    """

    synthetic_datum_parser = JsonOutputParser(pydantic_object=SyntheticData)
    qa_generate_chain = prompt | llm | synthetic_datum_parser
    qa_pairs = []
    generation = qa_generate_chain.invoke({'document': context, 'amount': amount})
    for datum in generation:
        qa_pair = {
            'question': datum['question'],
            'context': context if include_context else None,
            'answer': datum['answer'],
            'thought': datum['thought'] if include_thoughts else None,
            'references': datum['references'] if include_references else None,
        }
        qa_pair = {k: v for k, v in qa_pair.items() if v is not None}
        qa_pairs.append(qa_pair)
    return qa_pairs

Here is an example where we create a series of synthetic data pairs, including the original context (useful for training models for Retrieval-Augmented Generation (RAG) applications).

In [None]:
sample_doc="""Elephants are the largest living land animals. 
Three living species are currently recognised:
the African bush elephant (Loxodonta africana),
the African forest elephant (L. cyclotis), and the Asian elephant (Elephas maximus). 
They are the only surviving members of the family Elephantidae and the order Proboscidea;
extinct relatives include mammoths and mastodons."""

generate_qa_pairs(sample_doc, 5, include_context = True, include_thoughts = True, include_references = True)

## Generate full dataset

We will create a simple method to convert each QA pair dictionary into a single string with the format required for the fine-tuning process. Then, we will iterate over each chunk of our source data.

In [None]:
def qa_pairs_to_prompt_completion(qa_pairs):
    """Converts QA pair dictionaries into prompt-completion strings formatted for fine-tuning.

    Args:
        qa_pairs (Union[dict, List[dict]]): A single QA pair or a list of QA pairs.

    Returns:
        List[str]: A list of JSON-formatted strings, each representing a prompt-completion example.
    """
    # Ensure input is a list of QA pairs
    if isinstance(qa_pairs, dict):
        qa_pairs = [qa_pairs]
    
    lines = []
    
    for pair in qa_pairs:
        #line = {'prompt': f'{"You are a helpful assistant for question-answering tasks."}{pair["question"]}', 'completion': ''}
        line = {'prompt': f'{pair["question"]}', 'completion': ''}

        # Optionally include context if available 
        if pair.get('context'):
            line['prompt'] += f'\nContext: {pair["context"]}\n'
        
        # Optionally include the model's "thoughts" before the answer
        if pair.get('thought'):
            line['completion'] += f'Thought: {pair["thought"]}\n'
        
        # Append the answer directly to the completion
        line['completion'] += f'Answer: {pair["answer"]}\n'

        # Optionally include references at the end
        if pair.get('references'):
            line['completion'] += f'References: {pair["references"]}\n'
        
        # Convert the prompt-completion pair to a JSONL line
        lines.append(json.dumps(line))
    return lines

In [None]:
lines = []
for document in documents:
    try: 
        qa_pairs = generate_qa_pairs(
            context=document.page_content,
            amount=5,
            include_context=False,
            include_thoughts=False,
            include_references=False,
        )
        lines.extend(qa_pairs_to_prompt_completion(qa_pairs))
    except Exception as e:
        print(f"Error generating Q&A pairs for document: {document.page_content}")
        print(e)
lines

Write the list of JSON strings into a jsonl file

In [None]:
with open("output.jsonl", "w") as f:
    for line in lines:
        json_obj = json.loads(line)  # ensure it's valid JSON
        f.write(json.dumps(json_obj) + "\n")