# Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
import random
import argparse

from load_data import *
from utils import *
from training import *
from models import *

# Models

In [2]:
def print_grad(grad):
    print(grad)
    
# Redefined LSTM
class GradLSTM(ModifiableModule):
    def __init__(self, input_size, hidden_size):
        super(GradLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size

        ignore_wi = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wf = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wg = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wo = nn.Linear(hidden_size + input_size, hidden_size)

        self.wi_weights = V(ignore_wi.weight.data, requires_grad=True)
        self.wi_bias = V(ignore_wi.bias.data, requires_grad=True)
        self.wf_weights = V(ignore_wf.weight.data, requires_grad=True)
        self.wf_bias = V(ignore_wf.bias.data, requires_grad=True)
        self.wg_weights = V(ignore_wg.weight.data, requires_grad=True)
        self.wg_bias = V(ignore_wg.bias.data, requires_grad=True)
        self.wo_weights = V(ignore_wo.weight.data, requires_grad=True)
        self.wo_bias = V(ignore_wo.bias.data, requires_grad=True)



    def forward(self, inp, hidden):
        hx, cx = hidden
        input_plus_hidden = torch.cat((inp, hx), 2)
        
        
        #myhook = hx.register_hook(print_grad)


        i_tpre = F.linear(input_plus_hidden, self.wi_weights, self.wi_bias)
        i_t = torch.sigmoid(i_tpre)
        f_tpre = F.linear(input_plus_hidden, self.wf_weights, self.wf_bias)
        f_t = torch.sigmoid(f_tpre)
        g_tpre = F.linear(input_plus_hidden, self.wg_weights, self.wg_bias)
        g_t = torch.tanh(g_tpre)
        fred = F.linear(input_plus_hidden, self.wo_weights, self.wo_bias)
        o_t = torch.sigmoid(fred)
        #o_t = torch.sigmoid(F.linear(input_plus_hidden, self.wo_weights, self.wo_bias))
        
        #myhook = f_tpre.register_hook(print_grad)

        cx = f_t * cx + i_t * g_t
        hx = o_t * torch.tanh(cx)

        #if cx.requires_grad:
        #    myhook = cx.register_hook(print_grad)

        return hx, (hx, cx), fred, input_plus_hidden, i_tpre, f_tpre, g_tpre


    def named_leaves(self):
        return [('wi_weights', self.wi_weights), ('wi_bias', self.wi_bias),
                ('wf_weights', self.wf_weights), ('wf_bias', self.wf_bias),
                ('wg_weights', self.wg_weights), ('wg_bias', self.wg_bias),
                ('wo_weights', self.wo_weights), ('wo_bias', self.wo_bias)]

# Encoder/decoder model
class EncoderDecoder(ModifiableModule):
    def __init__(self, vocab_size, input_size, hidden_size):
        super(EncoderDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.embedding = GradEmbedding(vocab_size, input_size)
        self.enc_lstm = GradLSTM(input_size, hidden_size)

        self.dec_lstm = GradLSTM(input_size, hidden_size)
        self.dec_output = GradLinear(hidden_size, vocab_size)

        self.max_length = 20


    def forward(self, sequence_list, correct):
        computation_graph = {}
        # Initialize the hidden and cell states
        hidden = (V(torch.zeros(1, len(sequence_list), self.hidden_size)),
                  V(torch.zeros(1, len(sequence_list), self.hidden_size)))
        
        computation_graph["enc_h-1"] = ["init", [("ZERO", hidden[0].detach())]]
        computation_graph["enc_c-1"] = ["init", [("ZERO", hidden[1].detach())]]
        
        hidden_prev = hidden

        # The input is a list of sequences. Here the sequences are converted
        # into integer keys
        all_seqs = []
        for sequence in sequence_list:
            this_seq = []
            # Iterate over the sequence
            for elt in sequence:
                ind = self.char2ind[elt]
                this_seq.append(ind)
            all_seqs.append(torch.LongTensor(this_seq))
        max_length = max([len(x) for x in sequence_list])
        
        index = 0
        if max_length > 0:
            # Pad the sequences to allow batching 
            all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)

            all_seqs_onehot = (all_seqs > 0).type(torch.FloatTensor)

            index = 0
            # Pass the sequences through the encoder, one character at a time
            for index, elt in enumerate(all_seqs):
                cprev_name = "enc_c" + str(index-1)
                hprev_name = "enc_h" + str(index-1)
                
                # Embed the character
                emb = self.embedding(elt.unsqueeze(0))
                
                computation_graph["enc_input" + str(index)] = ["emb", [("onehot", elt), ("emb_mat", self.embedding.weights)]]

                computation_graph["enc_inputhidden" + str(index)] = ["concat", [("enc_input" + str(index), emb), (hprev_name, hidden)]]


                # Pass through the LSTM
                output, hidden_new, o_t, iph, i_tpre, f_tpre, g_tpre = self.enc_lstm(emb, hidden)
                
                
                
                computation_graph["enc_h" + str(index)] = ["tanhsigmoideltwisemul", [("enc_c" + str(index), hidden_new[1].detach()), ("enc_o" + str(index), o_t.detach())]]
                computation_graph["enc_c" + str(index)] = ["newc", [(cprev_name, hidden_prev[1].detach()), ("enc_f" + str(index), f_tpre.detach()), ("enc_i" + str(index), i_tpre.detach()), ("enc_g" + str(index), g_tpre.detach())]]
                computation_graph["enc_o" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wo", self.enc_lstm.wo_weights),("enc_bo", self.enc_lstm.wo_bias)]]
                computation_graph["enc_f" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wf", self.enc_lstm.wf_weights),("enc_bf", self.enc_lstm.wf_bias)]]
                computation_graph["enc_i" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wi", self.enc_lstm.wi_weights),("enc_bi", self.enc_lstm.wi_bias)]]
                computation_graph["enc_g" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wg", self.enc_lstm.wg_weights),("enc_bg", self.enc_lstm.wg_bias)]]

                
                
                hidden_prev = hidden_new

                # Awkward solution to variable length inputs: For each sequence in the batch, use the
                # new hidden state if the sequence is still being updated, or retain the old
                # hidden state if the sequence is over and we're now in the padding
                hx = hidden_prev[0] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)) + hidden_new[0] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)
                cx = hidden_prev[1] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)) + hidden_new[1] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)

                hidden = (hx, cx)

        encoding = hidden
        # Decoding

        # Previous output characters (used as input for the following time step)
        prev_output = ["SOS" for _ in range(len(sequence_list))]

        # Accumulates the output sequences
        out_strings = ["" for _ in range(len(sequence_list))]

        # Probabilities at each output position (used for computing the loss)
        logits = []
        preds = []
        hiddens = []
        ots = []
        iphs = []
        hidden_prev = hidden
        its = []
        fts = []
        gts = []
        
        cprev_name = "enc_c" + str(index)
        hprev_name = "enc_h" + str(index)


        for i in range(self.max_length):
            if correct[i-1] == "EOS":
                #print(correct[:i-1])
                break
            
            # Determine the previous output character for each element
            # of the batch; to be used as the input for this time step
            prev_outputs = []
            for elt in prev_output:
                ind = self.char2ind[elt]
                prev_outputs.append(ind)

            # Embed the previous outputs
            emb = self.embedding(torch.LongTensor([prev_outputs]))
            
            computation_graph["dec_input" + str(i)] = ["emb", [("onehot", ind), ("emb_mat", self.embedding.weights)]]

            
            computation_graph["dec_inputhidden" + str(i)] = ["concat", [("dec_input" + str(i), emb), (hprev_name, hidden)]]

            hidden_prev = hidden
            
            # Pass through the decoder
            output, hidden, o_t, iph, i_tpre, f_tpre, g_tpre = self.dec_lstm(emb, hidden)
            #myhook = o_t.register_hook(print_grad)

            # Determine the output probabilities used to make predictions
            pred = self.dec_output(output)
            probs = F.log_softmax(pred, dim=2)
            logits.append(probs)

            
            computation_graph["logit" + str(i)] = ["logsoftmax", [("pred" + str(i), pred.detach()), self.char2ind[correct[i]]]]
            computation_graph["pred" + str(i)] = ["weightbias", [("dec_h" + str(i), output.detach()),("output_weights", self.dec_output.weights),("output_bias", self.dec_output.bias)]]
            computation_graph["dec_h" + str(i)] = ["tanhsigmoideltwisemul", [("dec_c" + str(i), hidden[1].detach()), ("dec_o" + str(i), o_t.detach())]]
            computation_graph["dec_c" + str(i)] = ["newc", [(cprev_name, hidden_prev[1].detach()), ("dec_f" + str(i), f_tpre.detach()), ("dec_i" + str(i), i_tpre.detach()), ("dec_g" + str(i), g_tpre.detach())]]
            computation_graph["dec_o" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wo", self.dec_lstm.wo_weights),("dec_bo", self.dec_lstm.wo_bias)]]
            computation_graph["dec_f" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wf", self.dec_lstm.wf_weights),("dec_bf", self.dec_lstm.wf_bias)]]
            computation_graph["dec_i" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wi", self.dec_lstm.wi_weights),("dec_bi", self.dec_lstm.wi_bias)]]
            computation_graph["dec_g" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wg", self.dec_lstm.wg_weights),("dec_bg", self.dec_lstm.wg_bias)]]

            
            preds.append(pred)
            hiddens.append(hidden)
            ots.append(o_t)
            iphs.append(iph)
            its.append(i_tpre)
            fts.append(f_tpre)
            gts.append(g_tpre)

            # Discretize the output labels (via argmax) for generating an output character
            topv, topi = probs.data.topk(1)
            label = topi[0]

            prev_output = []
            for index, elt in enumerate(label):
                char = self.ind2char[elt.item()]

                out_strings[index] += char
                prev_output.append(char)
                
            cprev_name = "dec_c" + str(i)
            hprev_name = "dec_h" + str(i)

        return out_strings, logits, encoding, preds, hiddens, ots, iphs, hidden_prev, its, fts, gts, computation_graph

    def named_submodules(self):
        return [('embedding', self.embedding), ('enc_lstm', self.enc_lstm),
                ('dec_lstm', self.dec_lstm), ('dec_output', self.dec_output)]


    # Create a copy of the model
    def create_copy(self, same_var=False):
        new_model = EncoderDecoder(self.vocab_size, self.input_size, self.hidden_size)
        new_model.copy(self, same_var=same_var)

        return new_model

    def set_dicts(self, vocab_list):
        vocab_list = ["NULL", "SOS", "EOS"] + vocab_list

        index = 0
        char2ind = {}
        ind2char = {}

        for elt in vocab_list:
            char2ind[elt] = index
            ind2char[index] = elt
            index += 1

        self.char2ind = char2ind
        self.ind2char = ind2char



# Gradient definitions

In [3]:
def init_grad(args, result):
    return []

In [4]:
def logsoftmax_grad(args):
    name, pred = args[0]
    correct_ind = args[1]
    
    onehot = torch.zeros(1,34)
    onehot[0][correct_ind] = 1.0

    exped = torch.exp(pred)
    sm = (exped/torch.sum(exped)).view(-1).unsqueeze(0)

    mat_grad = torch.transpose(sm,0,1).expand(34,34) - torch.eye(34)

    grad_pred = torch.matmul(onehot, torch.transpose(mat_grad,0,1))

    return [(name, grad_pred)]

In [5]:
def weightbias_grad(args, result):
    name_inp, inp = args[0]
    inp = inp.view(-1).unsqueeze(0)
    name_weight, weight = args[1]
    name_bias, bias = args[2]
    
    grad_bias = result
    grad_weight = torch.transpose(torch.mm(torch.transpose(inp,0,1),grad_bias),0,1)
    grad_inp = torch.mm(grad_bias, weight).view(-1).unsqueeze(0)
    
    return [(name_weight, grad_weight), (name_bias, grad_bias), (name_inp, grad_inp)]


In [6]:
def emb_grad(args, result):
    name_ind, ind = args[0]
    name_weight, weight = args[1]
    
    onehot = torch.zeros(1,34)
    onehot[0][ind] = 1.0
    
    grad_weight = torch.mm(torch.transpose(onehot,0,1),result)
    
    return [(name_weight, grad_weight)]

In [7]:
def tanhsigmoideltwisemul_grad(args, result):
    name_c, ct = args[0]
    name_o, ot = args[1]
    
    grad_ot = torch.sigmoid(ot) * (1 - torch.sigmoid(ot)) * torch.tanh(ct) * result
    grad_ct = torch.sigmoid(ot) * result * (1 - torch.pow(torch.tanh(ct),2))
    
    grad_ot = grad_ot.view(-1).unsqueeze(0)
    grad_ct = grad_ct.view(-1).unsqueeze(0)
    
    return [(name_c, grad_ct), (name_o, grad_ot)]
    

In [8]:
def newc_grad(args, result):
    name_cprev, cprev = args[0]
    name_f, ft = args[1]
    name_i, it = args[2]
    name_g, gt = args[3]
    
    grad_cprev = result * torch.sigmoid(ft) # Might be wrong
    grad_ft = cprev * result * torch.sigmoid(ft) * (1 - torch.sigmoid(ft))
    grad_it = torch.sigmoid(it) * (1 - torch.sigmoid(it)) * torch.tanh(gt) * result
    grad_gt = (1 - torch.pow(torch.tanh(gt),2)) * torch.sigmoid(it) * result
    
    grad_cprev = grad_cprev.view(-1).unsqueeze(0)
    grad_ft = grad_ft.view(-1).unsqueeze(0)
    grad_it = grad_it.view(-1).unsqueeze(0)
    grad_gt = grad_gt.view(-1).unsqueeze(0)
    
    
    return [(name_cprev, grad_cprev), (name_f, grad_ft), (name_i, grad_it), (name_g, grad_gt)]
    

In [9]:
def concat_grad(args, result):
    inp_name, inp = args[0]
    hprev_name, hprev = args[1]
    
    grad_inp, grad_hprev = torch.split(result, [10,256], dim=1)
    
    return [(inp_name, grad_inp), (hprev_name, grad_hprev)]

# Initialize gradients

In [10]:
gradients = {}
def init_grads():
    gradients["dec_wi"] = None
    gradients["dec_wf"] = None
    gradients["dec_wg"] = None
    gradients["dec_wo"] = None

    gradients["dec_bi"] = None
    gradients["dec_bf"] = None
    gradients["dec_bg"] = None
    gradients["dec_bo"] = None

    gradients["enc_wi"] = None
    gradients["enc_wf"] = None
    gradients["enc_wg"] = None
    gradients["enc_wo"] = None

    gradients["enc_bi"] = None
    gradients["enc_bf"] = None
    gradients["enc_bg"] = None
    gradients["enc_bo"] = None

    gradients["output_weights"] = None
    gradients["output_bias"] = None

    gradients["emb_mat"] = None
    
init_grads()

function2grad = {}
function2grad["init"] = init_grad
function2grad["logsoftmax"] = logsoftmax_grad
function2grad["weightbias"] = weightbias_grad
function2grad["emb"] = emb_grad
function2grad["tanhsigmoideltwisemul"] = tanhsigmoideltwisemul_grad
function2grad["newc"] = newc_grad
function2grad["concat"] = concat_grad

# Defining backprop

In [14]:
def backprop_gradient(cg, name, gradients):
    done = False
        
    grad_type, args = cg[name]
    grad_function = function2grad[grad_type]
    
    results = {}
    for result in grad_function(args):
        name, grad = result
        if name in results:
            results[name] = results[name] + grad
        else:
            results[name] = grad
    
    while not done:
        
        if len(results.keys()) == 0:
            done = True
            break
                    
        results_new = {}
        for name in results:
            grad = results[name]
            
            if name in gradients:
                if gradients[name] is None:
                    gradients[name] = grad
                else:
                    gradients[name] = gradients[name] + grad
                    
            else:
                
                grad_type, args = cg[name]
                grad_function = function2grad[grad_type]

                to_add = grad_function(args, grad)
                for result in to_add:
                    name, grad = result
                    if name in results_new:
                        results_new[name] = results_new[name] + grad
                    else:
                        results_new[name] = grad

        results = results_new 
        
    results = None
        

    
    
    

# Example

In [15]:
model = EncoderDecoder(34,10,256)
model.load_state_dict(torch.load("maml_yonc_256_5.weights"))
model.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())
cg = model(["dogsqmasmad"], ".....................")[11]


In [16]:
init_grads()
output = backprop_gradient(cg,"logit0",gradients)

In [17]:
init_grads()
gradients

{'dec_wi': None,
 'dec_wf': None,
 'dec_wg': None,
 'dec_wo': None,
 'dec_bi': None,
 'dec_bf': None,
 'dec_bg': None,
 'dec_bo': None,
 'enc_wi': None,
 'enc_wf': None,
 'enc_wg': None,
 'enc_wo': None,
 'enc_bi': None,
 'enc_bf': None,
 'enc_bg': None,
 'enc_bo': None,
 'output_weights': None,
 'output_bias': None,
 'emb_mat': None}

In [18]:
test_set = load_dataset("yonc.test")


In [19]:
test_set[0][0][3]

['uuxu', '.u.u.xu.']

In [20]:
param_name_dict = {}
param_name_dict["embedding.weights"] = "emb_mat"
param_name_dict["enc_lstm.wi_weights"] = "enc_wi"
param_name_dict["enc_lstm.wi_bias"] = "enc_bi"
param_name_dict["enc_lstm.wf_weights"] = "enc_wf"
param_name_dict["enc_lstm.wf_bias"] = "enc_bf"
param_name_dict["enc_lstm.wg_weights"] = "enc_wg"
param_name_dict["enc_lstm.wg_bias"] = "enc_bg"
param_name_dict["enc_lstm.wo_weights"] = "enc_wo"
param_name_dict["enc_lstm.wo_bias"] = "enc_bo"
param_name_dict["dec_lstm.wi_weights"] = "dec_wi"
param_name_dict["dec_lstm.wi_bias"] = "dec_bi"
param_name_dict["dec_lstm.wf_weights"] = "dec_wf"
param_name_dict["dec_lstm.wf_bias"] = "dec_bf"
param_name_dict["dec_lstm.wg_weights"] = "dec_wg"
param_name_dict["dec_lstm.wg_bias"] = "dec_bg"
param_name_dict["dec_lstm.wo_weights"] = "dec_wo"
param_name_dict["dec_lstm.wo_bias"] = "dec_bo"
param_name_dict["dec_output.weights"] = "output_weights"
param_name_dict["dec_output.bias"] = "output_bias"


In [21]:
import gc
gc.collect()

20

In [22]:
def train_model(model, train_set):
    model.load_state_dict(torch.load("maml_yonc_256_5.weights"))
    
    for index, elt in enumerate(train_set):
        print(index)
        
        inp, outp = elt
        print(inp)
        
        all_outs = model([inp], list(outp) + ["EOS", "z", "z", "z"])
        
        print(outp)
        print(all_outs[0][0])
        print(outp == all_outs[0][0][:-3])
        print("")
        cg = all_outs[11]

        init_grads()
        gc.collect()
        if len(outp) > 0:
            for i in range(len(outp) + 1):
                backprop_gradient(cg,"logit"+str(i),gradients)

            for name, param in model.named_params():
                 model.set_param(name, param - 0.01 * gradients[param_name_dict[name]])
                    
        cg = None
        all_outs = None
            
      
    init_grads()
    gc.collect()
    cg = None
    return model
        
    

In [23]:
train_model(model,test_set[1][0][:20])
gc.collect()

0
IIpEp
.I.I.pEp.
.pII.IEpEOSEOS
False

1
cI
.cI.
.cI.EOS
True

2
IIIII
.I.I.I.I.I.
.I...I.I.I.EOS
False

3
EIEE
.E.I.E.E.
.E.I.E.E.EOS
True

4
mIIxx
.mI.I.xEx.
.mI.I.xxIx.
False

5
xpEcp
.xE.pE.cEp.
.pE.pE.cEp.EOS
False

6
IxEII
.I.xE.I.I.
.I.xE.I.I.EOS
True

7
IEIcI
.I.E.I.cI.
.I.E.I.cI.EOS
True

8


EOS
True

9
EIEIm
.E.I.E.Im.
.E.I.E.Im.EOS
True

10
xEEI
.xE.E.I.
.xE.E.I.EOS
True

11
IIEc
.I.I.Ec.
.I.I.Ec.EOS
True

12
pE
.pE.
.pE.EOS
True

13
Emcpx
.E.mEc.pEx.
.E.mEc.xEx.EOS
False

14
pmx
.pE.mEx.
.pE.mEx.EOS
True

15
pEccm
.pEc.cEm.
.pEc.cEm.EOS
True

16
xpcI
.xEp.cI.
.xEp.cI.EOS
True

17
pExEc
.pE.xEc.
.pE.xEc.EOS
True

18
EEm
.E.Em.
.E.Em.EOS
True

19
EIEIx
.E.I.E.Ix.
.E.I.E.Ix.EOS
True



0

In [24]:
model(["r"], "...........................")[0]

['.rE.EOSEOS.EOSEOS.EOSEOS.EOSEOS.EOSEOS.EOS']