In [3]:
from datasets import load_dataset_builder, load_dataset, get_dataset_infos, get_dataset_config_names, list_datasets
from langchain.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from operator import itemgetter
from sklearn.model_selection import train_test_split
from typing import List, Dict, Any, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from tqdm import tqdm
import json
import yaml
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# list_datasets()

# Data

In [38]:
def load_dataset_by_name(name):
    df = pd.read_csv(f"data/{name}_data.csv")
    df.columns.values[0] = 'sentence'
    df.columns.values[1] = 'classification'      
    df = df.dropna()
    return df

In [39]:
def split_datasets(df_dict, test_size=0.2):
    df_dict_train = {}
    df_dict_test = {}

    for dataset, df in df_dict.items():
        df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
        df_dict_train[dataset] = df_train
        df_dict_test[dataset] = df_test
    return df_dict_train, df_dict_test

def construct_example_strings(df, str_fn=None):
    df_dict_examples = {}
    for dataset_name, dataset_df in df.items():
        if str_fn:
            df_dict_examples[dataset_name] = dataset_df.apply(str_fn, axis=1).tolist()
        else:
            df_dict_examples[dataset_name] = dataset_df.apply(lambda x: f"{x['sentence']}[{x['classification']}]", axis=1).tolist()
    return df_dict_examples

In [40]:
easy_datasets = ["imdb_numbers", 
               "imdb_digits", 
               "backpack",]

hard_datasets = ["lowercase", 
               "sd_addition",
               "dd_addition",
               "simple_punctuation",
               "medium_punctuation",
               "pronoun",
               "past_tense",
               "gpt_digits"]
df_dict = {name: load_dataset_by_name(name) for name in easy_datasets}

dict_train, dict_test = split_datasets(df_dict)

str_fn = lambda x: f"sentence: {x['sentence']}\nclassification: {x['classification']}"
dict_examples = construct_example_strings(dict_train, str_fn=str_fn)

In [41]:
# print a few examples from dict_examples
for dataset, examples in dict_examples.items():
    print(f"Dataset: {dataset}")
    print("Examples:")
    for example in examples[:3]:
        print(example)
    print()

Dataset: lowercase
Examples:
sentence: i like to watch movies on weekends.
classification: True
sentence: going to the park is always fun.
classification: True
sentence: the cat sleeps on the sofa.
classification: True

Dataset: imdb_numbers
Examples:
sentence: I have never worked as an office temp before, but after seeing this movie, I have a very clear image as to what it must be like.
classification: False
sentence: Two or even three movies for the price of one! The first is a travelog that was shot somewhere south of the US border.
classification: True
sentence: Ring 0 is the Japanese prequel to the Ring Mythos.
classification: True

Dataset: imdb_digits
Examples:
sentence: One wishes for more for more talent in such a film! The actors looked like they're fresh from acting class! Amateurish! Come on, the filmmaker could have done better.
classification: False
sentence: I saw this on cable last night, just 2 days after seeing the Sprecher sisters' latest film, 13 Conversations About

# Config

In [42]:
# model_types = {"gpt-4": "gpt-4-1106-preview", "gpt-3":"gpt-3.5-turbo-1106"}

@dataclass
class RunArgs():
    n_test = 20
    n_examples = 40
    dataset_name = "dd_addition"
    model = "gpt-3.5-turbo-1106"
    template = "templates/classification_9.yaml"
    randomize_inputs = True
    randomize_examples = False
    label_rename_map = {"True":"true", "False":"false"}
    explanation_template = "templates/explain_1.yaml"
args = RunArgs()

# Model

#### Utility Functions

In [43]:
def get_output_parser():
    response_schemas = [
        ResponseSchema(name="sentence", description="the sentence to classify"),
        ResponseSchema(name="classification", description="the classification")
    ]
    output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
    return output_parser

def get_inputs(dataset_name: str, n_test: int, randomize: bool = False) -> pd.DataFrame:
    df_inputs = dict_test[dataset_name].copy()

    if randomize:
        df_inputs = df_inputs.sample(n_test)
    else: 
        df_inputs = df_inputs[:n_test]
    
    df_inputs.columns = ['sentence', 'classification']
    df_inputs['classification'].astype(bool)
    if args.label_rename_map:
        df_inputs.loc[:, 'classification'] = df_inputs['classification'].apply(lambda x: args.label_rename_map[str(x)])

    return df_inputs

