In [108]:
import random
import numpy as np
from tqdm import tnrange, tqdm_notebook
import tensorflow as tf
import math

def score(context, response):
    return random.random()


def best_response(context, candidates):
    index = np.argmax([score(context, response) for response in candidates])
    return candidates[index]


def parse_dialogs(filename):
    dialogs = []
    with open(filename, 'r') as f:
        dialog = []
        for line in f:
            if line.strip() == '':
                dialogs.append(dialog)
                dialog = []
            else:
                user_utt, bot_utt = line.strip().split('\t')
                utt_num = user_utt.split(' ')[0]
                user_utt = ' '.join(user_utt.split(' ')[1:])
                dialog.append((utt_num, user_utt, bot_utt))
    return dialogs


def parse_candidates(filename):
    with open(filename, 'r') as f:
        return [' '.join(line.strip().split(' ')[1:]) for line in f]            

    
def responses_accuracy(dialogs, candidates):
    correct = 0
    count = 0
    for dialog in dialogs:
        for _, user_utt, bot_utt in dialog:
            count += 1
            context = user_utt
            response = best_response(context, candidates)
            if response == bot_utt:
                correct += 1
    return correct / count, correct, count


def build_vocab_to_ind_map(dialogs):
    vocab = set()
    for d in dialogs:
        for _, user_utt, bot_utt in d:
            vocab = vocab.union(user_utt.split(' ') + bot_utt.split(' '))
    vocab = sorted(list(vocab))
    
    cntr = 0
    vocab_ind_map = {}
    for w in vocab:
        vocab_ind_map[w] = cntr
        cntr += 1
    return vocab_ind_map


def build_vec(vocab_ind_map, utt):
    vocab_len = len(vocab_ind_map.keys())
    vec = np.zeros((vocab_len, 1))
    for w in utt.split(' '):
        try:
            vec[vocab_ind_map[w]] += 1
        except KeyError:
            pass
    return vec


def get_vec_set(dialogs, vocab_ind_map):
    vec_set = []
    for d in dialogs:
        for _, user_utt, bot_utt in d:
            x = build_vec(vocab_ind_map, user_utt)
            y = build_vec(vocab_ind_map, bot_utt)
            vec_set.append([x, y])
    return vec_set

In [2]:
train_set_task1_dialogs = parse_dialogs('dataset/dialog-bAbI-tasks/dialog-babi-task1-API-calls-trn.txt')
dev_set_task1_dialogs = parse_dialogs('dataset/dialog-bAbI-tasks/dialog-babi-task1-API-calls-dev.txt')
candidates = parse_candidates('dataset/dialog-bAbI-tasks/dialog-babi-candidates.txt')
vocab_to_ind_map = build_vocab_to_ind_map(train_set_task1_dialogs)
vec_set = get_vec_set(train_set_task1_dialogs, vocab_to_ind_map)

In [87]:
D = 32
V = len(vocab_to_ind_map)

context = tf.placeholder(dtype=tf.float32, name='Context', shape=[V, 1])
response = tf.placeholder(dtype=tf.float32, name='Response', shape=[V, 1])
f_neg = tf.placeholder(dtype=tf.float32, name='f_neg', shape=())
A_var = tf.Variable(initial_value=tf.truncated_normal(shape=[D, V], stddev= 1 / math.sqrt(D)))
B_var = tf.Variable(initial_value=tf.truncated_normal(shape=[D, V], stddev= 1 / math.sqrt(D)))

resp_mult = tf.matmul(B_var, response)
cont_mult = tf.matmul(A_var, context)

f = tf.nn.tanh(tf.matmul(tf.transpose(cont_mult), resp_mult))

m = 0.01
loss = tf.nn.relu(f_neg - f + m)

LR = 0.001
optimizer = tf.train.GradientDescentOptimizer(LR).minimize(loss)

In [73]:
print(len(vec_set))

6024


