In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
import random
import argparse

from load_data import *
from utils import *
from training import *
from models import *

In [1022]:
def print_grad(grad):
    print(grad)
    
# Redefined LSTM
class GradLSTM(ModifiableModule):
    def __init__(self, input_size, hidden_size):
        super(GradLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size

        ignore_wi = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wf = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wg = nn.Linear(hidden_size + input_size, hidden_size)
        ignore_wo = nn.Linear(hidden_size + input_size, hidden_size)

        self.wi_weights = V(ignore_wi.weight.data, requires_grad=True)
        self.wi_bias = V(ignore_wi.bias.data, requires_grad=True)
        self.wf_weights = V(ignore_wf.weight.data, requires_grad=True)
        self.wf_bias = V(ignore_wf.bias.data, requires_grad=True)
        self.wg_weights = V(ignore_wg.weight.data, requires_grad=True)
        self.wg_bias = V(ignore_wg.bias.data, requires_grad=True)
        self.wo_weights = V(ignore_wo.weight.data, requires_grad=True)
        self.wo_bias = V(ignore_wo.bias.data, requires_grad=True)



    def forward(self, inp, hidden):
        hx, cx = hidden
        input_plus_hidden = torch.cat((inp, hx), 2)
        
        
        #myhook = hx.register_hook(print_grad)


        i_tpre = F.linear(input_plus_hidden, self.wi_weights, self.wi_bias)
        i_t = torch.sigmoid(i_tpre)
        f_tpre = F.linear(input_plus_hidden, self.wf_weights, self.wf_bias)
        f_t = torch.sigmoid(f_tpre)
        g_tpre = F.linear(input_plus_hidden, self.wg_weights, self.wg_bias)
        g_t = torch.tanh(g_tpre)
        fred = F.linear(input_plus_hidden, self.wo_weights, self.wo_bias)
        o_t = torch.sigmoid(fred)
        #o_t = torch.sigmoid(F.linear(input_plus_hidden, self.wo_weights, self.wo_bias))
        
        #myhook = f_tpre.register_hook(print_grad)

        cx = f_t * cx + i_t * g_t
        hx = o_t * torch.tanh(cx)

        if cx.requires_grad:
            myhook = cx.register_hook(print_grad)

        return hx, (hx, cx), fred, input_plus_hidden, i_tpre, f_tpre, g_tpre


    def named_leaves(self):
        return [('wi_weights', self.wi_weights), ('wi_bias', self.wi_bias),
                ('wf_weights', self.wf_weights), ('wf_bias', self.wf_bias),
                ('wg_weights', self.wg_weights), ('wg_bias', self.wg_bias),
                ('wo_weights', self.wo_weights), ('wo_bias', self.wo_bias)]



In [1023]:
# Encoder/decoder model
class EncoderDecoder(ModifiableModule):
    def __init__(self, vocab_size, input_size, hidden_size):
        super(EncoderDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.embedding = GradEmbedding(vocab_size, input_size)
        self.enc_lstm = GradLSTM(input_size, hidden_size)

        self.dec_lstm = GradLSTM(input_size, hidden_size)
        self.dec_output = GradLinear(hidden_size, vocab_size)

        self.max_length = 20


    def forward(self, sequence_list, correct):
        computation_graph = {}
        # Initialize the hidden and cell states
        hidden = (V(torch.zeros(1, len(sequence_list), self.hidden_size)),
                  V(torch.zeros(1, len(sequence_list), self.hidden_size)))
        
        computation_graph["enc_h-1"] = ["init", [("ZERO", hidden[0].detach())]]
        computation_graph["enc_c-1"] = ["init", [("ZERO", hidden[1].detach())]]
        
        hidden_prev = hidden

        # The input is a list of sequences. Here the sequences are converted
        # into integer keys
        all_seqs = []
        for sequence in sequence_list:
            this_seq = []
            # Iterate over the sequence
            for elt in sequence:
                ind = self.char2ind[elt]
                this_seq.append(ind)
            all_seqs.append(torch.LongTensor(this_seq))
        max_length = max([len(x) for x in sequence_list])
        if max_length > 0:
            # Pad the sequences to allow batching 
            all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)

            all_seqs_onehot = (all_seqs > 0).type(torch.FloatTensor)

            # Pass the sequences through the encoder, one character at a time
            for index, elt in enumerate(all_seqs):
                cprev_name = "enc_c" + str(index-1)
                hprev_name = "enc_h" + str(index-1)
                
                # Embed the character
                emb = self.embedding(elt.unsqueeze(0))
                
                computation_graph["enc_input" + str(index)] = ["emb", [("onehot", elt), ("emb_mat", self.embedding.weights)]]

                computation_graph["enc_inputhidden" + str(index)] = ["concat", [("enc_input" + str(index), emb), (hprev_name, hidden)]]


                # Pass through the LSTM
                output, hidden_new, o_t, iph, i_tpre, f_tpre, g_tpre = self.enc_lstm(emb, hidden)
                
                
                
                computation_graph["enc_h" + str(index)] = ["tanhsigmoideltwisemul", [("enc_c" + str(index), hidden_new[1].detach()), ("enc_o" + str(index), o_t.detach())]]
                computation_graph["enc_c" + str(index)] = ["newc", [(cprev_name, hidden_prev[1].detach()), ("enc_f" + str(index), f_tpre.detach()), ("enc_i" + str(index), i_tpre.detach()), ("enc_g" + str(index), g_tpre.detach())]]
                computation_graph["enc_o" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wo", self.enc_lstm.wo_weights),("enc_bo", self.enc_lstm.wo_bias)]]
                computation_graph["enc_f" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wf", self.enc_lstm.wf_weights),("enc_bf", self.enc_lstm.wf_bias)]]
                computation_graph["enc_i" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wi", self.enc_lstm.wi_weights),("enc_bi", self.enc_lstm.wi_bias)]]
                computation_graph["enc_g" + str(index)] = ["weightbias", [("enc_inputhidden" + str(index), iph.detach()),("enc_wg", self.enc_lstm.wg_weights),("enc_bg", self.enc_lstm.wg_bias)]]

                
                
                hidden_prev = hidden_new

                # Awkward solution to variable length inputs: For each sequence in the batch, use the
                # new hidden state if the sequence is still being updated, or retain the old
                # hidden state if the sequence is over and we're now in the padding
                hx = hidden_prev[0] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)) + hidden_new[0] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)
                cx = hidden_prev[1] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)) + hidden_new[1] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)

                hidden = (hx, cx)

        encoding = hidden
        # Decoding

        # Previous output characters (used as input for the following time step)
        prev_output = ["SOS" for _ in range(len(sequence_list))]

        # Accumulates the output sequences
        out_strings = ["" for _ in range(len(sequence_list))]

        # Probabilities at each output position (used for computing the loss)
        logits = []
        preds = []
        hiddens = []
        ots = []
        iphs = []
        hidden_prev = hidden
        its = []
        fts = []
        gts = []
        
        cprev_name = "enc_c" + str(index)
        hprev_name = "enc_h" + str(index)


        for i in range(self.max_length):
            # Determine the previous output character for each element
            # of the batch; to be used as the input for this time step
            prev_outputs = []
            for elt in prev_output:
                ind = self.char2ind[elt]
                prev_outputs.append(ind)

            # Embed the previous outputs
            emb = self.embedding(torch.LongTensor([prev_outputs]))
            
            computation_graph["dec_input" + str(i)] = ["emb", [("onehot", ind), ("emb_mat", self.embedding.weights)]]

            
            computation_graph["dec_inputhidden" + str(i)] = ["concat", [("dec_input" + str(i), emb), (hprev_name, hidden)]]

            hidden_prev = hidden
            
            # Pass through the decoder
            output, hidden, o_t, iph, i_tpre, f_tpre, g_tpre = self.dec_lstm(emb, hidden)
            #myhook = o_t.register_hook(print_grad)

            # Determine the output probabilities used to make predictions
            pred = self.dec_output(output)
            probs = F.log_softmax(pred, dim=2)
            logits.append(probs)

            
            computation_graph["logit" + str(i)] = ["logsoftmax", [("pred" + str(i), pred.detach()), self.char2ind[correct[i]]]]
            computation_graph["pred" + str(i)] = ["weightbias", [("dec_h" + str(i), output.detach()),("output_weights", self.dec_output.weights),("output_bias", self.dec_output.bias)]]
            computation_graph["dec_h" + str(i)] = ["tanhsigmoideltwisemul", [("dec_c" + str(i), hidden[1].detach()), ("dec_o" + str(i), o_t.detach())]]
            computation_graph["dec_c" + str(i)] = ["newc", [(cprev_name, hidden_prev[1].detach()), ("dec_f" + str(i), f_tpre.detach()), ("dec_i" + str(i), i_tpre.detach()), ("dec_g" + str(i), g_tpre.detach())]]
            computation_graph["dec_o" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wo", self.dec_lstm.wo_weights),("dec_bo", self.dec_lstm.wo_bias)]]
            computation_graph["dec_f" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wf", self.dec_lstm.wf_weights),("dec_bf", self.dec_lstm.wf_bias)]]
            computation_graph["dec_i" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wi", self.dec_lstm.wi_weights),("dec_bi", self.dec_lstm.wi_bias)]]
            computation_graph["dec_g" + str(i)] = ["weightbias", [("dec_inputhidden" + str(i), iph.detach()),("dec_wg", self.dec_lstm.wg_weights),("dec_bg", self.dec_lstm.wg_bias)]]

            
            preds.append(pred)
            hiddens.append(hidden)
            ots.append(o_t)
            iphs.append(iph)
            its.append(i_tpre)
            fts.append(f_tpre)
            gts.append(g_tpre)

            # Discretize the output labels (via argmax) for generating an output character
            topv, topi = probs.data.topk(1)
            label = topi[0]

            prev_output = []
            for index, elt in enumerate(label):
                char = self.ind2char[elt.item()]

                out_strings[index] += char
                prev_output.append(char)
                
            cprev_name = "dec_c" + str(i)
            hprev_name = "dec_h" + str(i)

        return out_strings, logits, encoding, preds, hiddens, ots, iphs, hidden_prev, its, fts, gts, computation_graph

    def named_submodules(self):
        return [('embedding', self.embedding), ('enc_lstm', self.enc_lstm),
                ('dec_lstm', self.dec_lstm), ('dec_output', self.dec_output)]


    # Create a copy of the model
    def create_copy(self, same_var=False):
        new_model = EncoderDecoder(self.vocab_size, self.input_size, self.hidden_size)
        new_model.copy(self, same_var=same_var)

        return new_model

    def set_dicts(self, vocab_list):
        vocab_list = ["NULL", "SOS", "EOS"] + vocab_list

        index = 0
        char2ind = {}
        ind2char = {}

        for elt in vocab_list:
            char2ind[elt] = index
            ind2char[index] = elt
            index += 1

        self.char2ind = char2ind
        self.ind2char = ind2char



In [1024]:
def init_grad(args, result):
    return []

In [1025]:
def logsoftmax_grad(args):
    name, pred = args[0]
    correct_ind = args[1]
    
    onehot = torch.zeros(1,34)
    onehot[0][correct_ind] = 1.0

    exped = torch.exp(pred)
    sm = (exped/torch.sum(exped)).view(-1).unsqueeze(0)

    mat_grad = torch.transpose(sm,0,1).expand(34,34) - torch.eye(34)

    grad_pred = torch.matmul(onehot, torch.transpose(mat_grad,0,1))

    return [(name, grad_pred)]

In [1026]:
def weightbias_grad(args, result):
    name_inp, inp = args[0]
    inp = inp.view(-1).unsqueeze(0)
    name_weight, weight = args[1]
    name_bias, bias = args[2]
    
    grad_bias = result
    grad_weight = torch.transpose(torch.mm(torch.transpose(inp,0,1),grad_bias),0,1)
    grad_inp = torch.mm(grad_bias, weight).view(-1).unsqueeze(0)
    
    return [(name_weight, grad_weight), (name_bias, grad_bias), (name_inp, grad_inp)]


In [1122]:
def emb_grad(args, result):
    name_ind, ind = args[0]
    name_weight, weight = args[1]
    
    onehot = torch.zeros(1,34)
    onehot[0][ind] = 1.0
    
    grad_weight = torch.mm(torch.transpose(onehot,0,1),result)
    
    return [(name_weight, grad_weight)]

In [1123]:
def tanhsigmoideltwisemul_grad(args, result):
    name_c, ct = args[0]
    name_o, ot = args[1]
    
    grad_ot = torch.sigmoid(ot) * (1 - torch.sigmoid(ot)) * torch.tanh(ct) * result
    grad_ct = torch.sigmoid(ot) * result * (1 - torch.pow(torch.tanh(ct),2))
    
    grad_ot = grad_ot.view(-1).unsqueeze(0)
    grad_ct = grad_ct.view(-1).unsqueeze(0)
    
    return [(name_c, grad_ct), (name_o, grad_ot)]
    

In [1124]:
def newc_grad(args, result):
    name_cprev, cprev = args[0]
    name_f, ft = args[1]
    name_i, it = args[2]
    name_g, gt = args[3]
    
    grad_cprev = result * torch.sigmoid(ft) # Might be wrong
    grad_ft = cprev * result * torch.sigmoid(ft) * (1 - torch.sigmoid(ft))
    grad_it = torch.sigmoid(it) * (1 - torch.sigmoid(it)) * torch.tanh(gt) * result
    grad_gt = (1 - torch.pow(torch.tanh(gt),2)) * torch.sigmoid(it) * result
    
    grad_cprev = grad_cprev.view(-1).unsqueeze(0)
    grad_ft = grad_ft.view(-1).unsqueeze(0)
    grad_it = grad_it.view(-1).unsqueeze(0)
    grad_gt = grad_gt.view(-1).unsqueeze(0)
    
    
    return [(name_cprev, grad_cprev), (name_f, grad_ft), (name_i, grad_it), (name_g, grad_gt)]
    

In [1125]:
def concat_grad(args, result):
    inp_name, inp = args[0]
    hprev_name, hprev = args[1]
    
    grad_inp, grad_hprev = torch.split(result, [10,256], dim=1)
    
    return [(inp_name, grad_inp), (hprev_name, grad_hprev)]

In [1126]:
gradients = {}
gradients["dec_wi"] = None
gradients["dec_wf"] = None
gradients["dec_wg"] = None
gradients["dec_wo"] = None

gradients["dec_bi"] = None
gradients["dec_bf"] = None
gradients["dec_bg"] = None
gradients["dec_bo"] = None

gradients["enc_wi"] = None
gradients["enc_wf"] = None
gradients["enc_wg"] = None
gradients["enc_wo"] = None

gradients["enc_bi"] = None
gradients["enc_bf"] = None
gradients["enc_bg"] = None
gradients["enc_bo"] = None

gradients["output_weights"] = None
gradients["output_bias"] = None

gradients["emb_mat"] = None

function2grad = {}
function2grad["init"] = init_grad
function2grad["logsoftmax"] = logsoftmax_grad
function2grad["weightbias"] = weightbias_grad
function2grad["emb"] = emb_grad
function2grad["tanhsigmoideltwisemul"] = tanhsigmoideltwisemul_grad
function2grad["newc"] = newc_grad
function2grad["concat"] = concat_grad

In [1127]:
def sort_results(results):
    decs = []
    encs = []
    others = []
    
    for result in results:
        name = result[0]
        if name[:3] == "enc":
            encs.append(result)
        elif name[:3] == "dec":
            decs.append(result)
        else:
            others.append(result)
            
    return decs + encs + others

In [1128]:
def backprop_gradient(cg, name, gradients):
    done = False
    
    interm_gradients = {}
    
    grad_type, args = cg[name]
    grad_function = function2grad[grad_type]
    
    results = grad_function(args)
    
    while not done:
        print([r[0] for r in results])
        if results == []:
            done = True
            break
            
        contains_dec = results[0][0][:3] == "dec"
        print(contains_dec)
            
        
        results_new = []
        for result in results:
            name, grad = result
            if name == "enc_c0":
                print("here it is!")
            if contains_dec and result[0][:3] != "dec":
                results_new.append(result)
                #print("adding", result[0])
                continue
                
            name, grad = result
            if name in gradients:
                if name == "enc_c0":
                    print("updating")
                    print(grad)
                if gradients[name] is None:
                    gradients[name] = grad
                else:
                    gradients[name] = gradients[name] + grad
            else:
                if name == "enc_c0":
                    print("updating")
                    #print(grad)
                if name in interm_gradients:
                    interm_gradients[name] = interm_gradients[name] + grad
                    
                else:
                    interm_gradients[name] = grad
                grad_type, args = cg[name]
                grad_function = function2grad[grad_type]

                results_new = results_new + grad_function(args, grad)

        results = sort_results(results_new)
        
        
    return results_new, interm_gradients
    
    
    

In [1129]:
model = EncoderDecoder(34,10,256)
model.load_state_dict(torch.load("maml_yonc_256_5.weights"))
model.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())
cg = model(["d"], ".....................")[11]

In [1130]:
loss = nn.NLLLoss()(model(["d"],"....................")[1][1][0],torch.LongTensor([33]))
loss.backward(create_graph=True, retain_graph=True)

tensor([[[-5.3931e-01,  6.9094e-02,  4.0202e-03, -8.8672e-03, -1.1825e-02,
          -4.1266e-02,  3.4893e-02,  1.0484e-01,  3.2202e-03,  9.5029e-02,
           2.0861e-01,  2.0864e-01,  1.1967e-01, -6.4139e-02,  4.9992e-04,
           1.1397e-01,  3.5812e-01,  4.7062e-01,  9.0996e-02,  3.7696e-01,
          -1.2074e-01,  1.4445e-02, -6.2341e-02, -1.3772e-03,  3.3585e-02,
           1.6912e-01,  1.1109e-01,  1.4714e-04, -5.0178e-02, -1.4443e-02,
          -9.1349e-02,  1.3468e-01,  3.1671e-02,  3.3135e-03,  1.8037e-01,
          -5.2121e-02, -2.4404e-01,  2.4586e-01,  1.3810e-01, -6.1891e-02,
          -9.1680e-02, -8.0907e-03, -4.5414e-03, -4.9270e-02, -7.8413e-03,
           1.2321e-02,  2.0757e-03, -4.3205e-02,  4.8133e-02, -3.2909e-01,
          -2.3658e-02, -4.2870e-02, -5.5963e-02,  4.1287e-02, -2.1622e-01,
           3.7514e-04,  7.5480e-02, -2.1279e-02,  3.5426e-02,  5.9570e-02,
           1.9885e-01,  1.4978e-01,  6.1599e-02, -3.1100e-02,  1.5463e-01,
          -6.4519e-02, -2

In [1131]:
output, interm_gradients = backprop_gradient(cg,"logit1",gradients)

['pred1']
False
['dec_h1', 'output_weights', 'output_bias']
True
['dec_c1', 'dec_o1', 'output_weights', 'output_bias']
True
['dec_c0', 'dec_f1', 'dec_i1', 'dec_g1', 'dec_wo', 'dec_bo', 'dec_inputhidden1', 'output_weights', 'output_bias']
True
['dec_f0', 'dec_i0', 'dec_g0', 'dec_wf', 'dec_bf', 'dec_inputhidden1', 'dec_wi', 'dec_bi', 'dec_inputhidden1', 'dec_wg', 'dec_bg', 'dec_inputhidden1', 'dec_input1', 'dec_h0', 'enc_c0', 'output_weights', 'output_bias']
True
here it is!
['dec_wf', 'dec_bf', 'dec_inputhidden0', 'dec_wi', 'dec_bi', 'dec_inputhidden0', 'dec_wg', 'dec_bg', 'dec_inputhidden0', 'dec_input1', 'dec_h0', 'dec_input1', 'dec_h0', 'dec_input1', 'dec_h0', 'dec_c0', 'dec_o0', 'enc_c0', 'emb_mat', 'output_weights', 'output_bias']
True
here it is!
['dec_input0', 'dec_input0', 'dec_input0', 'dec_c0', 'dec_o0', 'dec_c0', 'dec_o0', 'dec_c0', 'dec_o0', 'dec_f0', 'dec_i0', 'dec_g0', 'dec_wo', 'dec_bo', 'dec_inputhidden0', 'enc_h0', 'enc_h0', 'enc_h0', 'enc_c0', 'enc_c0', 'emb_mat', 'emb

In [1132]:
interm_gradients["dec_c0"]

tensor([[-7.6343e-01,  2.8821e-01,  2.2603e-02,  4.2585e-02,  4.5970e-01,
          1.1911e-01,  2.1751e-03,  1.0366e-01,  1.3653e-01,  9.5653e-02,
         -3.0947e-01,  3.0301e-01,  1.6842e-01, -9.7699e-02,  5.0340e-03,
         -1.1438e-01,  3.5180e-01, -6.7858e-01,  5.3403e-02,  2.0384e-01,
         -2.5872e-01,  8.1468e-02, -9.4247e-02, -2.7366e-02,  4.9285e-02,
          2.1124e-01,  6.1366e-02,  9.7330e-03,  5.1393e-02,  6.6620e-03,
          2.7665e-01,  2.2864e-01, -2.1797e-01, -3.8722e-02,  1.2291e-01,
         -4.3216e-02, -1.8936e-01,  1.4102e-01,  3.3920e-02, -7.8083e-02,
          4.5861e-02, -6.7152e-02, -2.1536e-02,  2.8714e-01, -1.0896e-02,
          3.4571e-02, -1.0766e-03,  2.2614e-01,  2.5757e-01, -1.5307e-01,
         -1.6521e-01,  5.5468e-01, -1.0896e-01,  3.6422e-01,  1.9840e-02,
         -5.0932e-02, -1.8607e-01,  1.1711e-02,  4.8586e-02,  4.6913e-02,
         -2.9490e-01, -1.4798e-01,  6.9339e-02,  1.4208e-02,  3.6916e-02,
          5.5357e-01,  3.2576e-01, -4.

In [1133]:
gradients["emb_mat"]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1716,  0.0012,  0.0583,  0.2247, -0.0837, -0.1303,  0.0570, -0.2255,
         -0.1956,  0.2146],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  

In [1134]:
model.embedding.weights.grad

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1716,  0.0012,  0.0583,  0.2247, -0.0837, -0.1303,  0.0570, -0.2255,
         -0.1956,  0.2146],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  

In [37]:
cg["logit0"][1]

[('pred0',
  tensor([[[-18.2495, -18.2065,   2.8476,  -3.7418,  -5.2067,  -2.2878,  -0.3709,
             -6.6155,  -6.2988,  -5.7134,  -4.1288,  -4.8795,  -3.3390,  -7.0429,
             -4.9520,  -0.1065,  -7.0276,  -8.2674,  -7.8986,  -6.1748,  -5.8160,
             -8.9927,  -4.0365,  -8.4741,  -5.1798,  -9.3990,  -6.9244,  -5.5437,
             -6.2603,  -5.3010,  -4.9812,  -7.0763,  -4.7963,  11.8783]]],
         grad_fn=<AddBackward0>)),
 33]

In [375]:
test_set = load_dataset("yonc.test")
model = EncoderDecoder(34,10,256)
model.load_state_dict(torch.load("maml_yonc_256_5.weights"))

In [376]:
model.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())
loss = nn.NLLLoss()(model(["d"],"....................")[1][0][0],torch.LongTensor([33]))
loss.backward(create_graph=True, retain_graph=True)

tensor([[[ 4.5744e-02, -1.6112e-02, -4.7076e-02,  7.6808e-02,  1.2732e-01,
           1.0663e-01, -4.7536e-02, -5.7247e-02,  1.2138e-01, -8.1724e-02,
           8.2449e-02, -9.1530e-04,  7.0700e-02, -3.8209e-02,  1.7516e-01,
           9.8565e-02, -2.1420e-01, -3.6250e-02, -2.2710e-01, -2.3672e-01,
          -2.3202e-01,  1.0897e-01,  2.2687e-02,  1.8941e-02,  2.2788e-01,
          -2.4732e-01,  1.2328e-02, -1.2145e+00, -2.8930e-02, -5.0763e-02,
           2.2141e-01,  6.5797e-02, -2.4298e-02, -9.0569e-02,  3.4675e-02,
          -3.2900e-01, -1.1883e-01,  1.4096e-01,  1.2883e-03,  2.8239e-03,
           5.5968e-01,  9.4449e-02,  1.9732e-01, -2.0234e-01, -4.1721e-02,
          -1.9965e-01,  6.0732e-04,  2.9987e-02,  2.3678e-02, -3.8946e-02,
           2.8809e-02, -2.0510e-01, -4.9838e-02, -1.0804e-01, -8.3539e-02,
          -3.2697e-01,  1.9707e-01,  2.0227e-01, -5.0008e-02,  5.2575e-02,
          -9.0661e-02,  6.6438e-01, -1.9801e-01, -6.5950e-02,  1.4418e-01,
           3.4549e-01, -4

In [50]:
model.dec_output.bias.grad

tensor([ 8.2341e-14,  8.5960e-14,  1.1966e-04,  1.6452e-07,  3.8022e-08,
         7.0423e-07,  4.7883e-06,  9.2937e-09,  1.2756e-08,  2.2909e-08,
         1.1173e-07,  5.2740e-08,  2.4615e-07,  6.0619e-09,  4.9053e-08,
         6.2377e-06,  6.1550e-09,  1.7816e-09,  2.5762e-09,  1.4442e-08,
         2.0673e-08,  8.6259e-10,  1.2254e-07,  1.4489e-09,  3.9060e-08,
         5.7459e-10,  6.8244e-09,  2.7145e-08,  1.3257e-08,  3.4601e-08,
         4.7638e-08,  5.8625e-09,  5.7316e-08, -1.3244e-04],
       grad_fn=<CloneBackward>)