Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
vinhkhuc committed Jan 10, 2016
1 parent 4e8e716 commit 8252345
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions babi_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
random.seed(seed_val)
np.random.seed(seed_val) # for reproducing


def run_task(data_dir, task_id):
"""
Train and test for each task
Expand Down
2 changes: 2 additions & 0 deletions demo/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def train_model(data_dir, model_file):
memn2n = MemN2N(data_dir, model_file)
memn2n.train()


def run_console_demo(data_dir, model_file):
"""
Console-based demo
Expand Down Expand Up @@ -231,6 +232,7 @@ def run_console_demo(data_dir, model_file):
if will_continue != '' and will_continue.lower() != 'y': break
print("=" * 70)


def run_web_demo(data_dir, model_file):
from demo.web import webapp
webapp.init(data_dir, model_file)
Expand Down
5 changes: 5 additions & 0 deletions demo/web/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
memn2n = None
test_story, test_questions, test_qstory = None, None, None


def init(data_dir, model_file):
""" Initialize web app """
global memn2n, test_story, test_questions, test_qstory
Expand All @@ -26,13 +27,16 @@ def init(data_dir, model_file):
test_story, test_questions, test_qstory = \
parse_babi_task(test_data_path, memn2n.general_config.dictionary, False)


def run():
app.run()


@app.route('/')
def index():
return flask.render_template("index.html")


@app.route('/get/story', methods=['GET'])
def get_story():
question_idx = np.random.randint(test_questions.shape[1])
Expand All @@ -52,6 +56,7 @@ def get_story():
"correct_answer": correct_answer
})


@app.route('/get/answer', methods=['GET'])
def get_answer():
question_idx = flask.request.args.get('question_idx')
Expand Down
2 changes: 2 additions & 0 deletions memn2n/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from memn2n.nn import ElemMult, Identity, Sequential, LookupTable, Module
from memn2n.nn import Sum, Parallel, Softmax, MatVecProd


class Memory(Module):
"""
Memory:
Expand Down Expand Up @@ -119,6 +120,7 @@ def init_output_module(self):
self.mod_out.add(p)
self.mod_out.add(MatVecProd(False))


class MemoryL(Memory):
"""
MemoryL:
Expand Down
1 change: 1 addition & 0 deletions memn2n/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# when Softmax is not included at the end layer.
np.seterr(divide='ignore')


class Module(object):
"""
Abstract Module class for neural net
Expand Down
1 change: 1 addition & 0 deletions train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def train_linear_start(train_story, train_questions, train_qstory, memory, model
# Train with old settings
train(train_story, train_questions, train_qstory, memory, model, loss, general_config)


def test(test_story, test_questions, test_qstory, memory, model, loss, general_config):
total_test_err = 0.
total_test_num = 0
Expand Down

0 comments on commit 8252345

Please sign in to comment.