In [15]:
%matplotlib inline
from random import choice, randrange
import os
import numpy as np
import sys
sys.argv.append('--dynet_mem')
sys.argv.append('6000')

embed_size = 64
filter_width = 5
encoder_depth = 3
decoder_depth = 3


#os.environ["CUDA_VISIBLE_DEVICES"]="1"
PAD = "<PAD>" #all strings will end with the End Of String token
EOS = "<EOS>"
characters = list("abcd") + [PAD, EOS]


int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}

VOCAB_SIZE = len(characters)

def sample_model(min_length, max_lenth):
    random_length = randrange(min_length, max_lenth)                             # Pick a random length
    random_char_list = [choice(characters[:-2]) for _ in range(random_length)]  # Pick random chars
    random_string = ''.join(random_char_list) 
    return random_string, random_string[::-1]  # Return the random string and its reverse

MAX_STRING_LEN = 10

train_set = [sample_model(1, MAX_STRING_LEN) for _ in range(3000)]
val_set = [sample_model(1, MAX_STRING_LEN) for _ in range(50)]

In [16]:
import matplotlib.pyplot as plt
import dynet as dy
from tqdm import tqdm_notebook as tqdm

def train(network, train_set, val_set, epochs = 20):
    def get_val_set_loss(network, val_set):
        loss = [network.get_loss(input_string, output_string).value() for input_string, output_string in val_set]
        return sum(loss)
    
    train_set = train_set*epochs
    trainer = dy.SimpleSGDTrainer(network.model)
    losses = []
    iterations = []
    for i, training_example in enumerate(tqdm(train_set)):
        input_string, output_string = training_example
        dy.renew_cg()
        loss = network.get_loss(input_string, output_string)
        loss_value = loss.value()
        loss.backward()
        trainer.update()

        # Accumulate average losses over training to plot
        if i%(len(train_set)/100) == 0:
            val_loss = get_val_set_loss(network, val_set)
            losses.append(val_loss)
            iterations.append(i/((len(train_set)/100)))
            print(val_loss)

    plt.plot(iterations, losses)
    plt.axis([0, 100, 0, len(val_set)*MAX_STRING_LEN])
    plt.show() 
    print('loss on validation set:', val_loss)

In [17]:
def pp(expr):
    print(expr.npvalue().shape)

