In [None]:
from subprocess import call
from subprocess import check_output
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import json

In [None]:
# this will be used to print accuracies.

def print_accuracy(path, att_type):

    with open(path + '/evaluate.json', 'r') as evaluate:
        data = evaluate.read()

    obj = json.loads(data)

    print(att_type + " accuracy: " + str(obj['accuracy']))
    
    
def get_path_and_print_acc(datasetName):
    raw_path = check_output('./get_accuracy.sh '+datasetName, shell=True)
    paths = {}

    sep = raw_path.splitlines()
    paths['none'] = str(sep[0])[2:-1]
    paths['tanh'] = str(sep[1])[2:-1]
    paths['dot'] = str(sep[2])[2:-1]

    print(datasetName)

    print_accuracy(paths['none'], 'none')
    print_accuracy(paths['tanh'], 'tanh')
    print_accuracy(paths['dot'], 'dot')
    print('---------------------')

**Attention:** outputs of the running scripts are found in the console on which the notebook is running.

Run the following cell to preprocess the datasets for the experiments on the original paper's setup. Notice the SNLI dataset might take a while, and takes up some disk space; feel free to interrupt the cell when the other datasets are done (see console output) if you just want to focus on the other tests.

In [None]:
call("./preprocess_datasets.sh")

In the following cells, it will be possible to run the experiments on the original codebase.
Each cell regards a different dataset.

**IMPORTANT**: In case manual checking is needed, results in the form of accuracy (and other metrics) will be found in the `outputs/<dataset_name>/lstm+<attention_type>/<current_time>/evaluate.json` files. **Please notice** that because of how the original code was structured, both the "no attention" and the "tanh" attention results will be in the folder called "lstm+tanh". The results of the "no attention" tests will be simply those with the latest timestamp. Sorry for the inconvience.

Please update the permissions for the bash scripts in case they are not executable.

execute to run the IMDB experiments:

In [None]:
call("./imdb_tests.sh")

In [None]:
get_path_and_print_acc('imdb')

execute to run the SST experiments:

In [None]:
call("./sst_tests.sh")

In [None]:
get_path_and_print_acc('sst')

execute to run the AgNews experiments:

In [None]:
call("./agnews_tests.sh")

In [None]:
get_path_and_print_acc('agnews')

execute to run the 20News experiments:

In [None]:
call("./20news_tests.sh")

In [None]:
get_path_and_print_acc('20News_sports')

execute to run the bAbI 1 experiments:

In [None]:
call("./babi1_tests.sh")

In [None]:
get_path_and_print_acc('babi_1')

execute to run the SNLI experiments:

In [None]:
call("./snli_tests.sh")

In [None]:
get_path_and_print_acc('snli')

### Seq2seq

Run the autoencoder experiments, both with and without attention

In [None]:
call("./NMT_tests_auto.sh")

Print both the BLEU scores:

In [None]:
f = open("./NMT/eng_bleu_att.txt","r")
print('BLEU of autoencoder with attention: ', end='')
print(f.read())
f.close()

f = open("./NMT/eng_bleu_uni.txt","r")
print('BLEU of autoencoder without attention: ', end='')
print(f.read())
f.close()

Visualize violin plots for autoencoder:

In [None]:
img = mpimg.imread("./NMT/eng2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()

img = mpimg.imread("./NMT/avg_aggr_eng2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()

img = mpimg.imread("./NMT/max_aggr_eng2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()

Run the NMT experiments, both with and without attention

In [None]:
call("./NMT_tests_fra.sh")

Print both the BLEU scores:

In [None]:
f = open("./NMT/fra_bleu_att.txt","r")
print('BLEU of fra->eng model with attention: ', end='')
print(f.read())
f.close()

f = open("./NMT/fra_bleu_uni.txt","r")
print('BLEU of fra->eng model without attention: ', end='')
print(f.read())
f.close()

Visualize violin plots for NMT:

In [None]:
img = mpimg.imread("./NMT/fra2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()

img = mpimg.imread("./NMT/avg_aggr_fra2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()

img = mpimg.imread("./NMT/max_aggr_fra2eng_permutation_exp.png")
imgplot = plt.imshow(img)
plt.show()