In [9]:
from typing import List, Type, TypeVar, Optional
from pydantic import BaseModel, Field
from langchain.chat_models import AzureChatOpenAI  # Ensure correct import
import tiktoken
import json
import time
import os
import re
from dotenv import load_dotenv

# Load environment variables from a .env file
load_dotenv()

# Define a generic type for Pydantic models
T = TypeVar('T', bound=BaseModel)

def process_chunk(model, encoding, schema, system_prompt, user_prompt, chunk, max_tokens_for_prompt, max_tokens_for_response):
    """
    Process a single chunk by invoking the Azure OpenAI model and extracting data.
    """
    prompt_tokens = encoding.encode(user_prompt)
    if len(prompt_tokens) > max_tokens_for_prompt:
        print(f"Prompt too long for this chunk, skipping.")
        return []

    # Call Azure OpenAI API
    try:
        response = model.invoke(
            input=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            )
        assistant_reply = response.content

        # Parse the assistant's reply
        try:
            extracted_data = json.loads(assistant_reply)
            extracted_needles = []
            if isinstance(extracted_data, list):
                for item in extracted_data:
                    try:
                        extracted_item = schema(**item)
                        extracted_needles.append(extracted_item)
                    except Exception as e:
                        print(f"Error parsing item: {item}, error: {e}")
            else:
                print(f"Expected a list, got: {type(extracted_data)}")
            return extracted_needles
        except json.JSONDecodeError as e:
            print(f"JSON decode error: {e}")
            print("Assistant reply:")
            print(assistant_reply)
            return []
    except Exception as e:
        print(f"Error with Azure OpenAI API: {e}")
        if "rate limit" in str(e).lower():
            print("Rate limit exceeded. Sleeping for 60 seconds.")
            time.sleep(60)
        return []

def generate_keywords(model, example_needles: List[str]) -> List[str]:
    """
    Use an LLM to generate relevant keywords based on the given example needles.
    """

    user_prompt = (
        """
        You are an expert in extracting keywords from text. Given the following examples, generate relevant keywords such that they can be used to extract similar data from a large text corpus
        The keywords have to generalize well to other examples. The other examples have similar structure to the ones provided below.
        """ + '\n'.join(example_needles) + "\n\nProvide the keywords as a comma-separated list."
    )

    try:
        response = model.invoke(
            input= user_prompt,
            temperature=0.3,
            )
        assistant_reply = response.content
        keywords = [kw.strip() for kw in assistant_reply.split(',') if kw.strip()]
        print(f"Generated keywords: {keywords}")
        return keywords
    except Exception as e:
        print(f"Error generating keywords with Azure OpenAI API: {e}")
        return []