In [94]:
train_vec_set = vec_set
sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('log/my_graph', sess.graph)
avg_loss = 0.0
for _ in tnrange(10):
    for x, y in tqdm_notebook(train_vec_set):            
        y_negs = random.sample(train_vec_set, 10)
        for _, y_neg in y_negs:
            f_neg_val = sess.run([f], feed_dict={context: x, response: y_neg})[0][0][0]
            loss_val = sess.run([loss, optimizer], feed_dict={context: x, response: y, f_neg: f_neg_val})
            avg_loss += loss_val[0][0]
    print(loss_val[0][0], avg_loss)
    avg_loss = 0
    val_pos = sess.run([f], feed_dict={context: train_vec_set[-1][0], response: train_vec_set[-1][1]})
    val_neg = sess.run([f], feed_dict={context: train_vec_set[-1][0], response: train_vec_set[100][1]})
    print(val_pos, val_neg)




[ 0.] [ 681.26000977]
[array([[ 0.99286354]], dtype=float32)] [array([[ 0.30868417]], dtype=float32)]


[ 0.01037545] [ 261.6229248]
[array([[ 0.99764019]], dtype=float32)] [array([[ 0.38179165]], dtype=float32)]


[ 0.01075728] [ 265.97769165]
[array([[ 0.99864209]], dtype=float32)] [array([[ 0.42058]], dtype=float32)]


[ 0.01022298] [ 273.97573853]
[array([[ 0.99906611]], dtype=float32)] [array([[ 0.44525474]], dtype=float32)]


[ 0.] [ 277.25915527]
[array([[ 0.9992978]], dtype=float32)] [array([[ 0.46429858]], dtype=float32)]


[ 0.] [ 279.20220947]
[array([[ 0.99944097]], dtype=float32)] [array([[ 0.47857484]], dtype=float32)]


[ 0.0096897] [ 279.83337402]
[array([[ 0.99953592]], dtype=float32)] [array([[ 0.49092939]], dtype=float32)]


[ 0.01022233] [ 282.09359741]
[array([[ 0.9996047]], dtype=float32)] [array([[ 0.50150567]], dtype=float32)]


[ 0.] [ 283.96875]
[array([[ 0.99965662]], dtype=float32)] [array([[ 0.51009864]], dtype=float32)]


[ 0.] [ 285.58453369]
[array([[ 0.99969721]], dtype=float32)] [array([[ 0.51810676]], dtype=float32)]



In [97]:
val = sess.run([f], feed_dict={context: train_vec_set[1][0], response: train_vec_set[3][1]})
print(val)

[array([[ 0.6867553]], dtype=float32)]


In [125]:
cache_vec = {}
def best_response_emb(context_, candidates, session):
    index = np.argmax([score_emb(context_, response_, session) for response_ in candidates])
    return candidates[index]

def score_emb(context_, response_, session):
    if cache_vec.get(context_) is not None:
        context_vec = cache_vec[context_]
    else:
        context_vec = build_vec(vocab_to_ind_map, context_)
        cache_vec[context_] = context_vec
    
    if cache_vec.get(response_) is not None:
        response_vec = cache_vec[response_]
    else:
        response_vec = build_vec(vocab_to_ind_map, response_)
        cache_vec[response_] = response_vec
        
    val = session.run([f], feed_dict={context: context_vec, response: response_vec})[0][0][0]
    return val

def responses_accuracy_emb(dialogs, candidates, session):
    correct = 0
    count = 0
    for dialog in tqdm_notebook(dialogs):
        for _, user_utt, bot_utt in dialog:
            count += 1
            context = user_utt
            response = best_response_emb(context, candidates, session)
            if response == bot_utt:
                correct += 1
    return correct / count, correct, count

In [126]:
responses_accuracy_emb(dev_set_task1_dialogs, candidates, sess)




(0.4370739817123857, 2629, 6015)

Не забывай подход с утилитами и кэшированием! (cat | grep | python)