In [1]:
from llm_explain.llm.propose import propose, create_proposer_diff_prompt_body
from llm_explain.llm.validate import validate, create_prompt_body
from llm_explain.models.diff import explain_diff, ExplainDiffResult

In [2]:
X = ["cat", "dog", "fish", "carrot", "potato", "apple"]
Y = [False, False, False, True, True, True]

### Proposer Prompt

In [3]:
proposer_prompt = create_proposer_diff_prompt_body(
    X,
    Y,
    num_explanations=3,
    precise=False
)
print(proposer_prompt)


Here are two sets of text x_samples.

Some x_samples from the negative class:
Negative class sample.0: cat
Negative class sample.1: dog
Negative class sample.2: fish

Some x_samples from the positive class:
Positive class sample.0: carrot
Positive class sample.1: potato
Positive class sample.2: apple

We want to understand what kind of text x_samples are more likely to be in the positive class. Please suggest me at most 3 descriptions. Each of them needs to be a predicate about a text.. Each predicate needs to be wrapped in <predicate> tags.

Here are some example predicates:
<predicate>uses double negation</predicate>
<predicate>has a conservative stance</predicate>

Just output the predicates surrounded by <predicate> tags. Do not include any other text.



### Proposed Explanations $\phi$

In [4]:
explanations = propose(x_samples=X, y=Y, task_name="diff", num_explanations=3, precise=False)
for e in explanations:
    print("-", e)

- is a type of vegetable
- is a type of fruit
- is related to food


### Validator Prompt

In [5]:
one_sample = X[0]
e = explanations[0]

prompt = create_prompt_body(predicate=e, x_sample=one_sample)
print("Prompt:", prompt)

Prompt: 
Your job is to validate whether an x_sample surrounded by <x_sample> tags satisfies a predicate surrounds by <predicate> tags. Your output should be a yes or no.


<predicate>has a positive sentiment</predicate>
<x_sample>this movie is bad</x_sample>
<answer>no</answer>

<predicate>contains a green object</predicate>
<x_sample>the frog is climbing a tree</x_sample>
<answer>yes</answer>


Now validate the following.

<predicate>is a type of vegetable</predicate>
<x_sample>cat</x_sample>

Just output the answer surrounded by <answer> tags.


### Results of the validation $[[\phi]](x)$

In [6]:
ans = validate(predicate=e, x_sample=one_sample)
print("Answer:", ans)

Answer: 0


You can look at the implementation of the algorithms in explain_diff. Here are the results.

In [7]:
result: ExplainDiffResult = explain_diff(X=X, Y=Y, proposer_num_rounds=2, proposer_num_explanations_per_round=3, proposer_precise=False)

print(result)


Explanation: is a type of fruit or vegetable
Accuracy: 1.0

Explanation: is consumable by humans
Accuracy: 0.8333333333333333

Explanation: grows in a garden or farm
Accuracy: 1.0

Explanation: is a type of vegetable or fruit
Accuracy: 1.0

Explanation: is related to food or agriculture
Accuracy: 0.8333333333333333

Explanation: is something that grows in soil
Accuracy: 1.0