def get_examples(dataset_name: str, n_examples: int, randomize_examples: bool = False, label_rename_map: Dict = None) -> pd.DataFrame:
    df_examples = dict_train[dataset_name].copy()

    if randomize_examples:
        df_examples = df_examples.sample(n_examples)
    else:
        df_examples = df_examples[:n_examples]
    
    df_examples.columns = ['sentence', 'classification']
    df_examples['classification'].astype(bool)
    if label_rename_map:
        df_examples['classification'] = df_examples['classification'].apply(lambda x: label_rename_map[str(x)])

    return df_examples

def json_formatted_examples(examples = pd.DataFrame) -> List[str]:
    json_examples = []
    for index, row in examples.iterrows():
        json_examples.append(f"{{sentence: {row['sentence']}, classification: {row['classification']}}}")


    # json_examples = examples.to_json(orient="records")
    return json_examples

def simple_formatted_examples(examples = pd.DataFrame) -> List[str]:
    simple_examples = []
    for index, row in examples.iterrows():
        simple_examples.append(f"{row['sentence']} [{row['classification']}]")
    return simple_examples

def numbered_formatted_inputs(inputs = pd.DataFrame) -> List[str]:
    numbered_inputs = []
    for i in range(len(inputs)):
        numbered_inputs.append(f"{i+1}. {inputs.iloc[i]['sentence']}")
    return numbered_inputs

def set_a_or_b_examples(examples = pd.DataFrame) -> List[str]:
    sets = {"Set_A": [], "Set_B": []}
    for _, row in examples.iterrows():
        if row['classification'] == "Set_A":
            sets["Set_A"].append(row['sentence'])
        else:
            sets["Set_B"].append(row['sentence'])
    # print(sets['Set_A'])
    return sets['Set_A'], sets['Set_B']

def prompt_str(template, inputs):
    return template.invoke(inputs).to_string()

def construct_explanation_inputs(examples, inputs):
    explanation_inputs = {"examples": examples, "inputs": inputs}
    return explanation_inputs
    

#### Components

In [44]:
# model
model = ChatOpenAI(model=args.model)

# prompt
classification_template = load_prompt(args.template)

# output parser
output_parser = get_output_parser()

# Run Chain

In [45]:
examples = get_examples(
    args.dataset_name, 
    args.n_examples, 
    randomize_examples=args.randomize_examples,
    label_rename_map=args.label_rename_map
)

# print some examples
print_examples = simple_formatted_examples(examples)
for i in range(10):
    print(print_examples[i])

71 + 62 = 69 [false]
70 + 72 = 142 [true]
82 + 78 = 160 [true]
27 + 96 = 70 [false]
68 + 85 = 153 [true]
22 + 51 = 73 [true]
87 + 34 = 67 [false]
33 + 21 = 110 [false]
30 + 84 = 58 [false]
67 + 93 = 28 [false]


In [46]:
inputs = get_inputs(
    args.dataset_name, 
    args.n_test, 
    randomize=args.randomize_inputs
)

In [47]:
label_chain = (
    classification_template
    | model
    # | output_parser if "format_instructions" in str(classification_template) else RunnablePassthrough()
)

In [48]:
# mixed
if "format_instructions" in str(classification_template):
    examples = json_formatted_examples(examples)
    examples = "\n".join(examples)
    chain_inputs = {"input": None, "examples": examples, "format_instructions": output_parser.get_format_instructions()}
else:
    examples = simple_formatted_examples(examples)
    examples = "\n".join(examples)
    chain_inputs = {"input": None, "examples": examples}

# grouped into sets
if "set" in args.template:
    set_a, set_b = set_a_or_b_examples(examples)
    chain_inputs = {"input": None, "set_a": set_a, "set_b": set_b}

In [49]:
print(prompt_str(classification_template, chain_inputs))

Objective: You are to analyze sets of sentences to identify a simple, undisclosed feature shared by some of them. Your task is to apply this feature to label a new set of sentences as 'true' or 'false'.

