In [None]:
from llama_index import SimpleDirectoryReader
from llama_index.indices.service_context import ServiceContext
from llama_index.llms import OpenAI
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser.extractors import (
    MetadataExtractor,    
)
from llama_index.text_splitter import TokenTextSplitter
from llama_index.node_parser.extractors.marvin_entity_extractor import MarvinEntityExtractor

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "<Your OpenAI API Key>"


In [None]:
documents = SimpleDirectoryReader("data").load_data()

In [None]:
import marvin
from marvin import ai_model
from pydantic import BaseModel, Field

marvin.settings.openai.api_key = os.environ["OPENAI_API_KEY"]

@ai_model
class SportsSupplement(BaseModel):
    name: str = Field(..., description="The name of the sports supplement")
    description: str = Field(..., description="A description of the sports supplement")
    pros_cons: str = Field(..., description="The pros and cons of the sports supplement")

In [None]:
llm_model = "gpt-3.5-turbo"

llm = OpenAI(temperature=0.1, model_name=llm_model, max_tokens=512)
service_context = ServiceContext.from_defaults(llm=llm)

#construct text splitter to split texts into chunks for processing
text_splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=128)

#set the global service context object, avoiding passing service_context when building the index 
from llama_index import set_global_service_context
set_global_service_context(service_context)

#create metadata extractor
metadata_extractor = MetadataExtractor(
    extractors=[
        MarvinEntityExtractor(marvin_model=SportsSupplement, llm_model_string=llm_model), #let's extract custom entities for each node.
    ],
)

#create node parser to parse nodes from document
node_parser = SimpleNodeParser(
    text_splitter=text_splitter,
    metadata_extractor=metadata_extractor,
)

#use node_parser to get nodes from documents
nodes = node_parser.get_nodes_from_documents(documents)

In [None]:
from pprint import pprint
for i in range(5):    
    pprint(nodes[i].metadata)