In [132]:
from datasets import load_dataset_builder, load_dataset, get_dataset_infos, get_dataset_config_names, list_datasets
from langchain.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from operator import itemgetter
from sklearn.model_selection import train_test_split
from typing import List, Dict, Any, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from tqdm import tqdm
import json
import yaml
import pandas as pd
import numpy as np

In [133]:
# list_datasets()

# Data

In [134]:
def load_dataset_by_name(name):
    df = pd.read_csv(f"data/{name}_data.csv")
    df.columns.values[0] = 'sentence'
    df.columns.values[1] = 'classification'      
    df = df.dropna()
    return df

In [135]:
def split_datasets(df_dict, test_size=0.2):
    df_dict_train = {}
    df_dict_test = {}

    for dataset, df in df_dict.items():
        df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
        df_dict_train[dataset] = df_train
        df_dict_test[dataset] = df_test
    return df_dict_train, df_dict_test

def construct_example_strings(df, str_fn=None):
    df_dict_examples = {}
    for dataset_name, dataset_df in df.items():
        if str_fn:
            df_dict_examples[dataset_name] = dataset_df.apply(str_fn, axis=1).tolist()
        else:
            df_dict_examples[dataset_name] = dataset_df.apply(lambda x: f"{x['sentence']}[{x['classification']}]", axis=1).tolist()
    return df_dict_examples

In [136]:
hard_datasets = ["imdb_numbers", 
               "imdb_digits", 
               "backpack",]

easy_datasets = ["lowercase", 
               "sd_addition",
               "dd_addition",
               "simple_punctuation",
               "medium_punctuation",
               "pronoun",
               "past_tense",
               "gpt_digits",]
df_dict = {name: load_dataset_by_name(name) for name in easy_datasets}

dict_train, dict_test = split_datasets(df_dict)

str_fn = lambda x: f"sentence: {x['sentence']}\nclassification: {x['classification']}"
dict_examples = construct_example_strings(dict_train, str_fn=str_fn)

In [137]:
# print a few examples from dict_examples
for dataset, examples in dict_examples.items():
    print(f"Dataset: {dataset}")
    print("Examples:")
    for example in examples[:3]:
        print(example)
    print()

Dataset: lowercase
Examples:
sentence: i like to watch movies on weekends.
classification: True
sentence: going to the park is always fun.
classification: True
sentence: the cat sleeps on the sofa.
classification: True

Dataset: sd_addition
Examples:
sentence: 2 + 6 = 8
classification: True
sentence: 7 + 5 = 8
classification: False
sentence: 9 + 4 = 13
classification: True

Dataset: dd_addition
Examples:
sentence: 71 + 62 = 69
classification: False
sentence: 70 + 72 = 142
classification: True
sentence: 82 + 78 = 160
classification: True

Dataset: simple_punctuation
Examples:
sentence: She listened to the radio while working.
classification: False
sentence: The toast burned in the toaster.
classification: False
sentence: She collected seashells at the beach!
classification: True

Dataset: medium_punctuation
Examples:
sentence: She listened to the radio while working!
classification: True
sentence: The toast burned in the toaster.
classification: False
sentence: She collected seashells a

# Config

In [166]:
# model_types = {"gpt-4": "gpt-4-1106-preview", "gpt-3":"gpt-3.5-turbo-1106"}

@dataclass
class RunArgs():
    n_test = 30
    n_examples = 50
    dataset_name = "pronoun"
    model = "gpt-3.5-turbo-1106"
    template = "templates/classification_4.yaml"
    randomize_inputs = True
    randomize_examples = False
    label_rename_map = {"True":"TypeA", "False":"TypeB"}
    explanation_template = "templates/explain_1.yaml"
args = RunArgs()

# Model

#### Utility Functions

In [167]:
def get_output_parser():
    response_schemas = [
        ResponseSchema(name="sentence", description="the sentence to classify"),
        ResponseSchema(name="classification", description="the classification")
    ]
    output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
    return output_parser

def get_inputs(dataset_name: str, n_test: int, randomize: bool = False) -> pd.DataFrame:
    df_inputs = dict_test[dataset_name].copy()

    if randomize:
        df_inputs = df_inputs.sample(n_test)
    else: 
        df_inputs = df_inputs[:n_test]
    
    df_inputs.columns = ['sentence', 'classification']
    df_inputs['classification'].astype(bool)
    if args.label_rename_map:
        df_inputs.loc[:, 'classification'] = df_inputs['classification'].apply(lambda x: args.label_rename_map[str(x)])

    return df_inputs