def extract_multi_needle(schema: Type[T], haystack: str, example_needles: List[str]) -> List[T]:
    """
    Extracts and structures information from a large text corpus based on a given schema and examples.
    """
    extracted_needles = []
    model_name = 'gpt-4o-mini'  # Ensure this matches your deployment name

    # Initialize the Azure OpenAI model
    model = AzureChatOpenAI(
    openai_api_version=os.environ.get("AZURE_OPENAI_VERSION", "2024-07-18"),
    azure_deployment=os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-mini"),
    azure_endpoint=os.environ.get(
        "AZURE_OPENAI_ENDPOINT", 
        "https://gptmini4o.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview"
    ),
    openai_api_key=os.environ.get("AZURE_OPENAI_KEY", "your_default_api_key_here"),
    )

     # Initialize tokenizer
    try:
        encoding = tiktoken.get_encoding("cl100k_base")  # Use appropriate encoding
    except Exception as e:
        print(f"Error initializing tokenizer: {e}")
        return extracted_needles

    # Token limits
    max_tokens_per_request = 128000
    max_tokens_for_response = 1000
    max_tokens_for_prompt = max_tokens_per_request - max_tokens_for_response

    # Step 1: Generate keywords using LLM
    keywords = generate_keywords(model, example_needles)
    if not keywords:
        print("No keywords generated by LLM.")
        return []

    # Build a regex pattern to find sentences containing any of the keywords
    keyword_pattern = r'\b(?:' + '|'.join(map(re.escape, keywords)) + r')\b'
    sentence_pattern = re.compile(r'[^.!?]*' + keyword_pattern + r'[^.!?]*[.!?]', re.IGNORECASE)

    # Find all candidate sentences
    candidate_sentences = sentence_pattern.findall(haystack)
    print(candidate_sentences)
    if not candidate_sentences:
        print("No candidate sentences found using keyword-based pre-filtering.")
        return []

    print(f"Found {len(candidate_sentences)} candidate sentences.")

    # Step 2: Batch candidate sentences into large chunks
    candidate_text = ' '.join(candidate_sentences)
    candidate_tokens = encoding.encode(candidate_text)
    num_tokens = len(candidate_tokens)

    chunk_size = max_tokens_for_prompt  # 127,000 tokens
    chunks = [
        encoding.decode(candidate_tokens[i:i + chunk_size])
        for i in range(0, num_tokens, chunk_size)
    ]

    print(f"Split into {len(chunks)} chunks.")

    # Prepare schema description
    schema_description = "Extract information according to the following schema:\n{\n"
    for field_name, field in schema.__fields__.items():
        field_descr = field.description or ''
        field_type = (
            field.annotation.__name__ if hasattr(field.annotation, '__name__') else str(field.annotation)
        )
        schema_description += f'  "{field_name}": "{field_descr} ({field_type})",\n'
    schema_description += "}\n"

    # Prepare examples
    examples_text = "Examples of the desired output format:\n"
    for example in example_needles:
        examples_text += f"- {example}\n"

    # System prompt
    system_prompt = "You are an AI language model that extracts structured data from text."

    # Step 3: Process each chunk sequentially
    for idx, chunk in enumerate(chunks):
        user_prompt = (
            f"{schema_description}\n"
            f"{examples_text}\n"
            f"Text to analyze:\n\"\"\"\n{chunk}\n\"\"\"\n"
            "Extract any instances matching the schema from the text above. "
            "Provide the output as a JSON array of objects."
        )
        result = process_chunk(model, encoding, schema, system_prompt, user_prompt, chunk, max_tokens_for_prompt, max_tokens_for_response)
        extracted_needles.extend(result)

    return extracted_needles

# Example usage
if __name__ == "__main__":
    class TechCompany(BaseModel):
        name: Optional[str] = Field(default=None, description="The full name of the technology company")
        location: Optional[str] = Field(default=None, description="City and country where the company is headquartered")
        employee_count: Optional[int] = Field(default=None, description="Total number of employees")
        founding_year: Optional[int] = Field(default=None, description="Year the company was established")
        is_public: Optional[bool] = Field(default=None, description="Whether the company is publicly traded (True) or privately held (False)")
        valuation: Optional[float] = Field(default=None, description="Company's valuation in billions of dollars")
        primary_focus: Optional[str] = Field(default=None, description="Main area of technology or industry the company focuses on")

    example_needles = [
        "Ryoshi, based in Neo Tokyo, Japan, is a private quantum computing firm founded in 2031, currently valued at $8.7 billion with 1,200 employees focused on quantum cryptography."
    ]

    # Read the haystack from a file
    haystack_file = "haystack.txt"  # Ensure this file exists and is accessible
    try:
        with open(haystack_file, "r", encoding="utf-8") as file:
            haystack = file.read()
    except FileNotFoundError:
        print(f"Haystack file '{haystack_file}' not found.")
        haystack = ""

    if haystack:
        extracted_data = extract_multi_needle(TechCompany, haystack, example_needles)
        print(len(extracted_data), "items extracted:")
        for item in extracted_data:
            print(item)

