In [1]:
# !pip install sentence_transformers
# !pip install langchain
# !pip install faiss-gpu langchain_openai

In [2]:
import os
import re
import json
import csv
import numpy as np
import concurrent.futures
from typing import Type, List, TypeVar
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
from tqdm import tqdm  # For progress bar visualization

In [3]:
T = TypeVar('T', bound=BaseModel)

def extract_multi_needle(schema: Type[T], haystack: str, example_needles: List[str]) -> List[T]:
    """
    Extracts information from a large text (haystack) based on example sentences (needles)
    and a defined schema. Returns a list of extracted data conforming to the schema.
    """
    # Initialize the list to hold the extracted data
    extracted_needles = []

    # Initialize the SentenceTransformer model for embeddings (fast and efficient)
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

    # Split the haystack into individual sentences using regex
    sentences = re.split(r'(?<=[.!?])\s+', haystack)

    # Compute embeddings for the sentences in the haystack
    sentence_embeddings = embedding_model.encode(
        sentences, batch_size=256, show_progress_bar=True
    )

    # Compute embeddings for the example needles
    example_embeddings = embedding_model.encode(
        example_needles, batch_size=256, show_progress_bar=True
    )

    # Normalize embeddings to unit vectors for cosine similarity calculation
    sentence_embeddings_normalized = sentence_embeddings / np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
    example_embeddings_normalized = example_embeddings / np.linalg.norm(example_embeddings, axis=1, keepdims=True)

    # Compute cosine similarities between example needles and sentences
    cosine_similarities = np.dot(example_embeddings_normalized, sentence_embeddings_normalized.T)

    # Set a similarity threshold to select relevant sentences
    similarity_threshold = 0.3  # Adjust this value as needed

    # Get indices of sentences that have similarity above the threshold
    candidate_indices = np.argwhere(cosine_similarities >= similarity_threshold)[:, 1]

    # Retrieve the candidate sentences based on the indices
    candidate_sentences = [sentences[idx] for idx in set(candidate_indices)]

    # Initialize the Azure OpenAI LLM model
    model = AzureChatOpenAI(
        openai_api_version=os.environ.get("AZURE_OPENAI_VERSION", "2023-03-15-preview"),
        azure_deployment=os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-mini"),
        azure_endpoint=os.environ.get(
            "AZURE_OPENAI_ENDPOINT",
            "https://gptmini4o.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview"
        ),
        openai_api_key=os.environ.get("AZURE_OPENAI_KEY", "c2105be0c2744742980b57320b87e813"),
    )

    # Generate a description of the schema to include in the prompts
    schema_description = generate_schema_description(schema)

    # Generate keywords using the LLM
    keywords = generate_keywords(example_needles, schema_description, model)

    # Include sentences that contain any of the generated keywords
    keyword_sentences = [
        sentence for sentence in sentences
        if any(keyword.lower() in sentence.lower() for keyword in keywords)
    ]

    # Combine the candidate sentences from embeddings and keyword matching
    candidate_sentences = list(set(candidate_sentences).union(set(keyword_sentences)))

    print(f"Number of candidate sentences: {len(candidate_sentences)}")

    # Construct the system prompt with schema description for the LLM
    system_prompt = f"""
You are an assistant that extracts information from text according to a given schema.

The schema is:
{schema_description}

Your task is to read the provided text and extract any information that matches the schema.

Provide the extracted data as a JSON object conforming to the schema.

If the text does not contain relevant information, output an empty JSON object.

Only provide the JSON object, and no additional text.

Consider variations in sentence structure and wording. Extract information even if the text differs from the examples.
"""

    # Process the candidate sentences in parallel using threading
    extracted_needles = process_sentences_in_parallel(candidate_sentences, system_prompt, model, schema)

    return extracted_needles

def generate_schema_description(schema: Type[BaseModel]) -> str:
    """
    Generates a text description of the schema fields and their types.
    """
    schema_description = ""
    for field_name, field in schema.__fields__.items():
        field_desc = field.description or ''
        field_type = (
              field.annotation.__name__ if hasattr(field.annotation, '__name__') else str(field.annotation)
          )
        schema_description += f"- {field_name} ({field_type}): {field_desc}\n"
    return schema_description

def generate_keywords(example_needles: List[str], schema_description: str, model) -> List[str]:
    """
    Uses the LLM to generate a list of keywords based on the example needles and schema.
    """
    # Construct a prompt for the LLM
    prompt = f"""Given the following schema and example needles, generate a list of keywords that would be useful for identifying relevant sentences in a text. The keywords should be related to the schema fields and the type of information we're looking for.

Schema:
{schema_description}

Example needles:
{', '.join(example_needles)}

Please provide a comma-separated list of approximately 10 keywords. These keywords should be closed words in english, i.e., there are needles present in the haystack which are structurally very similar to the example needles. Choose keywords that will help identify these other similar needles as well."""
    # Create the conversation messages for the LLM
    messages = [
        SystemMessage(content="You are an assistant that extracts keywords from text."),
        HumanMessage(content=prompt)
    ]

    # Call the LLM to generate keywords
    response = model.invoke(messages, temperature = 0.3)

    # Parse the response to extract keywords
    keywords_text = response.content.strip()
    keywords = [kw.strip() for kw in keywords_text.split(',') if kw.strip()]
    return keywords

