Skip to content

Source code of paper: A Framework for Inference Inspired by Human Memory Mechanisms

License

Notifications You must be signed in to change notification settings

zengxyyu/PMI-TR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A Framework for Inference Inspired by Human Memory Mechanisms

This repository contains the code to reproduce the relational reasoning: sort_of_clever,text-based question-answering: bAbI, detecting equilateral triangles and cifar-10 tasks from our paper.

Install relevant libraries

pip install -r requirements.txt 

Task1: Sort-of-CLEVR

You can find the source code for the Sort-of-CLEVR task in sort_of_clevr_and_babi folder.

Firstly, dataset generation:

python sort_of_clevr_generator.py

Then, you can run sort_main.py directly.

Or execute the following commands to reproduce the experimental results of the Sort-of-CLEVR dataset in our paper.

sh pmi_sort.sh h_dim num_layers num_heads share_vanilla_parameters use_topk topk shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed set_transformer

Explanation of Parameters:

h_dim: Embedding dimensions

num_layers: Number of model layers

num_heads: Number of heads in multi-headed attention

share_vanilla_parameters: Whether share parameters across layers.

use_topk: Whether to use top-k competition

topk: Value of k in top-k competition

shared_memory_attention: Whether to use shared working memory and long-term memory. If shared_memory_attention is false, then vanilla multi-head attention is used.

mem_slots: Number of slots in workspace

use_long_men: Whether to use long-term memory component. It must be True in our PMI-TR.

long_mem_segs: Number of long-term memory segments

long_mem_aggre: Whether cross-attention is performed on information retrieved from the working memory and long-term memory. If True, it will run PMI-TR $_{m}w/o2$.

use_wm_inference: Whether working memory come into play during the reasoning process

seed: Random seed

functional: ues Set Transformer (ISAB) or not. If True, it will run ISAB.

Specifically, please execute the following commands to reproduce all experiments for the Sort-of-CLEVR task in the paper:

PMI-TR $_{s}$
sh pmi_sort.sh 256 4 4 True True 5 True 8 True 5 True True 1 False

PMI-TR $_{m}$
sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 True True 1 False

PMI-TR $_{l}$
sh pmi_sort.sh 256 12 16 True True 5 True 8 True 5 True True 1 False

PMI-TR $_{m}w/o1$ (without memory sharing among its layers)
sh pmi_sort.sh 256 8 8 False True 5 True 8 True 5 True True 1 False

PMI-TR $_{m}w/o2$  (info retrieved from LTM is directly aggregated with data from WM via $\alpha$ without correction step.
That is,there is no Equation 11 in the paper, and in Equation 12, \(U_{wl}^t\) is changed to \(U_{l}^t\).)

sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 False True 1 False

PMI-TR $_{m}w/o3$ (without WM involvement during inference)
sh pmi_sort.sh 256 8 8 True True 5 True 8 True 5 True False 1 False

PMI-TR $_{m}$ +soft
sh pmi_sort.sh 256 8 8 True False 5 True 8 True 5 True True 1 False

TR + HSW
sh pmi_sort.sh 256 4 4 True True 5 True 8 False 5 False False 1 False

TR
sh pmi_sort.sh 256 4 4 True False 5 False 8 False 5 False False 1 False

TR + HC
sh pmi_sort.sh 256 4 4 False False 5 False 8 False 5 False False 1 False

ISAB
sh pmi_sort.sh 256 4 4 False False 5 False 8 False 5 False False 1 True

Task2: bAbI

You can find the source code for the bAbI task in sort_of_clevr_and_babi folder.

Specifically, please run babi_main.py directly.

Or execute the following commands to reproduce all experiments for the bAbI task in the paper:

sh pmi_babi.sh h_dim num_layers num_heads share_vanilla_parameters use_topk topk shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed set_transformer
PMI-TR $_{m}$
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 True True 1 False

PMI-TR $_{m}w/o1$ 
sh pmi_babi.sh 256 8 8 False True 5 True 8 True 5 True True 1 False

PMI-TR $_{m}w/o2$ 
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 False True 1 False

PMI-TR $_{m}w/o3$ 
sh pmi_babi.sh 256 8 8 True True 5 True 8 True 5 True False 1 False

PMI-TR $_{m}$ +soft
sh pmi_babi.sh 256 8 8 True False 5 True 8 True 5 True True 1 False

TR + HSW
sh pmi_babi.sh 256 4 4 True True 5 True 8 False 5 False False 1 False

Task3: Detecting Equilateral Triangles

You can find the source code for the Triangle task in triangle_and_cifar10 folder.

Specifically, please run run.py directly.

Or execute the following commands to reproduce all experiments for the Triangle task in the paper:

sh run.sh dataset model patch_size num_layers h_dim ffn_dim share_vanilla_parameters use_topk topk
shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed
PMI-TR
sh run.sh "Triangle" "default" 32 2 128 256 True True 5 True 8 True 5 True True 1

PMI-TR+S
sh run.sh "Triangle" "default" 32 2 128 256 True False 5 True 8 True 5 True True 1

TR + HSW
sh run.sh "Triangle" "default" 4 4 128 256 True True 5 True 8 False 5 False True 1

TR
sh run.sh "Triangle" "default" 4 4 128 256 True False 5 False 8 False 5 False True 1

STR
sh run.sh "Triangle" "default" 4 4 128 256 True True 5 False 8 False 5 False True 1

ISAB
sh run.sh "Triangle" "functional" 4 4 128 256 False False 5 False 8 False 5 False True 1

Task4: Image Classification

You can find the source code for the Cifar-10 task in triangle_and_cifar10 folder.

Specifically, please run run.py directly.

Or execute the following commands to reproduce all experiments for the Cifar-10 task in the paper:

sh run.sh dataset model patch_size num_layers h_dim ffn_dim share_vanilla_parameters use_topk topk
shared_memory_attention mem_slots use_long_men long_mem_segs long_mem_aggre use_wm_inference seed

1.Trans.

PMI-TR
sh run.sh "cifar10" "default" 4 4 256 256 True True 5 True 8 True 5 True True 1

TR + HSW
sh run.sh "cifar10" "default" 4 4 256 256 True True 5 True 8 False 5 False True 1

ViT
sh run.sh "cifar10" "default" 4 4 256 256 True False 5 False 8 False 5 False True 1

ISAB
sh run.sh "cifar10" "functional" 4 4 256 256 False False 5 False 8 False 5 False True 1

2.Conv.

CNN_MLP
sh run.sh "cifar10" "CNN_MLP" 4 4 256 256 True True 5 True 8 True 5 False True 1

CNN_PMI
sh run.sh "cifar10" "CNN_PMI" 4 4 256 256 True True 5 True 8 True 5 True True 1

CNN_PMI $w/o$ 
sh run.sh "cifar10" "CNN_PMI" 4 4 256 256 True True 5 True 8 True 5 False True 1

About

Source code of paper: A Framework for Inference Inspired by Human Memory Mechanisms

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published