In [1]:
from utils import generate_few_shot, string_generator, wcst_generator

This version of the WCST for LLMs provides three different tasks:
1. card: This is the most similar to the original WCST, where LLMs are asked to match cards based on color, shape, or number of shapes
2. card-random: Similar to card, but shuffles the order of the attributes in the card description
3. string: Asks to match random strings based on their length, number of vowels, or number of consonants

Queries are generated using the corresponding generator functions, as shown below

In [2]:
output = wcst_generator("color")
print(output)                                   # generator outputs a tuple

given, options = output
print(f"Given: {given}")                        # first element is the given card/string
print("Options: ")
for i, option in enumerate(options):            # second element is the options
    print(f"{i+1}. {option}")
print(f"Correct Answer: {options[0]}")          # Correct answer is always the first option

('three green circle', ['two green triangle', 'three blue star', 'four yellow cross', 'one red circle'])
Given: three green circle
Options: 
1. two green triangle
2. three blue star
3. four yellow cross
4. one red circle
Correct Answer: two green triangle


A preview of a trial can be generated using the generate_few_shot() function. Note that the order of the options are randomized each time to prevent contamination

In [3]:
print(generate_few_shot("card"))

Example of a short session:
Given: two red star
Options:
1. one red circle
2. two green triangle
3. four yellow cross
4. three blue star

ANSWER: 3

Incorrect. Please try again.
Given: two red star
Options:
1. one red circle
2. two green triangle
3. four yellow cross
4. three blue star

ANSWER: 2

Correct!
Given: three red circle
Options:
1. three blue star
2. four yellow cross
3. two green triangle
4. one red circle

ANSWER: 1

Correct!
Given: four green circle
Options:
1. three blue star
2. two green triangle
3. one red circle
4. four yellow cross

ANSWER: 4

Correct!
Given: two red cross
Options:
1. two green triangle
2. three blue star
3. four yellow cross
4. one red circle

ANSWER: 1

Correct!
Given: three green star
Options:
1. four yellow cross
2. one red circle
3. two green triangle
4. three blue star

ANSWER: 1

Incorrect. Please try again.
Given: three green star
Options:
1. four yellow cross
2. one red circle
3. two green triangle
4. three blue star

ANSWER: 3

Correct!
Give

Finally, the main evaluation can be run via the CLI:

`python wcst.py`

Currently, we support "llama" and "gemini" for models, which correspond to Llama 3.1 8B-Instruct and Gemini 1.5 Flash

The parameters for the CLI are:
* --model       : "llama" or "gemini", default:"llama"; model to be used
* --variant     : "card", "card-random", "string", default:"card"; specified task variant
* --max_trials  : int, default: 64; maximum number of queries in a trial
* --num_correct : int, default: 5; number of consecutive correct guesses to complete a rule section
* --repeats     : int, default: 1; number trials to run
* --few_shot    : bool, default: False; whether to use few-shot prompting
* --cot         : bool, default: False; whether to use chain-of-thought prompting
* --verbose     : int, default:15; how often model output is printed

In [None]:
!python wcst.py