**Instructions:**

1. **Review the Provided Examples**: Below are sets of sentences labeled as 'true' or 'false'. All 'true' sentences share a simple feature.

  `71 + 62 = 69 [false]
70 + 72 = 142 [true]
82 + 78 = 160 [true]
27 + 96 = 70 [false]
68 + 85 = 153 [true]
22 + 51 = 73 [true]
87 + 34 = 67 [false]
33 + 21 = 110 [false]
30 + 84 = 58 [false]
67 + 93 = 28 [false]
79 + 58 = 153 [false]
48 + 26 = 178 [false]
65 + 26 = 91 [true]
94 + 76 = 170 [true]
59 + 30 = 89 [true]
21 + 81 = 119 [false]
84 + 39 = 123 [true]
24 + 22 = 120 [false]
18 + 38 = 131 [false]
86 + 66 = 152 [true]
89 + 28 = 61 [false]
12 + 15 = 27 [true]
16 + 68 = 120 [false]
27 + 49 = 110 [false]
61 + 77 = 33 [false]
99 + 27 = 126 [true]
70 + 53 = 61 [false]
60 + 92 = 104 [false]
73 + 14 = 148 [false]
51 + 13 = 164 [fal

#### Classify Multiple

In [50]:
input = "\n".join(numbered_formatted_inputs(inputs))
chain_inputs = {"input": input, "examples": examples}

out = label_chain.invoke(chain_inputs)

In [51]:
print(out.content)

Based on the analysis, the simple feature shared by the 'true' sentences is that the sum of the first and second numbers is equal to the reverse of the third number. For example, 70 + 72 = 142, where 70 + 72 = 142, and 82 + 78 = 160, where 82 + 78 = 160. 

Using this feature, the labeled sentences are as follows:

1. 48 + 12 = 60 [false]
2. 19 + 41 = 194 [false]
3. 89 + 13 = 102 [true]
4. 80 + 85 = 96 [false]
5. 18 + 20 = 38 [true]
6. 16 + 85 = 57 [false]
7. 24 + 21 = 128 [false]
8. 10 + 70 = 80 [false]
9. 72 + 73 = 145 [false]
10. 75 + 51 = 126 [false]
11. 80 + 52 = 103 [false]
12. 93 + 14 = 30 [false]
13. 24 + 61 = 107 [true]
14. 91 + 20 = 111 [true]
15. 59 + 19 = 78 [false]
16. 89 + 12 = 101 [true]
17. 89 + 75 = 164 [false]
18. 66 + 44 = 110 [true]
19. 83 + 50 = 60 [false]
20. 69 + 79 = 148 [false]


#### Classify One

In [1461]:
results = inputs.copy()

_print = True

for index, row in tqdm(inputs.iterrows()):
    chain_inputs['input'] = row['sentence']
    prediction = label_chain.invoke(chain_inputs)
    print(prediction)

    # if _print:
    #     print(f"Input: {row['sentence']}")
    #     print(f"Ground Truth: {row['classification']}")
    #     print("Prediction: " + prediction + "\n")

    # update results
    # results.loc[index, 'prediction'] = 'a'

1it [00:02,  2.40s/it]

content="Based on the analysis of the provided examples, the simple feature shared by the 'true' sentences seems to be the presence of items related to outdoor activities or travel. These items include clothing, accessories, and equipment commonly used for activities like hiking, camping, beach outings, sports, and exploration.\n\nApplying this hypothesis to the unlabeled set of sentences:\n- Cliff bars, hydration pack, GPS watch, map, compass\n\nBased on the presence of items commonly used for outdoor activities and exploration, the label for this set of sentences would be 'true'."


2it [00:04,  2.01s/it]

content="Hypothesis: The simple feature shared by the 'true' sentences is that they all contain items related to outdoor or recreational activities.\n\nApplication to Unlabeled Sentences:\n- Gardening gloves, seeds, trowel, pruning shears, watering can: true\n\nJustification: The items in the sentence are related to outdoor activities, specifically gardening, which aligns with the simple feature identified in the 'true' sentences. Therefore, the label for this sentence is 'true'."


3it [00:07,  2.45s/it]

content="Based on the analysis of the provided examples, the simple, undisclosed feature that is shared by the 'true' sentences is that they all contain items related to specific activities or environments (e.g., ballet, beach, hiking, cycling, etc.). On the other hand, the 'false' sentences seem to lack this specificity and contain more general items (e.g., binder, lined paper, graph paper, etc.).\n\nApplying this hypothesis to the new set of sentences:\n\nNotebook, pen, water bottle, laptop, headphones, calculator, USB drive\n\nBased on the hypothesis, this set of sentences would be labeled as 'false' because it contains more general items and does not relate to a specific activity or environment.\n\nTherefore, the label for the new set of sentences is 'false'."


4it [00:09,  2.35s/it]

content="Based on the analysis of the provided examples, the simple feature shared by the 'true' sentences is that they all contain items related to a specific activity or purpose (e.g., ballet, hiking, swimming, etc.). In contrast, the 'false' sentences seem to lack a clear theme or purpose, and the items listed appear to be random.\n\nApplying this hypothesis to the new set of sentences:\n\n- Tank top, yoga pants, sports bra, water bottle, headphones\n\nBased on the presence of items related to a specific activity (yoga or exercise), this set of sentences would be labeled as 'true'."


5it [00:12,  2.43s/it]

content="Based on the analysis, the simple, undisclosed feature shared by the 'true' sentences is that they all contain items related to specific activities or environments (e.g., ballet, beach, cycling, hiking, soccer, rain, camping, etc.). On the other hand, the 'false' sentences seem to consist of random items without a clear connection or theme.\n\nApplying this hypothesis to the new set of sentences:\n- Knit hat, ski jacket, gloves, goggles, neck warmer\n\nBased on the hypothesis, this set of sentences would be labeled as 'true' because the items are related to the specific activity or environment of skiing.\n\nLabel with Justification:\n- Knit hat, ski jacket, gloves, goggles, neck warmer : true (related to skiing)"





In [1433]:
results

Unnamed: 0,sentence,classification,prediction
0,"Notebook, pen, water bottle, laptop, headphone...",False,a
65,"Tank top, yoga pants, sports bra, water bottle...",True,a
81,"Knit hat, ski jacket, gloves, goggles, neck wa...",True,a
86,"Cliff bars, hydration pack, GPS watch, map, co...",False,a
56,"Frisbee, sunscreen, beach towel, sunglasses, w...",False,a
44,"Action figures, comic books, trading cards, sn...",False,a
28,"Gardening gloves, seeds, trowel, pruning shear...",False,a
40,"Hammer, screwdriver, tape measure, nails, safe...",False,a
10,"Camera, memory cards, tripod, extra batteries,...",False,a
22,"Fishing line, hooks, bait, fishing license, su...",False,a


In [1270]:
correct = results['classification'] == results['prediction']
print(f"percent correct: {correct.mean()}")

percent correct: 0.6


# Explanation

In [1348]:
explain_args = RunArgs()
explain_args.dataset_name = "backpack"
explain_args.explanation_template = "templates/explain_1.yaml"
explain_args.n_test = 50

In [1349]:
explanation_template = load_prompt(explain_args.explanation_template)

explanation_chain = (
    explanation_template
    | model
    | StrOutputParser()
)

In [1350]:
examples = get_examples(
    explain_args.dataset_name, 
    explain_args.n_examples, 
    randomize_examples=explain_args.randomize_examples,
    label_rename_map=explain_args.label_rename_map
)
inputs = get_inputs(
    explain_args.dataset_name, 
    explain_args.n_test, 
    randomize=explain_args.randomize_inputs
)
explanation_inputs = construct_explanation_inputs(examples, inputs)

KeyError: ' "True"'

In [1346]:
explanation_inputs = {"examples": examples}
for s in explanation_chain.stream(explanation_inputs):
    print(s, end="")

A sentence is classified as `Correct` if the sum of the two numbers is mathematically accurate.

In [1290]:
explanation_inputs = {"examples": examples}
explanation_chain.invoke(explanation_inputs)

KeyboardInterrupt: 