def get_examples(dataset_name: str, n_examples: int, randomize_examples: bool = False, label_rename_map: Dict = None) -> pd.DataFrame:
    df_examples = dict_train[dataset_name].copy()

    if randomize_examples:
        df_examples = df_examples.sample(n_examples)
    else:
        df_examples = df_examples[:n_examples]
    
    df_examples.columns = ['sentence', 'classification']
    df_examples['classification'].astype(bool)

    if label_rename_map:
        df_examples['classification'] = df_examples['classification'].apply(lambda x: label_rename_map[str(x)])

    return df_examples

def json_formatted_examples(examples = pd.DataFrame) -> List[str]:
    json_examples = []
    for index, row in examples.iterrows():
        json_examples.append(f"{{sentence: {row['sentence']}, classification: {row['classification']}}}")


    # json_examples = examples.to_json(orient="records")
    return json_examples

def simple_formatted_examples(examples = pd.DataFrame) -> List[str]:
    simple_examples = []
    for index, row in examples.iterrows():
        simple_examples.append(f"{row['sentence']} [{row['classification']}]")
    return simple_examples

def numbered_formatted_inputs(inputs = pd.DataFrame) -> List[str]:
    numbered_inputs = []
    for i in range(len(inputs)):
        numbered_inputs.append(f"{i+1}. {inputs.iloc[i]['sentence']}")
    return numbered_inputs

def set_a_or_b_examples(examples = pd.DataFrame) -> List[str]:
    sets = {"Set_A": [], "Set_B": []}
    for _, row in examples.iterrows():
        if row['classification'] == "Set_A":
            sets["Set_A"].append(row['sentence'])
        else:
            sets["Set_B"].append(row['sentence'])
    # print(sets['Set_A'])
    return sets['Set_A'], sets['Set_B']

def prompt_str(template, inputs):
    return template.invoke(inputs).to_string()

def construct_explanation_inputs(examples, inputs):
    explanation_inputs = {"examples": examples, "inputs": inputs}
    return explanation_inputs
    

#### Components

In [168]:
# model
model = ChatOpenAI(model=args.model)

# prompt
classification_template = load_prompt(args.template)

# output parser
output_parser = get_output_parser()

# Run Single Chains

In [169]:
examples = get_examples(
    args.dataset_name, 
    args.n_examples, 
    randomize_examples=args.randomize_examples,
    label_rename_map=args.label_rename_map
)

# print some examples
print_examples = simple_formatted_examples(examples)
for i in range(10):
    print(print_examples[i])

The cake was delicious. [TypeB]
Her favorite color is blue. [TypeA]
You did a great job on the project. [TypeA]
You should try to get enough sleep. [TypeA]
Fish live in water. [TypeB]
Neptune is known for its beautiful blue color. [TypeB]
She has a meeting scheduled for tomorrow. [TypeA]
He can play the guitar beautifully. [TypeA]
Mount Everest is the highest mountain in the world. [TypeB]
Reading enhances your vocabulary. [TypeB]


In [170]:
inputs = get_inputs(
    args.dataset_name, 
    args.n_test, 
    randomize=args.randomize_inputs
)

In [143]:
label_chain = (
    classification_template
    | model
    # | output_parser if "format_instructions" in str(classification_template) else RunnablePassthrough()
)

In [144]:
# mixed
if "format_instructions" in str(classification_template):
    examples = json_formatted_examples(examples)
    examples = "\n".join(examples)
    chain_inputs = {"input": None, "examples": examples, "format_instructions": output_parser.get_format_instructions()}
else:
    examples = simple_formatted_examples(examples)
    examples = "\n".join(examples)
    chain_inputs = {"input": None, "examples": examples}

# grouped into sets
if "set" in args.template:
    set_a, set_b = set_a_or_b_examples(examples)
    chain_inputs = {"input": None, "set_a": set_a, "set_b": set_b}

In [145]:
print(prompt_str(classification_template, chain_inputs))

INSTRUCTIONS: 
Classify the sentences as TypeA or TypeB based on their similarity to the examples provided.

EXAMPLES:
The examples below are references for correct classifications of the sentences
The cake was delicious. [TypeB]
Her favorite color is blue. [TypeA]
You did a great job on the project. [TypeA]
You should try to get enough sleep. [TypeA]
Fish live in water. [TypeB]
Neptune is known for its beautiful blue color. [TypeB]
She has a meeting scheduled for tomorrow. [TypeA]
He can play the guitar beautifully. [TypeA]
Mount Everest is the highest mountain in the world. [TypeB]
Reading enhances your vocabulary. [TypeB]
Light travels faster than sound. [TypeB]
We'll have to reschedule the meeting. [TypeA]
Whales are the largest mammals on Earth. [TypeB]
The book was written by a famous author. [TypeB]
Learning a new language can be challenging. [TypeB]
You're doing an excellent job. [TypeA]
He is looking forward to the trip. [TypeA]
We're going to the museum next weekend. [TypeA]


#### Classify Multiple

