In [1]:
import os
import json
import re
from typing import List, Type, TypeVar
from pydantic import BaseModel
from concurrent.futures import ThreadPoolExecutor, as_completed
from sentence_transformers import SentenceTransformer
import faiss
from langchain.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Define a generic type variable bound to BaseModel
T = TypeVar('T', bound=BaseModel)

def compute_embeddings_multithreaded(texts: List[str], model_name: str = 'all-MiniLM-L6-v2', batch_size: int = 64) -> List:
    """
    Computes embeddings for a list of texts using multi-threading.

    Args:
        texts (List[str]): The list of texts to compute embeddings for.
        model_name (str): The name of the SentenceTransformer model to use.
        batch_size (int): The batch size for embedding computation.

    Returns:
        List: A list of embeddings corresponding to the input texts.
    """
    # Initialize the embedding model
    embedding_model = SentenceTransformer(model_name)

    # Function to compute embeddings for a batch of texts
    def compute_batch(batch_texts):
        return embedding_model.encode(batch_texts, show_progress_bar=False)

    embeddings = [None] * len(texts)

    # Create batches of texts
    batches = [(i, texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)]

    # Use ThreadPoolExecutor for multi-threading
    with ThreadPoolExecutor() as executor:
        # Submit tasks
        future_to_batch = {executor.submit(compute_batch, batch_texts): (i, batch_texts) for i, batch_texts in batches}
        # Collect results as they complete
        for future in as_completed(future_to_batch):
            i, batch_texts = future_to_batch[future]
            try:
                batch_embeddings = future.result()
                embeddings[i:i + len(batch_embeddings)] = batch_embeddings
            except Exception as exc:
                print(f'Batch starting at index {i} generated an exception: {exc}')

    return embeddings

def extract_multi_needle(schema: Type[T], haystack: str, example_needles: List[str]) -> List[T]:
    """
    Extracts and structures information from a large text corpus based on a given schema and examples.

    Args:
        schema (Type[T]): A Pydantic model defining the structure of the needle to be extracted.
        haystack (str): The large text corpus to search through (haystack).
        example_needles (List[str]): A list of example sentences (needles).

    Returns:
        List[T]: A list of extracted needles conforming to the provided schema.
    """
    # Initialize the list to hold the extracted needles
    extracted_needles = []

    # Initialize the SentenceTransformer model name
    model_name = 'all-MiniLM-L6-v2'

    # Split the haystack into sentences using regex
    sentences = re.split(r'(?<=[.!?])\s+', haystack)

    # Compute embeddings for the sentences in the haystack using multi-threading
    sentence_embeddings = compute_embeddings_multithreaded(sentences, model_name=model_name)

    # Compute embeddings for the example needles
    embedding_model = SentenceTransformer(model_name)
    example_embeddings = embedding_model.encode(example_needles, show_progress_bar=True)

    # Build a Faiss index for efficient similarity search
    dimension = len(sentence_embeddings[0])
    index = faiss.IndexFlatIP(dimension)
    # Convert embeddings to numpy array
    import numpy as np
    sentence_embeddings = np.array(sentence_embeddings)
    faiss.normalize_L2(sentence_embeddings)
    index.add(sentence_embeddings)

    # Normalize example embeddings for cosine similarity
    example_embeddings = np.array(example_embeddings)
    faiss.normalize_L2(example_embeddings)

    # Number of nearest neighbors to retrieve
    k = 5

    # Perform similarity search for each example embedding
    D, I = index.search(example_embeddings, k)

    # Collect candidate sentences based on similarity search
    candidate_sentences = set()
    for indices in I:
        for idx in indices:
            candidate_sentences.add(sentences[idx])

    # Convert the set to a list for processing
    candidate_sentences = list(candidate_sentences)

    # Initialize the Azure OpenAI LLM model
    model = AzureChatOpenAI(
        openai_api_version=os.environ.get("AZURE_OPENAI_VERSION", "2024-07-18"),
        azure_deployment=os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-mini"),
        azure_endpoint=os.environ.get(
            "AZURE_OPENAI_ENDPOINT",
            "https://gptmini4o.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview"
        ),
        openai_api_key=os.environ.get("AZURE_OPENAI_KEY", "your_default_api_key_here"),
    )

    # Generate a description of the schema fields and their descriptions
    def generate_schema_description(schema: Type[BaseModel]) -> str:
        """
        Generates a string description of the schema.

        Args:
            schema (Type[BaseModel]): The Pydantic model.

        Returns:
            str: A string describing the schema fields and their descriptions.
        """
        schema_description = ""
        for field_name, field in schema.__fields__.items():
            field_desc = field.field_info.description or ''
            field_type = field.outer_type_.__name__
            schema_description += f"- {field_name} ({field_type}): {field_desc}\n"
        return schema_description

    schema_description = generate_schema_description(schema)

    # Construct the system prompt with schema description
    system_prompt = f"""
You are an assistant that extracts information from text according to a given schema.

The schema is:
{schema_description}

Your task is to read the provided text and extract any information that matches the schema.

Provide the extracted data as a JSON object conforming to the schema.

If the text does not contain relevant information, output an empty JSON object.

Only provide the JSON object, and no additional text.
"""

    # Process each candidate sentence
    for text in candidate_sentences:
        # Create the conversation messages for the LLM
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=text)
        ]

        # Call the LLM to process the text
        response = model(messages)

        # Attempt to parse the LLM response as JSON
        try:
            data = json.loads(response.content)
            if data:  # If data is not empty
                # Validate and instantiate the schema
                item = schema(**data)
                extracted_needles.append(item)
        except json.JSONDecodeError:
            # If parsing fails, skip this text
            continue
        except Exception:
            # If data does not conform to schema, skip
            continue

    return extracted_needles


In [2]:
from typing import Optional
from pydantic import BaseModel, Field

class TechCompany(BaseModel):
    name: Optional[str] = Field(default=None, description="The full name of the technology company")
    location: Optional[str] = Field(default=None, description="City and country where the company is headquartered")
    employee_count: Optional[int] = Field(default=None, description="Total number of employees")
    founding_year: Optional[int] = Field(default=None, description="Year the company was established")
    is_public: Optional[bool] = Field(default=None, description="Whether the company is publicly traded (True) or privately held (False)")
    valuation: Optional[float] = Field(default=None, description="Company's valuation in billions of dollars")
    primary_focus: Optional[str] = Field(default=None, description="Main area of technology or industry the company focuses on")

In [3]:
example_needles = ["Ryoshi, based in Neo Tokyo, Japan, is a private quantum computing firm founded in 2031, currently valued at $8.7 billion with 1,200 employees focused on quantum cryptography."]

In [4]:
with open("haystack.txt", "r") as file:
    haystack_text = file.read()

In [None]:
# Example usage
extracted_data = extract_multi_needle(schema=TechCompany, haystack=haystack_text, example_needles=example_needles)

# Serialize the extracted data to a JSON file
with open('extracted_needles.json', 'w') as f:
    json.dump([item.dict() for item in extracted_data], f, indent=2)