Generated keywords: ['Ryoshi', 'Neo Tokyo', 'Japan', 'private quantum computing firm', 'founded 2031', 'valued $8.7 billion', '1', '200 employees', 'quantum cryptography']
["\n\n\nFRANK HERBERT FRAN HERBERT \n\n\nFRANK HERBERT'S \nDUNE SAGA COLLECTION \n\n\nFrank Herbert's Dune Saga Collection: \nBooks 1 - 6 \n\n\nDune \nDune Messiah \nChildren of Dune \nGod Emperor of Dune \nHeretics of Dune \nChapterhouse: Dune \n\n\nFrank Herbert \n\n\nTable of Contents \n\n\nCover \nTitle Page \n\n\nDune \n\nDune Messiah \nChildren of Dune \n\nGod Emperor of Dune \nHeretics of Dune \nChapterhouse: Dune \n\n\n\n\n\nPRAISE FOR THE DUNE CHRONICLES \n\n\nDUNE \n\n\n“An astonishing science fiction phenomenon.", ' “1 \nshould wed your mother, make her my Duchess.', ' “Princess-daughter,” my father said, “1 \nwould that you’d been older when it came time for \nthis man to choose a woman.', ' If 1 refuse, it may offend him.', ' “1 \nhope it will not be necessary.', ' \nShall 1 tell him of the Duke’s daught

In [1]:
from typing import List, Type, TypeVar, Optional, Dict, Any
from pydantic import BaseModel
import spacy
import openai
import json
import re
import numpy as np
from sentence_transformers import SentenceTransformer, util

T = TypeVar('T', bound=BaseModel)

def extract_multi_needle(schema: Type[T], haystack: str, example_needles: List[str]) -> List[T]:
    """
    Extracts and structures information from a large text corpus based on a given schema and examples.

    Args:
    schema (Type[T]): A Pydantic model defining the structure of the needle to be extracted.
    haystack (str): The large text corpus to search through (haystack).
    example_needles (List[str]): A list of example sentences (needles).

    Returns:
    List[T]: A list of extracted needles conforming to the provided schema.
    """
    extracted_needles = []

    # Initialize spaCy model
    nlp = spacy.load('en_core_web_sm')
    nlp.max_length = len(haystack) + 1000  # Adjust based on the input size

    # Step 1: Extract keywords from example_needles
    keywords = extract_keywords(example_needles, nlp)

    # Step 2: Preprocess haystack and split into sentences
    haystack_sentences = split_into_sentences(haystack, nlp)

    # Step 3: Find candidate sentences using keyword matching
    candidate_sentences = find_candidate_sentences(haystack_sentences, keywords)

    if not candidate_sentences:
        return extracted_needles  # No candidates found

    # Step 4: Use embeddings to compute similarity between examples and candidates
    num_candidates = 100  # Adjust based on desired processing time
    candidate_sentences = rank_candidates_by_similarity(example_needles, candidate_sentences, num_candidates)

    # Step 5: Extract data from candidate sentences using LLM
    for sentence in candidate_sentences:
        prompt = construct_prompt(schema, sentence)
        response = call_llm_api(prompt)
        data = parse_llm_response(response)
        try:
            instance = schema.parse_obj(data)
            extracted_needles.append(instance)
        except Exception:
            # Handle validation errors silently or log them
            pass

    return extracted_needles

def extract_keywords(example_needles: List[str], nlp) -> set:
    combined_text = ' '.join(example_needles)
    doc = nlp(combined_text)
    keywords = {token.lemma_.lower() for token in doc if token.pos_ in {'NOUN', 'PROPN', 'ADJ'}}
    return keywords

def split_into_sentences(text: str, nlp) -> List[str]:
    doc = nlp(text)
    sentences = [sent.text.strip() for sent in doc.sents]
    return sentences

def find_candidate_sentences(sentences: List[str], keywords: set) -> List[str]:
    candidate_sentences = []
    for sentence in sentences:
        sentence_lower = sentence.lower()
        if any(keyword in sentence_lower for keyword in keywords):
            candidate_sentences.append(sentence)
    return candidate_sentences

def rank_candidates_by_similarity(example_needles: List[str], candidate_sentences: List[str], num_candidates: int) -> List[str]:
    model = SentenceTransformer('all-MiniLM-L6-v2')
    example_embeddings = model.encode(example_needles, convert_to_tensor=True)
    candidate_embeddings = model.encode(candidate_sentences, convert_to_tensor=True)
    cosine_scores = util.cos_sim(example_embeddings, candidate_embeddings)
    max_similarities = cosine_scores.max(axis=0).values
    top_indices = np.argsort(-max_similarities.cpu().numpy())[:num_candidates]
    top_candidate_sentences = [candidate_sentences[idx] for idx in top_indices]
    return top_candidate_sentences

def construct_prompt(schema: Type[T], sentence: str) -> str:
    field_info = []
    for field_name, field in schema.__fields__.items():
        description = field.description or ''
        field_type = (
            field.annotation.__name__ if hasattr(field.annotation, '__name__') else str(field.annotation)
        )
        field_info.append(f"- {field_name} ({field_type}): {description}")
    field_info_text = '\n'.join(field_info)
    prompt = f"""Extract the following information from the sentence:

Sentence: "{sentence}"

Information to extract:
{field_info_text}

Return the information as a JSON object with keys matching the field names.
If a piece of information is not available, use null.

Example format:
{{
    "field1": value1,
    "field2": value2,
    ...
}}
"""
    return prompt

def call_llm_api(prompt: str) -> str:
    # Replace 'YOUR_API_KEY' with your actual OpenAI API key
    model = AzureChatOpenAI(
    openai_api_version=os.environ.get("AZURE_OPENAI_VERSION", "2024-07-18"),
    azure_deployment=os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-mini"),
    azure_endpoint=os.environ.get(
        "AZURE_OPENAI_ENDPOINT", 
        "https://gptmini4o.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview"
    ),
    openai_api_key=os.environ.get("AZURE_OPENAI_KEY", "your_default_api_key_here"),
    )
    response = model.invoke(
        input=[
            {'role': 'system', 'content': 'You are an assistant that extracts structured information from text.'},
            {'role': 'user', 'content': prompt}
        ],
        temperature=0,
    )
    return response.content

def parse_llm_response(response: str) -> Dict[str, Any]:
    try:
        data = json.loads(response)
        return data
    except json.JSONDecodeError:
        match = re.search(r'\{.*\}', response, re.DOTALL)
        if match:
            json_str = match.group(0)
            try:
                data = json.loads(json_str)
                return data
            except json.JSONDecodeError:
                return {}
        else:
            return {}



In [31]:
from pydantic import BaseModel, Field
from typing import Optional

class TechCompany(BaseModel):
    name: Optional[str] = Field(default=None, description="The full name of the technology company")
    location: Optional[str] = Field(default=None, description="City and country where the company is headquartered")
    employee_count: Optional[int] = Field(default=None, description="Total number of employees")
    founding_year: Optional[int] = Field(default=None, description="Year the company was established")
    is_public: Optional[bool] = Field(default=None, description="Whether the company is publicly traded (True) or privately held (False)")
    valuation: Optional[float] = Field(default=None, description="Company's valuation in billions of dollars")
    primary_focus: Optional[str] = Field(default=None, description="Main area of technology or industry the company focuses on")

# Example needles and haystack
example_needles = [
    "Ryoshi, based in Neo Tokyo, Japan, is a private quantum computing firm founded in 2031, currently valued at $8.7 billion with 1,200 employees focused on quantum cryptography."
]
with open("haystack.txt", "r") as file:
    haystack = file.read()
# Call the function
extracted_companies = extract_multi_needle(TechCompany, haystack, example_needles)

# Print the results
for company in extracted_companies:
    print(company)


: 