In [160]:
input = "\n".join(numbered_formatted_inputs(inputs))
chain_inputs = {"input": input, "examples": examples}

out = label_chain.invoke(chain_inputs)

KeyboardInterrupt: 

In [None]:
if 'json' in out.content:
    json_str = out.content[7:-3]

# parse json
json_dict = json.loads(json_str)

# create dataframe from json
df = pd.DataFrame.from_dict(json_dict)
df.rename(columns={"classification": "prediction"}, inplace=True)

# change values in prediciton column to boolean
# df['prediction'] = df['prediction'].apply(lambda x: True if x == "TypeA" else False)



In [149]:
results = pd.merge(df, inputs, on='sentence', how='inner')

In [150]:
correct = results['classification'] == results['prediction']
print(f"percent correct: {correct.mean()}")

percent correct: 0.9666666666666667


#### Classify One

In [151]:
def classify_one():
    results = inputs.copy()

    _print = True

    for index, row in tqdm(inputs.iterrows()):
        chain_inputs['input'] = row['sentence']
        prediction = label_chain.invoke(chain_inputs)
        print(prediction)

        # if _print:
        #     print(f"Input: {row['sentence']}")
        #     print(f"Ground Truth: {row['classification']}")
        #     print("Prediction: " + prediction + "\n")

        # update results
        # results.loc[index, 'prediction'] = 'a'

# Run Experiments

In [197]:
def get_data_as_str(dataset: str, n_examples: int, n_test: int = 30, randomize_examples: bool = False, 
             randomize_inputs: bool = False, label_rename_map: Dict = {"True": "TypeA", "False": "TypeB"}):
    # get examples
    examples = get_examples(
        dataset, 
        n_examples, 
        randomize_examples=randomize_examples,
        label_rename_map=label_rename_map
    )

    #get inputs
    inputs = get_inputs(
        dataset, 
        n_test, 
        randomize=randomize_inputs
    )
    input = "\n".join(numbered_formatted_inputs(inputs))

    return input, examples

def classify(chain, input: str, examples: str):
    chain_inputs = {"input": input, "examples": examples}

    out = chain.invoke(chain_inputs)
    return out

def process_output(out):
    if 'json' in out:
        json_str = out[out.index('['):out.index(']')+1]

    # parse json
    json_dict = json.loads(json_str)

    # create dataframe from json
    df = pd.DataFrame.from_dict(json_dict)
    df.rename(columns={"classification": "prediction"}, inplace=True)

    results = pd.merge(df, inputs, on='sentence', how='inner')

    return results

def score_results(results: pd.DataFrame):
    correct = results['classification'] == results['prediction']
    print(f"percent correct: {correct.mean()}")
    return correct.mean()

In [198]:
# input_str, examples_str = get_data_as_str("pronoun", 5, 10, randomize_examples=True, randomize_inputs=True)
# results = classify(label_chain, input_str, examples_str)

In [209]:
import collections

out = None
experiments = collections.defaultdict(list)
for dataset in ['sd_addition', 'dd_addition']:
    print(dataset)
    for n_examples in tqdm([10, 20, 40, 80]):
        input_str, examples_str = get_data_as_str("pronoun", n_examples=n_examples, n_test=30, randomize_examples=True, randomize_inputs=True)
        out = classify(label_chain, input_str, examples_str)
        results = process_output(out.content)
        experiments[dataset].append((results, (results['classification'] == results['prediction']).mean()))

In [216]:
# experiments['sd_addition']
get_scores = lambda x: [y[1] for y in x]
get_scores(experiments['dd_addition'])

[0.9333333333333333, 0.9666666666666667, 0.9, 0.9666666666666667]

# Explanation

In [1348]:
explain_args = RunArgs()
explain_args.dataset_name = "backpack"
explain_args.explanation_template = "templates/explain_1.yaml"
explain_args.n_test = 50

In [1349]:
explanation_template = load_prompt(explain_args.explanation_template)

explanation_chain = (
    explanation_template
    | model
    | StrOutputParser()
)

In [1350]:
examples = get_examples(
    explain_args.dataset_name, 
    explain_args.n_examples, 
    randomize_examples=explain_args.randomize_examples,
    label_rename_map=explain_args.label_rename_map
)
inputs = get_inputs(
    explain_args.dataset_name, 
    explain_args.n_test, 
    randomize=explain_args.randomize_inputs
)
explanation_inputs = construct_explanation_inputs(examples, inputs)

KeyError: ' "True"'

In [1346]:
explanation_inputs = {"examples": examples}
for s in explanation_chain.stream(explanation_inputs):
    print(s, end="")

A sentence is classified as `Correct` if the sum of the two numbers is mathematically accurate.

In [1290]:
explanation_inputs = {"examples": examples}
explanation_chain.invoke(explanation_inputs)

KeyboardInterrupt: 