In [1316]:
from datasets import load_dataset_builder, load_dataset, get_dataset_infos, get_dataset_config_names, list_datasets
from langchain.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from operator import itemgetter
from sklearn.model_selection import train_test_split
from typing import List, Dict, Any, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from tqdm import tqdm
import json
import yaml
import pandas as pd
import numpy as np

In [1317]:
# list_datasets()

# Data

In [1318]:
def load_dataset_by_name(name):
    df = pd.read_csv(f"data/{name}_data.csv")
    df.columns.values[0] = 'sentence'
    df.columns.values[1] = 'classification'      
    df = df.dropna()
    return df

In [1319]:
def split_datasets(df_dict, test_size=0.2):
    df_dict_train = {}
    df_dict_test = {}

    for dataset, df in df_dict.items():
        df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
        df_dict_train[dataset] = df_train
        df_dict_test[dataset] = df_test
    return df_dict_train, df_dict_test

def construct_example_strings(df, str_fn=None):
    df_dict_examples = {}
    for dataset_name, dataset_df in df.items():
        if str_fn:
            df_dict_examples[dataset_name] = dataset_df.apply(str_fn, axis=1).tolist()
        else:
            df_dict_examples[dataset_name] = dataset_df.apply(lambda x: f"{x['sentence']}[{x['classification']}]", axis=1).tolist()
    return df_dict_examples

In [1320]:
my_datasets = ["lowercase", 
               "imdb_numbers", 
               "imdb_digits", 
               "backpack", 
               "sd_addition",
               "gpt_digits"]
df_dict = {name: load_dataset_by_name(name) for name in my_datasets}

dict_train, dict_test = split_datasets(df_dict)

str_fn = lambda x: f"sentence: {x['sentence']}\nclassification: {x['classification']}"
dict_examples = construct_example_strings(dict_train, str_fn=str_fn)

In [1321]:
# print a few examples from dict_examples
for dataset, examples in dict_examples.items():
    print(f"Dataset: {dataset}")
    print("Examples:")
    for example in examples[:3]:
        print(example)
    print()

Dataset: lowercase
Examples:
sentence: i like to watch movies on weekends.
classification: True
sentence: going to the park is always fun.
classification: True
sentence: the cat sleeps on the sofa.
classification: True

Dataset: imdb_numbers
Examples:
sentence: I have never worked as an office temp before, but after seeing this movie, I have a very clear image as to what it must be like.
classification: False
sentence: Two or even three movies for the price of one! The first is a travelog that was shot somewhere south of the US border.
classification: True
sentence: Ring 0 is the Japanese prequel to the Ring Mythos.
classification: True

Dataset: imdb_digits
Examples:
sentence: One wishes for more for more talent in such a film! The actors looked like they're fresh from acting class! Amateurish! Come on, the filmmaker could have done better.
classification: False
sentence: I saw this on cable last night, just 2 days after seeing the Sprecher sisters' latest film, 13 Conversations About

# Config

In [1338]:
# model_types = {"gpt-4": "gpt-4-1106-preview", "gpt-3":"gpt-3.5-turbo-1106"}

@dataclass
class RunArgs():
    n_test = 20
    n_examples = 20
    dataset_name = "sd_addition"
    model = "gpt-4-1106-preview"
    template = "templates/class_format_0.yaml"
    randomize_inputs = True
    randomize_examples = False
    label_rename_map = {"True":"Correct", "False":"Incorrect"}
    explanation_template = "templates/explain_1.yaml"
args = RunArgs()

# Model

#### Utility Functions

In [1339]:
def get_output_parser():
    response_schemas = [
        ResponseSchema(name="sentence", description="the sentence to classify"),
        ResponseSchema(name="classification", description="the classification")
    ]
    output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
    return output_parser

def get_inputs(dataset_name: str, n_test: int, randomize: bool = False) -> pd.DataFrame:
    df_inputs = dict_test[dataset_name].copy()

    if randomize:
        df_inputs = df_inputs.sample(n_test)
    else: 
        df_inputs = df_inputs[:n_test]
    
    df_inputs.columns = ['sentence', 'classification']
    df_inputs['classification'].astype(bool)
    if args.label_rename_map:
        df_inputs.loc[:, 'classification'] = df_inputs['classification'].apply(lambda x: args.label_rename_map[str(x)])

    return df_inputs

def get_examples(dataset_name: str, n_examples: int, randomize_examples: bool = False, label_rename_map: Dict = None) -> pd.DataFrame:
    df_examples = dict_train[dataset_name].copy()

    if randomize_examples:
        df_examples = df_examples.sample(n_examples)
    else:
        df_examples = df_examples[:n_examples]
    
    df_examples.columns = ['sentence', 'classification']
    df_examples['classification'].astype(bool)
    if label_rename_map:
        df_examples['classification'] = df_examples['classification'].apply(lambda x: label_rename_map[str(x)])

    return df_examples

def json_formatted_examples(examples = pd.DataFrame) -> List[str]:
    json_examples = []
    for index, row in examples.iterrows():
        json_examples.append(f"{{sentence: {row['sentence']}, classification: {row['classification']}}}")


    # json_examples = examples.to_json(orient="records")
    return json_examples

def simple_formatted_examples(examples = pd.DataFrame) -> List[str]:
    simple_examples = []
    for index, row in examples.iterrows():
        simple_examples.append(f"{row['sentence']} [{row['classification']}]")
    return simple_examples

def set_a_or_b_examples(examples = pd.DataFrame) -> List[str]:
    sets = {"Set_A": [], "Set_B": []}
    for _, row in examples.iterrows():
        if row['classification'] == "Set_A":
            sets["Set_A"].append(row['sentence'])
        else:
            sets["Set_B"].append(row['sentence'])
    # print(sets['Set_A'])
    return sets['Set_A'], sets['Set_B']

