## Introduction

Here we aim to perform build a __causal abstraction model__ [Geiger*, Lu*, Icard, and Potts (2020)](https://arxiv.org/pdf/2106.02997.pdf) for our SCAN data set in a format that can be used for __distributed alignment search__ [Geiger*, Wu*, Potts, Icard, and Goodman (2020)](https://arxiv.org/pdf/2303.02536.pdf) of the model with a neural network

The dataset we aim to use is SCAN and we try to build a causal model of SCAN. 

In [2]:
import pyvene
from pyvene import CausalModel

###  SCAN Task

Each example in the SCAN dataset is aimed at converting a natural language command to a sequence of actions. 

$$ InputCommand \longrightarrow OutputSequence$$

Example: 

$$jump \qquad  thrice \longrightarrow JUMP \qquad  JUMP\qquad  JUMP$$

Unlike the original hierarchical equality task:

- Our number of nodes are not fixed due to variable length commands
- The number of nodes will increase proportional to number of sentence conjuncts (and/after)
- For layer 0, we will have a nodes for each token in command
- For layer 1, the function will separate sequence into as many nodes as sentence conjuncts
- For layer 2, the function will resolve the twice/thrice statements
- For layer 3, the function will resolve the turns
- For layer 4 we have complete output

## Causal Model for SCAN

In [90]:
# defining vocabulary
C= [ "and", "after"]
N=["twice", "thrice"]
D=["turn","left", "right"]
V= ["walk", "run", "jump", "look"]

In [73]:
# getting layer 0 (base) variables
input="Jump thrice and turn left"
words=S.split()
variables_zero=[w.lower() for w in words]
print(variables_zero)

['jump', 'thrice', 'and', 'turn', 'left']


In [81]:
# getting layer 1 variables
seq1=[]
seq2=[]
for item in variables_zero:
    for c in C:
        if item==c:
            index=variables_zero.index(item)
            seq1=[s for s in variables_zero[:index]]
            seq2=[s for s in variables_zero[index+1:]]
print(seq1)
print(seq2)
variables_one=[seq1,seq2]
print(variables_one)

['jump', 'thrice']
['turn', 'left']
[['jump', 'thrice'], ['turn', 'left']]


In [87]:
# getting layer 2 variables
seq3=[]
seq4=[]
for item in variables_one:
    for word in item:
        for n in N:
            if word==n:
                index=item.index(word)
                if not seq3: #if seq3 is empty
                    if word=="thrice":
                        for i in range(3):
                            seq3.append(item[index-1])
                    elif word=="twice":
                        for i in range(2):
                            seq3.append(item[index-1])
                else: #if seq3 is not empty
                    if word=="thrice":
                        for i in range(3):
                            seq4.append(item[index-1])
                    elif word=="twice":
                        for i in range(2):
                            seq4.append(item[index-1])
print(seq3)
print(seq4)
variables_two=[seq3,seq4]
print(variables_two)

['jump', 'jump', 'jump']
[]
[['jump', 'jump', 'jump'], []]


In [89]:
# getting layer 3 variables
seq5=[]
seq6=[]
for item in variables_two:
    if not item:
        if variables_two.index(item)==0:
            item=seq1
        else:
            item=seq2
        for word in item:
            if word=="turn":
                index=item.index(word)+1
                # loop for simple turn
                if item[index]=="left": #leftcase
                    if not seq5 and item==seq3 :
                        seq5.append("LTURN")
                    else:
                        seq6.append("LTURN")
                elif item[index]=="right": #rightcase
                    if not seq5:
                        seq5.append("RTURN")
                    else:
                        seq6.append("RTURN")
                # loop for opposite turn
                elif item[index]=="opposite":
                    if item[index+1]=="left": #leftcase
                        if not seq5:
                            for i in range (2):
                                seq5.append("LTURN")
                        else:
                            for i in range (2):
                                seq6.append("LTURN")
                    elif item[index+1]=="right": #rightcase
                        if not seq5:
                            for i in range (2):
                                seq5.append("RTURN")
                        else:
                            for i in range (2):
                                seq6.append("RTURN")
                # loop for around turn
                elif item[index]=="around":
                    if item[index+1]=="left": #leftcase
                        if not seq5:
                            for i in range (4):
                                seq5.append("LTURN")
                        else:
                            for i in range (4):
                                seq6.append("LTURN")
                    elif item[index+1]=="right": #rightcase
                        if not seq5:
                            for i in range (4):
                                seq5.append("RTURN")
                        else:
                            for i in range (4):
                                seq6.append("RTURN")
            else: # when turn is not first word in the sequence
                
               
                
        
print(seq5)
print(seq6)

[]
['LTURN']
