This repository contains the code to reproduce the relational reasoning: sort_of_clever
,text-based question-answering: bAbI
, detecting equilateral triangles
and cifar-10
tasks from our paper.
pip install -r requirements.txt
You can find the source code for the Sort-of-CLEVR task in sort_of_clevr_and_babi
folder.
Firstly, dataset generation:
python sort_of_clevr_generator.py
Then, you can run sort_main.py
directly.
Or execute the following commands to reproduce the experimental results of the Sort-of-CLEVR dataset in our paper.
sh pmi_sort.sh h_dim num_layers num_heads share_vanilla_parameters use_topk topk shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed set_transformer
Explanation of Parameters:
h_dim
: Embedding dimensions
num_layers
: Number of model layers
num_heads
: Number of heads in multi-headed attention
share_vanilla_parameters
: Whether share parameters across layers.
use_topk
: Whether to use top-k competition
topk
: Value of k in top-k competition
shared_memory_attention
: Whether to use shared working memory and long-term memory.
If shared_memory_attention is false, then vanilla multi-head attention is used.
mem_slots
: Number of slots in workspace
use_long_men
: Whether to use long-term memory component. It must be True in our PMI-TR.
long_mem_segs
: Number of long-term memory segments
long_mem_aggre
: Whether cross-attention is performed on information retrieved from the working memory and long-term memory. If True, it will run PMI-TR
use_wm_inference
: Whether working memory come into play during the reasoning process
seed
: Random seed
functional
: ues Set Transformer (ISAB) or not. If True, it will run ISAB.
Specifically, please execute the following commands to reproduce all experiments for the Sort-of-CLEVR task in the paper:
PMI-TR $_{s}$
sh pmi_sort.sh 256 4 4 True True 5 True 8 True 5 True True 1 False
PMI-TR $_{m}$
sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 True True 1 False
PMI-TR $_{l}$
sh pmi_sort.sh 256 12 16 True True 5 True 8 True 5 True True 1 False
PMI-TR $_{m}w/o1$ (without memory sharing among its layers)
sh pmi_sort.sh 256 8 8 False True 5 True 8 True 5 True True 1 False
PMI-TR $_{m}w/o2$ (info retrieved from LTM is directly aggregated with data from WM via $\alpha$ without correction step.
That is,there is no Equation 11 in the paper, and in Equation 12, \(U_{wl}^t\) is changed to \(U_{l}^t\).)
sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 False True 1 False
PMI-TR $_{m}w/o3$ (without WM involvement during inference)
sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 True False 1 False
PMI-TR $_{m}$ +soft
sh pmi_sort.sh 256 8 8 True False 5 True 8 True 5 True True 1 False
TR + HSW
sh pmi_sort.sh 256 4 4 True True 5 True 8 False 5 False False 1 False
TR
sh pmi_sort.sh 256 4 4 True False 5 False 8 False 5 False False 1 False
TR + HC
sh pmi_sort.sh 256 4 4 False False 5 False 8 False 5 False False 1 False
ISAB
sh pmi_sort.sh 256 4 4 False False 5 False 8 False 5 False False 1 True
You can find the source code for the bAbI task in sort_of_clevr_and_babi
folder.
Specifically, please run babi_main.py
directly.
Or execute the following commands to reproduce all experiments for the bAbI task in the paper:
sh pmi_babi.sh h_dim num_layers num_heads share_vanilla_parameters use_topk topk shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed set_transformer
PMI-TR $_{m}$
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 True True 1 False
PMI-TR $_{m}w/o1$
sh pmi_babi.sh 256 8 8 False True 5 True 8 True 5 True True 1 False
PMI-TR $_{m}w/o2$
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 False True 1 False
PMI-TR $_{m}w/o3$
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 True False 1 False
PMI-TR $_{m}$ +soft
sh pmi_babi.sh 256 8 8 True False 5 True 8 True 5 True True 1 False
TR + HSW
sh pmi_babi.sh 256 4 4 True True 5 True 8 False 5 False False 1 False
You can find the source code for the Triangle task in triangle_and_cifar10
folder.
Specifically, please run run.py
directly.
Or execute the following commands to reproduce all experiments for the Triangle task in the paper:
sh run.sh dataset model patch_size num_layers h_dim ffn_dim share_vanilla_parameters use_topk topk
shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed
PMI-TR
sh run.sh "Triangle" "default" 32 2 128 256 True True 5 True 8 True 5 True True 1
PMI-TR+S
sh run.sh "Triangle" "default" 32 2 128 256 True False 5 True 8 True 5 True True 1
TR + HSW
sh run.sh "Triangle" "default" 4 4 128 256 True True 5 True 8 False 5 False True 1
TR
sh run.sh "Triangle" "default" 4 4 128 256 True False 5 False 8 False 5 False True 1
STR
sh run.sh "Triangle" "default" 4 4 128 256 True True 5 False 8 False 5 False True 1
ISAB
sh run.sh "Triangle" "functional" 4 4 128 256 False False 5 False 8 False 5 False True 1
You can find the source code for the Cifar-10 task in triangle_and_cifar10
folder.
Specifically, please run run.py
directly.
Or execute the following commands to reproduce all experiments for the Cifar-10 task in the paper:
sh run.sh dataset model patch_size num_layers h_dim ffn_dim share_vanilla_parameters use_topk topk
shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed
1.Trans.
PMI-TR
sh run.sh "cifar10" "default" 4 4 256 256 True True 5 True 8 True 5 True True 1
TR + HSW
sh run.sh "cifar10" "default" 4 4 256 256 True True 5 True 8 False 5 False True 1
ViT
sh run.sh "cifar10" "default" 4 4 256 256 True False 5 False 8 False 5 False True 1
ISAB
sh run.sh "cifar10" "functional" 4 4 256 256 False False 5 False 8 False 5 False True 1
2.Conv.
CNN_MLP
sh run.sh "cifar10" "CNN_MLP" 4 4 256 256 True True 5 True 8 True 5 False True 1
CNN_PMI
sh run.sh "cifar10" "CNN_PMI" 4 4 256 256 True True 5 True 8 True 5 True True 1
CNN_PMI $w/o$
sh run.sh "cifar10" "CNN_PMI" 4 4 256 256 True True 5 True 8 True 5 False True 1