In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import sys

# If running from github repo, can use this:
sys.path.append(str(Path().cwd().parent.resolve()))

# Uncomment for more debugging printouts.
"""
import logging
root = logging.getLogger()
root.setLevel(logging.DEBUG)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)
"""
None

In [None]:
from trulens_eval.keys import check_keys

check_keys(
    "OPENAI_API_KEY",
    "HUGGINGFACE_API_KEY"
)

In [None]:
from langchain.llms import OpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from trulens_eval import TruChain, Feedback, Tru, Query, FeedbackMode
from trulens_eval import OpenAI as OAI

Tru().start_dashboard(_dev=Path().cwd().parent.resolve(), force=True)

In [None]:
open_ai = OAI()

# Define a language match feedback function using HuggingFace.
mod_not_hate = Feedback(open_ai.moderation_not_hate).on(text=Query.RecordInput[:].page_content)

def wrap_chain_trulens(chain):
    return TruChain(
        chain,
        app_id='ChainOAI',
        feedbacks=[mod_not_hate],
        feedback_mode=FeedbackMode.WITH_APP # calls to TruChain will block until feedback is done evaluating
    )

def get_summary_model(text):
    """
    Produce summary chain, given input text.
    """

    llm = OpenAI(
        temperature=0,
        openai_api_key=""
    )
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", " "], chunk_size=8000, chunk_overlap=350
    )
    docs = text_splitter.create_documents([text])
    print(f"You now have {len(docs)} docs instead of 1 piece of text.")

    return docs, load_summarize_chain(llm=llm, chain_type='map_reduce')

In [None]:
from datasets import load_dataset
billsum = load_dataset("billsum", split="ca_test")
text = billsum['text'][0]

docs, chain = get_summary_model(text)
output, record = wrap_chain_trulens(chain).call_with_record(docs)