# Transitive inference

One way to do this is to intermix a bunch of transitive chains in the list:

a > b

x > y

b > c

y > z

Another way is to just have one transitive chain (a>b>c>d>...) that you just present each relation once in a random order. Currently, this script is implemented THIS way.

In [74]:
transitive_chain_length = 10 # A > B > C > D > E ...
interjection_length = 200

import random
import numpy as np
import matplotlib.pyplot as plt
from helpers import * # general helper functions
%reload_ext autoreload
%autoreload 2

## task-specific helper functions

In [75]:
def make_transitive_inference_probes(words):
    np.random.seed(0)
    bigger, smaller = [],[]
    count = 0
    for i in range(transitive_chain_length):
        for j in range(i + 2, transitive_chain_length):  # skip adjacent pairs
            bigger.append(words[i])
            smaller.append(words[j])
            count += 1
    perm = np.random.permutation(count)
    bigger = [bigger[p] for p in perm]
    smaller = [smaller[p] for p in perm]
    test_probes = make_pairs(bigger, smaller, joiner = "?", randomize=True) # B ? E
    return test_probes, bigger, smaller


def plot_results(bigger, smaller, output, distance=[]):
    bigger = np.array(bigger)
    smaller = np.array(smaller)
    output = np.array(output)

    correct = np.mean([output[i] == bigger[i] for i in range(len(output))])
    incorrect = np.mean([output[i] == smaller[i] for i in range(len(output))])
    random = np.mean([output[i] != bigger[i] and output[i] != smaller[i]  for i in range(len(output))])

    n_cols=1
    if len(distance) > 0:
        n_cols=2
        dist_x,dist_y=[],[]
        for dist in sorted(np.unique(distance)):
            dist_mask = np.array(distance) == dist
            dist_correct = np.mean([output[dist_mask][i] == bigger[dist_mask][i] for i in range(sum(dist_mask))])
            dist_x.append(dist)
            dist_y.append(dist_correct)

    fig,ax = plt.subplots(1,n_cols)
    ax = np.atleast_1d(ax)  # ensures ax is always indexable
    ax[0].bar([1,2,3], [correct,incorrect,random])
    ax[0].set_xticks([1,2,3], labels=['Correct','Incorrect','Random response'], rotation=30)
    ax[0].set_ylabel('%')
    ax[0].set_ylim(0,1)

    if len(distance) > 0:
        ax[1].plot(dist_x,dist_y)
        ax[1].set_xlabel('Distance in transitive chain')
        ax[1].set_ylabel('Accuracy')
        ax[1].set_ylim(0,1)

## define task

In [76]:
all_words = txt_to_list('_data/wasnorm_wordpool.txt')

words = subsample_words(all_words, transitive_chain_length, seed=0) # each stimulus is one word in the chain

adjacent_pairs = make_pairs(words[:-1], words[1:], joiner = '>') # pair up in chain
random.seed(0)
random.shuffle(adjacent_pairs)

test_probes, bigger, smaller = make_transitive_inference_probes(words) # query on pairs of words at least 1 link away

interjection_words = subsample_words(all_words, interjection_length*2, seed=0, avoid=words)
interjection_pairs = make_pairs(interjection_words[:interjection_length], 
                                interjection_words[interjection_length:],
                                joiner = '>')

In [83]:
ti_inp = make_inp(
    adjacent_pairs, test_probes,
    preamble = "Here is a list of relations between words:",
    cue = "Now we'll begin the test. " \
        "Here is a list of test probes; each one has two target words from the original set with a ? in between. " \
        "Your task is to select the bigger word (>). " \
        "Please output a list of your chosen bigger targets (one for each in the list of probes). Do not respond with any word that is not one of the possible targets. " \
        "Do not think step by step, only output the list of chosen targets. Here are the test probes:",
)
print(ti_inp)

ti_inp_informed = make_inp(
    adjacent_pairs, test_probes,
    preamble = "Here is a list of relations between words. " \
        "After seeing this list, you will be tested on inferential relations between words in this original set. " \
        "Here are the word relations:",
    cue = "Now we'll begin the test. " \
        "Here is a list of test probes; each one has two target words from the original set with a ? in between. " \
        "Your task is to select the bigger word (>). " \
        "Please output a list of your chosen bigger targets (one for each in the list of probes). Do not respond with any word that is not one of the possible targets. " \
        "Do not think step by step, only output the list of chosen targets. Here are the test probes:",
)
print(ti_inp_informed)

ti_inp_informed_interjected = make_inp(
    adjacent_pairs, test_probes,
    preamble = "Here is a list of relations between words. " \
        "After seeing this list, you will see another list of relations. " \
        "Then, you will be tested on inferential relations between words in this original set. " \
        "Here are the word relations:",
    cue = "Now we'll begin the test. " \
        "Here is a list of test probes; each one has two target words from the original set with a ? in between. " \
         "Your task is to select the bigger word (>). " \
        "Please output a list of your chosen bigger targets (one for each in the list of probes). Do not respond with any word that is not one of the possible targets. " \
        "Do not think step by step, only output the list of chosen targets. Here are the test probes:",
    interjection = "You will now see another list of word relations. Later, during the test, you should not base your responses on the relations or words in this list:",
    interjection_words = interjection_pairs,
)
print(ti_inp_informed_interjected)

Here is a list of relations between words: list>whistle, perch>ozone, vehicle>man, bank>employee, employee>perch, man>bank, knuckle>vehicle, whistle>garbage, ozone>list. Now we'll begin the test. Here is a list of test probes; each one has two target words from the original set with a ? in between. Your task is to select the bigger word (>). Please output a list of your chosen bigger targets (one for each in the list of probes). Do not respond with any word that is not one of the possible targets. Do not think step by step, only output the list of chosen targets. Here are the test probes:perch?whistle, man?garbage, man?perch, perch?list, ozone?bank, man?employee, perch?vehicle, knuckle?employee, vehicle?ozone, garbage?employee, employee?list, list?garbage, ozone?whistle, employee?whistle, garbage?perch, bank?vehicle, whistle?vehicle, knuckle?list, man?ozone, vehicle?garbage, knuckle?garbage, employee?ozone, bank?knuckle, vehicle?list, garbage?bank, bank?whistle, whistle?knuckle, bank?l

## run task

In [None]:
pipe = make_pipe(model_id = "Qwen/Qwen3-4B-Instruct-2507")

In [None]:
for inp in [ti_inp, ti_inp_informed, ti_inp_informed_interjected]:
    print(inp)
    output = query_model(pipe, inp)
    plot_results(bigger, smaller, output, interjection_words) # C are targets, E are lures