In [18]:
class ConvAtt:
    def params(self, size):
        return self.model.add_parameters(size)

    def get_conv_filters(self, filter_size, embeddings_size):
        f_a = self.params((1, filter_size, embeddings_size, embeddings_size))
        b_a = self.params((embeddings_size))
        f_b = self.params((1, filter_size, embeddings_size, embeddings_size))
        b_b = self.params((embeddings_size))
        return f_a, b_a, f_b, b_b

    def __init__(self, embeddings_size, filter_size, enc_layers, dec_layers):
        self.embeddings_size = embeddings_size

        self.model = dy.Model()

        self.word_embeddings = self.model.add_lookup_parameters((VOCAB_SIZE, embeddings_size))
        self.position_embeddings = self.model.add_lookup_parameters((MAX_STRING_LEN + 2, embeddings_size))

        self.enc_filters = []
        for _ in range(enc_layers):
            f_a, b_a, f_b, b_b = self.get_conv_filters(filter_size, embeddings_size)
            self.enc_filters.append((f_a, b_a, f_b, b_b))

        self.dec_filters = []
        self.att_ws = []
        self.att_bs = []
        for _ in range(dec_layers):
            f_a, b_a, f_b, b_b = self.get_conv_filters(filter_size, embeddings_size)
            self.dec_filters.append((f_a, b_a, f_b, b_b))
            self.att_ws.append(self.params((embeddings_size, embeddings_size)))
            self.att_bs.append(self.params((embeddings_size)))

        self.output_w = self.params((VOCAB_SIZE, embeddings_size))
        self.output_b = self.params((VOCAB_SIZE))

    def GLU(self, A, B):
        return dy.cmult(A, dy.tanh(B))

    def conv_block(self, f, block_input):
        f_a, b_a, f_b, b_b = f

        f_a = dy.parameter(f_a)
        b_a = dy.parameter(b_a)
        f_b = dy.parameter(f_b)
        b_b = dy.parameter(b_b)

        conv_output_a = dy.conv2d_bias(block_input, f_a, b_a, [1, 1], is_valid=False)
        conv_output_b = dy.conv2d_bias(block_input, f_b, b_b, [1, 1], is_valid=False)
        conv_output = self.GLU(conv_output_a, conv_output_b)
        return conv_output

    def embedd(self, string):
        seq_len = len(string)
        positions = [i for i in range(seq_len)]

        embedded_string = [self.word_embeddings[char] for char in string]
        embedded_positions = [self.position_embeddings[pos] for pos in positions]

        embedded = [char + pos for char, pos in zip(embedded_string, embedded_positions)]
        return embedded, seq_len

    def encode(self, input_string):
        embedded, seq_len = self.embedd(input_string)
        conv_input = dy.reshape(dy.concatenate_cols(embedded), (1, seq_len, self.embeddings_size))

        for enc_filter in self.enc_filters:
            conv_input = self.conv_block(enc_filter, conv_input) + conv_input
        return conv_input[0]

    def step_attention(self, conv_block_out, encoded, w, b, last_w, seq_len):
        w = dy.parameter(w)
        b = dy.parameter(b)
        conv_block_out = conv_block_out[0]
        # TODO replace with conv
        ds = [w * h + b + last_w for h in conv_block_out]
        aij = [dy.softmax(dy.concatenate([dy.transpose(d) * z for z in encoded])) for d in ds]

        cs = [dy.esum([z * a for a, z in zip(ai, encoded)]) for ai in aij]
        cs = dy.reshape(dy.transpose(dy.concatenate_cols(cs)), (1, seq_len, self.embeddings_size))
        return cs

    def decode(self, current_out, encoded):
        embedded, seq_len = self.embedd(current_out)
        last_w = embedded[-1]
        conv_input = dy.reshape(dy.concatenate_cols(embedded), (1, seq_len, self.embeddings_size))

        for dec_filter, att_w, att_b in zip(self.dec_filters, self.att_ws, self.att_bs):
            conv_block_out = self.conv_block(dec_filter, conv_input)
            conv_input = self.step_attention(conv_block_out, encoded, att_w, att_b, last_w, seq_len) + conv_input
        return conv_input

    def str2ints(self, string):
        return [char2int[c] for c in [PAD] + list(string) + [EOS]]

    def get_loss(self, input_string, output_string):
        input_string = self.str2ints(input_string)
        output_string = self.str2ints(output_string)

        w = dy.parameter(self.output_w)
        b = dy.parameter(self.output_b)

        encoded = self.encode(input_string)

        loss = []
        for j in range(1, len(output_string)):
            decoded = self.decode(output_string[:j], encoded)
            probs = dy.softmax(w * decoded[0][-1] + b)
            loss.append(-dy.log(dy.pick(probs, output_string[j])))
        decoded = self.decode(output_string, encoded)
        probs = dy.softmax(w * decoded[0][-1] + b)
        loss.append(-dy.log(dy.pick(probs, char2int[PAD])))
        return dy.esum(loss)

    def generate(self, input_string):
        input_string = self.str2ints(input_string)

        w = dy.parameter(self.output_w)
        b = dy.parameter(self.output_b)

        encoded = self.encode(input_string)
        output_string = [char2int[PAD]]
        for _ in range(MAX_STRING_LEN+2):
            decoded = self.decode(output_string, encoded)
            probs = dy.softmax(w * decoded[0][-1] + b)
            next_char = np.argmax(probs.npvalue())
            output_string.append(next_char)
            if int2char[next_char] == EOS:
                break
            
        return [int2char[char] for char in output_string]

In [19]:
conv = ConvAtt(embed_size, filter_width, encoder_depth, decoder_depth)
print(conv.generate('ab'))

['<PAD>', '<EOS>']


In [None]:
train(conv, train_set, val_set)

Widget Javascript not detected.  It may not be installed or enabled properly.



1077.915850162506
707.7158609628677
711.1349729351932
631.3835091533019


In [None]:
from slackclient import SlackClient

def get_val_set_loss(network, val_set):
    loss = [network.get_loss(input_string, output_string).value() for input_string, output_string in val_set]
    return sum(loss)
sc = SlackClient(token)
sc.api_call(str(get_val_set_loss(conv, val_set)))
sc.api_call(conv.generate('abcdab'))


In [14]:
print(conv.generate('bcaa'))

['<PAD>', 'a', 'a', 'c', 'a', 'c', '<PAD>', 'a', 'a', '<PAD>', 'c', 'a', 'a']


In [46]:
print(3)

3