def prompt_str(template, inputs):
    return template.invoke(inputs).to_string()

def construct_explanation_inputs(examples, inputs):
    explanation_inputs = {"examples": examples, "inputs": inputs}
    return explanation_inputs
    

#### Components

In [1340]:
# model
model = ChatOpenAI(model=args.model)

# prompt
classification_template = load_prompt(args.template)

# output parser
output_parser = get_output_parser()

# Run Chain

In [1341]:
examples = get_examples(
    args.dataset_name, 
    args.n_examples, 
    randomize_examples=args.randomize_examples,
    label_rename_map=args.label_rename_map
)

# print some examples
print_examples = simple_formatted_examples(examples)
for i in range(10):
    print(print_examples[i])

2 + 6 = 8 [Correct]
7 + 5 = 8 [Incorrect]
9 + 4 = 13 [Correct]
0 + 7 = 7 [Correct]
8 + 9 = 0 [Incorrect]
1 + 3 = 4 [Correct]
9 + 6 = 15 [Correct]
6 + 3 = 9 [Correct]
6 + 6 = 12 [Correct]
4 + 4 = 8 [Correct]


In [1342]:
inputs = get_inputs(
    args.dataset_name, 
    args.n_test, 
    randomize=args.randomize_inputs
)

In [1343]:
chain = (
    classification_template
    | model
    | output_parser if "format_instructions" in str(classification_template) else RunnablePassthrough()
)

In [1332]:
# mixed
if "format_instructions" in str(classification_template):
    examples = json_formatted_examples(examples)
    examples = "\n".join(examples)
    chain_inputs = {"input": None, "examples": examples, "format_instructions": output_parser.get_format_instructions()}
else:
    chain_inputs = {"input": None, "examples": examples}

# grouped into sets
if "set" in args.template:
    set_a, set_b = set_a_or_b_examples(examples)
    chain_inputs = {"input": None, "set_a": set_a, "set_b": set_b}

In [1333]:
print(prompt_str(classification_template, chain_inputs))

INSTRUCTIONS: 
Your task is to classify a sentence as TypeA or TypeB. You will be given a list of examples of TypeA and TypeB sentences. Identify the patterns that make a sentence TypeA or TypeB and use those patterns to classify the sentence.

FORMAT INSTRUCTIONS:
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"sentence": string  // the sentence to classify
	"classification": string  // the classification
}
```

EXAMPLES:
The examples below are references for correct classifications of the sentences
{sentence: 2 + 6 = 8, classification: Correct}
{sentence: 7 + 5 = 8, classification: Incorrect}
{sentence: 9 + 4 = 13, classification: Correct}
{sentence: 0 + 7 = 7, classification: Correct}
{sentence: 8 + 9 = 0, classification: Incorrect}
{sentence: 1 + 3 = 4, classification: Correct}
{sentence: 9 + 6 = 15, classification: Correct}
{sentence: 6 + 3 = 9, classification: Correct}
{sentence: 

In [1268]:
results = inputs.copy()

_print = False

for index, row in tqdm(inputs.iterrows()):
    chain_inputs['input'] = row['sentence']
    prediction = chain.invoke(chain_inputs)

    if _print:
        print(f"Input: {row['sentence']}")
        print(f"Ground Truth: {row['classification']}")
        print("Prediction: " + prediction + "\n")

    # update results
    results.loc[index, 'prediction'] = prediction['classification']

20it [00:41,  2.09s/it]


In [1305]:
results

Unnamed: 0,sentence,classification,prediction
78,I enjoyed this movie.,TypeB,TypeB
0,This is a true classic starring Orson Wells in...,TypeB,TypeA
30,I am a die hard fan of the original Ginger Sna...,TypeA,TypeB
65,No kid can take on twice his size.,TypeB,TypeB
68,I rate it a 10 by how it affected me when I fi...,TypeA,TypeA
42,"The third movie produced by Howard Hughes, thi...",TypeB,TypeB
80,I don't really need to say anything about the ...,TypeB,TypeB
4,Since childhood I've found the actors playing ...,TypeA,TypeB
12,"As was noted by Cine Tiger, this excellent sil...",TypeB,TypeA
56,Chronologically situated between The World in ...,TypeA,TypeA


In [1270]:
correct = results['classification'] == results['prediction']
print(f"percent correct: {correct.mean()}")

percent correct: 0.6


# Explanation

In [1347]:
explain_args = RunArgs()
explain_args.explanation_template = "templates/explain_1.yaml"
explain_args.n_test = 20

In [1344]:
explanation_template = load_prompt(explain_args.explanation_template)

explanation_chain = (
    explanation_template
    | model
    | StrOutputParser()
)

In [1345]:
examples = get_examples(
    explain_args.dataset_name, 
    explain_args.n_examples, 
    randomize_examples=explain_args.randomize_examples,
    label_rename_map=explain_args.label_rename_map
)
inputs = get_inputs(
    explain_args.dataset_name, 
    explain_args.n_test, 
    randomize=explain_args.randomize_inputs
)
explanation_inputs = construct_explanation_inputs(examples, inputs)

In [1346]:
explanation_inputs = {"examples": examples}
for s in explanation_chain.stream(explanation_inputs):
    print(s, end="")

A sentence is classified as `Correct` if the sum of the two numbers is mathematically accurate.

In [1290]:
explanation_inputs = {"examples": examples}
explanation_chain.invoke(explanation_inputs)

KeyboardInterrupt: 