def process_sentences_in_parallel(candidate_sentences: List[str], system_prompt: str, model, schema: Type[T]) -> List[T]:
    """
    Processes multiple sentences in parallel using threading to make API calls concurrently.
    Returns a list of extracted data conforming to the schema.
    """
    results = {}
    error_indices = []

    # Use ThreadPoolExecutor to process sentences concurrently
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit tasks to the executor
        future_to_index = {
            executor.submit(process_sentence, index, text, system_prompt, model, schema): index
            for index, text in enumerate(candidate_sentences)
        }

        # Process the futures as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_index),
                           total=len(candidate_sentences), desc="Processing Sentences"):
            index = future_to_index[future]
            try:
                result = future.result()
                results[index] = result
            except Exception as exc:
                print(f'Sentence {index} generated an exception: {exc}')
                results[index] = None
                error_indices.append(index)

    # Collect the extracted items from results
    extracted_items = [result for result in results.values() if result is not None]

    return extracted_items

def process_sentence(index: int, text: str, system_prompt: str, model, schema: Type[T]) -> T:
    """
    Processes a single sentence using the LLM to extract information according to the schema.
    Returns an instance of the schema if data is extracted, or None otherwise.
    """
    # Create the conversation messages for the LLM
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=text)
    ]

    # Call the LLM to process the text
    response = model.invoke(messages, temperature = 0.6)

    # Attempt to parse the LLM response as JSON
    try:
        data = json.loads(response.content)
        if data:  # If data is not empty
            # Validate and instantiate the schema with the extracted data
            item = schema(**data)
            return item
    except json.JSONDecodeError:
        print(f"JSONDecodeError for sentence {index}: {text}")
        print(f"LLM response: {response.content}")
    except Exception as e:
        print(f"Exception for sentence {index}: {text}")
        print(f"Error: {e}")
    return None


In [4]:
from typing import Optional
from pydantic import BaseModel, Field

class TechCompany(BaseModel):
    name: Optional[str] = Field(default=None, description="The full name of the technology company")
    location: Optional[str] = Field(default=None, description="City and country where the company is headquartered")
    employee_count: Optional[int] = Field(default=None, description="Total number of employees")
    founding_year: Optional[int] = Field(default=None, description="Year the company was established")
    is_public: Optional[bool] = Field(default=None, description="Whether the company is publicly traded (True) or privately held (False)")
    valuation: Optional[float] = Field(default=None, description="Company's valuation in billions of dollars")
    primary_focus: Optional[str] = Field(default=None, description="Main area of technology or industry the company focuses on")

In [5]:
example_needles = ["Ryoshi, based in Neo Tokyo, Japan, is a private quantum computing firm founded in 2031, currently valued at $8.7 billion with 1,200 employees focused on quantum cryptography."]

In [6]:
with open("haystack.txt", "r") as file:
    haystack_text = file.read()

In [7]:
# Example usage
extracted_data = extract_multi_needle(schema=TechCompany, haystack=haystack_text, example_needles=example_needles)

# Serialize the extracted data to a JSON file
with open('extracted_needles.json', 'w') as f:
    json.dump([item.dict() for item in extracted_data], f, indent=2)

Batches:   0%|          | 0/1070 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Number of candidate sentences: 36


Processing Sentences: 100%|██████████| 36/36 [00:02<00:00, 16.37it/s]


In [8]:
def json_to_csv(json_file: str, csv_file: str):
    """
    Converts a JSON file to a CSV file.

    Args:
        json_file (str): The path to the input JSON file.
        csv_file (str): The path to the output CSV file.
    """
    # Check if the JSON file exists
    if not os.path.exists(json_file):
        print(f"File {json_file} not found.")
        return

    # Read the JSON data
    with open(json_file, 'r') as f:
        data = json.load(f)

    # If the JSON data is a dictionary, convert it to a list of dictionaries
    if isinstance(data, dict):
        data = [data]

    # Get the keys (column names) from the first item
    fieldnames = data[0].keys()

    # Write data to the CSV file
    with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(data)

    print(f"CSV file saved as {csv_file}")

In [9]:
csv_file = "needles.csv"
json_to_csv("extracted_needles.json", csv_file)
print(f"Extracted data saved to {csv_file}")

CSV file saved as needles.csv
Extracted data saved to needles.csv
