In [1]:
from utils import generate_few_shot, string_generator, wcst_generator

This version of the WCST for LLMs provides three different tasks:
1. card: This is the most similar to the original WCST, where LLMs are asked to match cards based on color, shape, or number of shapes
2. card-random: Similar to card, but shuffles the order of the attributes in the card description
3. string: Asks to match random strings based on their length, number of vowels, or number of consonants

Queries are generated using the corresponding generator functions, as shown below

In [2]:
output = wcst_generator("color")
print(output)                                   # generator outputs a tuple

given, options = output
print(f"Given: {given}")                        # first element is the given card/string
print("Options: ")
for i, option in enumerate(options):            # second element is the options
    print(f"{i+1}. {option}")
print(f"Correct Answer: {options[0]}")          # Correct answer is always the first option

('two green circle', ['four green triangle', 'one blue star', 'one blue triangle', 'two yellow star'])
Given: two green circle
Options: 
1. four green triangle
2. one blue star
3. one blue triangle
4. two yellow star
Correct Answer: four green triangle


A preview of a trial can be generated using the generate_few_shot() function. Note that the order of the options are randomized each time to prevent contamination

In [4]:
print(generate_few_shot("card"))

Example of a short session:
Given: one green triangle
Options:
1. four green circle
2. four red cross
3. three red star
4. one red triangle

ANSWER: 3

Incorrect. Please try again.
Given: one green triangle
Options:
1. four green circle
2. four red cross
3. three red star
4. one red triangle

ANSWER: 4

Correct!
Given: one green cross
Options:
1. one blue star
2. four green circle
3. four green cross
4. three red star

ANSWER: 1

Correct!
Given: one blue circle
Options:
1. four blue triangle
2. four red star
3. three yellow triangle
4. one blue triangle

ANSWER: 4

Correct!
Given: two green star
Options:
1. three red triangle
2. three blue triangle
3. two yellow circle
4. one red circle

ANSWER: 3

Correct!
Given: two yellow triangle
Options:
1. four green circle
2. three red triangle
3. three yellow star
4. three green circle

ANSWER: 1

Incorrect. Please try again.
Given: two yellow triangle
Options:
1. four green circle
2. three red triangle
3. three yellow star
4. three green circl

Finally, the main evaluation can be run via the CLI:

`python wcst.py`

Currently, we support "llama" and "gemini" for models, which correspond to Llama 3.1 8B-Instruct and Gemini 1.5 Flash

The parameters for the CLI are:
* --model       : "llama" or "gemini", default:"llama"; model to be used
* --variant     : "card", "card-random", "string", default:"card"; specified task variant
* --max_trials  : int, default: 64; maximum number of queries in a trial
* --num_correct : int, default: 5; number of consecutive correct guesses to complete a rule section
* --repeats     : int, default: 1; number trials to run
* --few_shot    : bool, default: False; whether to use few-shot prompting
* --cot         : bool, default: False; whether to use chain-of-thought prompting
* --verbose     : int, default:15; how often model output is printed