# Data loading

In [None]:
import numpy as np
from random import shuffle

In [None]:
# Load a list of abstract language descriptors
def load_languages(language_file):
    fi = open(language_file, "r")
    lang_list = []

    for line in fi:
        parts = line.strip().split("\t")

        ranking = [int(x) for x in parts[0].split(",")]
        vowel_inventory = parts[1].split(",")
        consonant_inventory = parts[2].split(",")

        lang = [ranking, vowel_inventory, consonant_inventory]

        lang_list.append(lang)

    return lang_list

# Load the file input/output correspondences
def load_io(io_file):
    fi = open(io_file, "r")

    io_correspondences = {}

    for line in fi:
        parts = line.strip().split("\t")
        ranking = tuple([int(x) for x in parts[0].split(",")])

        value = parts[1]
        value_groups = value.split("&")

        value_list = []

        for group in value_groups:
            components = group.split("#")
            inp = components[0]
            outp = components[1]
            steps = components[2].split(",")

            value_list.append([inp, outp, steps])

        io_correspondences[ranking] = value_list

    return io_correspondences

# Load a language that is just Cs and Vs
def load_dataset(dataset_file):
    fi = open(dataset_file, "r")

    langs = []
    for line in fi:
        parts = line.strip().split("\t")

        train_set = [elt.split(",") for elt in parts[0].split()]
        dev_set = [elt.split(",") for elt in parts[1].split()]
        test_set = [elt.split(",") for elt in parts[2].split()]
        vocab = parts[3].split()
        key_string = parts[4].split(",")

        v_list = key_string[0].split()
        c_list = key_string[1].split()
        ranking = [int(x) for x in key_string[2].split()]

        key = [v_list, c_list, ranking]

        langs.append([train_set, dev_set, test_set, vocab, key])

    return langs



# Load a language that is just Cs and Vs
def load_dataset_scramble(dataset_file):
    fi = open(dataset_file, "r")

    all_train_sets = []
    all_dev_sets = []
    all_test_sets = []

    n_tasks = 0

    langs = []
    for line in fi:
        parts = line.strip().split("\t")

        train_set = [elt.split(",") for elt in parts[0].split()]
        dev_set = [elt.split(",") for elt in parts[1].split()]
        test_set = [elt.split(",") for elt in parts[2].split()]
        all_train_sets += train_set
        all_dev_sets += dev_set
        all_test_sets += test_set

        vocab = parts[3].split()

        n_tasks += 1

    shuffle(all_train_sets)
    shuffle(all_dev_sets)
    shuffle(all_test_sets)

    train_len = len(train_set)
    dev_len = len(dev_set)
    test_len = len(test_set)


    for i in range(n_tasks):
        train_set = all_train_sets[i*train_len:(i+1)*train_len]
        dev_set = all_dev_sets[i*dev_len:(i+1)*dev_len]
        test_set = all_test_sets[i*test_len:(i+1)*test_len]

        v_list = "scrambled"
        c_list = "scrambled"
        ranking = "scrambled"

        key = [v_list, c_list, ranking]

        langs.append([train_set, dev_set, test_set, vocab, key])

    return langs




# Load a language that is just Cs and Vs
def load_dataset_cv(dataset_file):
    fi = open(dataset_file, "r")

    langs = []
    for line in fi:
        parts = line.strip().split("\t")

        train_set = [elt.split(",") for elt in parts[0].split()]
        test_set = [elt.split(",") for elt in parts[1].split()]
        vocab = parts[2].split()

        langs.append([train_set, test_set, vocab])

    return langs




In [None]:
test_set = load_dataset("yonc.test")


# Utils

In [None]:
# Break a list into batches of the desired size
def batchify_list(lst, batch_size=100):
    batches = []
    this_batch_in = []
    this_batch_out = []

    for index, elt in enumerate(lst):
        #print(elt)
        this_batch_in.append(elt[0])
        this_batch_out.append(elt[1])

        if (index + 1) % batch_size == 0:
            batches.append([this_batch_in, this_batch_out])
            this_batch_in = []
            this_batch_out = []

    if this_batch_in != []:
        batches.append([this_batch_in, this_batch_out])

    return batches

# Trim the excess from the end of an output string
def process_output(output):
    if "EOS" in output:
        return output[:output.index("EOS")]
    else:
        return output


# Models

In [None]:
import random
from random import shuffle
from collections import OrderedDict

In [None]:
# Redefine a basic PyTorch model to allow
# for double gradients and manual modification
# of weights
class ModifiableModule():
    def params(self):
        return [p for _, p in self.named_params()]

    def named_leaves(self):
        return []

    def named_submodules(self):
        return []

    def named_params(self):
        subparams = []
        for name, mod in self.named_submodules():
            for subname, param in mod.named_params():
                subparams.append((name + '.' + subname, param))
        return self.named_leaves() + subparams

    def set_param(self, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in self.named_submodules():
                if module_name == name:
                    mod.set_param(rest, param)
                    break
        else:
            setattr(self, name, param)

    def copy(self, other, same_var=False):
        for name, param in other.named_params():
            if not same_var:
                param = V(param.data.clone(), requires_grad=True)
            self.set_param(name, param)


    def load_state_dict(self, sdict, same_var=False):
        for name in sdict:
            param = sdict[name]
            if not same_var:
                param = V(param.data.clone(), requires_grad=True)

            self.set_param(name, param)

    def state_dict(self):
        return OrderedDict(self.named_params())



In [None]:
# Redefined linear layer
class GradLinear(ModifiableModule):
    def __init__(self, inp_size, outp_size):
        super(GradLinear, self).__init__()
        self.weights = np.random.rand(outp_size, inp_size)
        self.bias = np.random.rand(outp_size)

    def forward(self, x):
        
        return np.matmul(self.weights,x) + self.bias

    def named_leaves(self):
        return [('weights', self.weights), ('bias', self.bias)]



In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

def logsoftmax(x):
    return np.log(softmax(x))

In [None]:
# Redefined LSTM
class GradLSTM(ModifiableModule):
    def __init__(self, input_size, hidden_size):
        super(GradLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size

        self.wi_weights = np.random.rand(hidden_size, hidden_size + input_size)
        self.wi_bias = np.random.rand(hidden_size)
        self.wf_weights = np.random.rand(hidden_size, hidden_size + input_size)
        self.wf_bias = np.random.rand(hidden_size)
        self.wg_weights = np.random.rand(hidden_size, hidden_size + input_size)
        self.wg_bias = np.random.rand(hidden_size)
        self.wo_weights = np.random.rand(hidden_size, hidden_size + input_size)
        self.wo_bias = np.random.rand(hidden_size)


    def forward(self, inp, hidden):
        hx, cx = hidden
        
        input_plus_hidden = np.concatenate([inp.flatten(), hx.flatten()])

        i_tpre = np.matmul(self.wi_weights,input_plus_hidden) + self.wi_bias
        i_t = sigmoid(i_tpre)
        f_tpre = np.matmul(self.wf_weights,input_plus_hidden) + self.wf_bias
        f_t = sigmoid(f_tpre)
        g_tpre = np.matmul(self.wg_weights,input_plus_hidden) + self.wg_bias
        g_t = tanh(g_tpre)
        o_tpre = np.matmul(self.wo_weights,input_plus_hidden) + self.wo_bias
        o_t = sigmoid(o_tpre)
        #print(i_t)
        #print(f_t)
        #print(g_t)
        #print(o_t)

        cx = f_t * cx + i_t * g_t
        hx = o_t * tanh(cx)

        #myhook = input_plus_hidden.register_hook(print_grad)

        return hx, (hx, cx), o_tpre, input_plus_hidden, i_tpre, f_tpre, g_tpre


    def named_leaves(self):
        return [('wi_weights', self.wi_weights), ('wi_bias', self.wi_bias),
                ('wf_weights', self.wf_weights), ('wf_bias', self.wf_bias),
                ('wg_weights', self.wg_weights), ('wg_bias', self.wg_bias),
                ('wo_weights', self.wo_weights), ('wo_bias', self.wo_bias)]




In [None]:
# Redefined embedding layer
class GradEmbedding(ModifiableModule):
    def __init__(self, vocab_size, emb_size):
        super(GradEmbedding, self).__init__()
        self.weights = np.random.rand(emb_size, vocab_size)


    def forward(self, x):
        return np.matmul(self.weights,x)

    def named_leaves(self):
        return [('weights', self.weights)]

In [None]:
def onehot(ind):
    oh = np.zeros(34)
    oh[ind] = 1.0
    
    return oh

onehot(6)

In [None]:
# Encoder/decoder model
class EncoderDecoder(ModifiableModule):
    def __init__(self, vocab_size, input_size, hidden_size):
        super(EncoderDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.embedding = GradEmbedding(vocab_size, input_size)
        self.enc_lstm = GradLSTM(input_size, hidden_size)

        self.dec_lstm = GradLSTM(input_size, hidden_size)
        self.dec_output = GradLinear(hidden_size, vocab_size)

        self.max_length = 20
        
        self.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())


    def forward(self, inp, outp_length=20):
        # Initialize the hidden and cell states
        hidden = (np.zeros([1,self.hidden_size]), np.zeros([1,self.hidden_size]))

        this_seq = []
        # Iterate over the sequence
        for elt in inp:
            ind = self.char2ind[elt]
            this_seq.append(ind)
        
        inp_length = len(inp)
        if inp_length > 0:

            # Pass the sequences through the encoder, one character at a time
            for index, elt in enumerate(this_seq):
                # Embed the character
                emb = self.embedding.forward(onehot(elt))

                # Pass through the LSTM
                output, hidden_new, _, _, i_tpre, f_tpre, g_tpre = self.enc_lstm.forward(emb, hidden)
                hidden_prev = hidden


                hidden = hidden_new

        encoding = hidden
        # Decoding

        # Previous output characters (used as input for the following time step)
        prev_output = "SOS"

        # Accumulates the output sequences
        out_string = ""

        
        
        # Probabilities at each output position (used for computing the loss)
        logits = []
        preds = []
        hiddens = []
        ots = []
        iphs = []
        hidden_prev = hidden
        its = []
        fts = []
        gts = []



        for i in range(min(self.max_length,outp_length)):
            # Determine the previous output character for each element
            # of the batch; to be used as the input for this time step
            
            # Embed the previous outputs
            emb = self.embedding.forward(onehot(self.char2ind[prev_output]))

            # Pass through the decoder
            output, hidden, o_t, iph, i_tpre, f_tpre, g_tpre = self.dec_lstm.forward(emb, hidden)
            #myhook = o_t.register_hook(print_grad)

            # Determine the output probabilities used to make predictions
            pred = self.dec_output.forward(output.flatten())
            probs = logsoftmax(pred)
            logits.append(probs)
            preds.append(pred)
            hiddens.append(hidden)
            ots.append(o_t)
            iphs.append(iph)
            its.append(i_tpre)
            fts.append(f_tpre)
            gts.append(g_tpre)

            # Discretize the output labels (via argmax) for generating an output character
            label = np.argmax(probs)

            char = self.ind2char[label]
            out_string += char
            prev_output = char
            

        return out_string, logits, encoding, preds, hiddens, ots, iphs, hidden_prev, its, fts, gts

    def named_submodules(self):
        return [('embedding', self.embedding), ('enc_lstm', self.enc_lstm),
                ('dec_lstm', self.dec_lstm), ('dec_output', self.dec_output)]

    # Create a copy of the model
    def create_copy(self, same_var=False):
        new_model = EncoderDecoder(self.vocab_size, self.input_size, self.hidden_size)
        new_model.copy(self, same_var=same_var)

        return new_model

    def set_dicts(self, vocab_list):
        vocab_list = ["NULL", "SOS", "EOS"] + vocab_list

        index = 0
        char2ind = {}
        ind2char = {}

        for elt in vocab_list:
            char2ind[elt] = index
            ind2char[index] = elt
            index += 1

        self.char2ind = char2ind
        self.ind2char = ind2char


In [None]:
encdec = EncoderDecoder(34,10,256)

encdec.enc_lstm.wo_weights = np.loadtxt("enc_lstm.wo_weights")
encdec.enc_lstm.wi_weights = np.loadtxt("enc_lstm.wi_weights")
encdec.enc_lstm.wg_weights = np.loadtxt("enc_lstm.wg_weights")
encdec.enc_lstm.wf_weights = np.loadtxt("enc_lstm.wf_weights")
encdec.enc_lstm.wo_bias = np.loadtxt("enc_lstm.wo_bias")
encdec.enc_lstm.wi_bias = np.loadtxt("enc_lstm.wi_bias")
encdec.enc_lstm.wg_bias = np.loadtxt("enc_lstm.wg_bias")
encdec.enc_lstm.wf_bias = np.loadtxt("enc_lstm.wf_bias")

encdec.dec_lstm.wo_weights = np.loadtxt("dec_lstm.wo_weights")
encdec.dec_lstm.wi_weights = np.loadtxt("dec_lstm.wi_weights")
encdec.dec_lstm.wg_weights = np.loadtxt("dec_lstm.wg_weights")
encdec.dec_lstm.wf_weights = np.loadtxt("dec_lstm.wf_weights")
encdec.dec_lstm.wo_bias = np.loadtxt("dec_lstm.wo_bias")
encdec.dec_lstm.wi_bias = np.loadtxt("dec_lstm.wi_bias")
encdec.dec_lstm.wg_bias = np.loadtxt("dec_lstm.wg_bias")
encdec.dec_lstm.wf_bias = np.loadtxt("dec_lstm.wf_bias")

encdec.embedding.weights = np.loadtxt("embedding.weights").transpose()
encdec.dec_output.weights = np.loadtxt("dec_output.weights")
encdec.dec_output.bias = np.loadtxt("dec_output.bias")




In [None]:
encdec.forward("do")[0]

In [None]:
from load_data import *
from utils import *
from training import *
from models import *

model = EncoderDecoder(34,10,256)
model.load_state_dict(torch.load("maml_yonc_256_5.weights"))

In [17]:
def flatten(lst):
    new_list = []
    for elt in lst:
        new_list = new_list + elt
        
    return new_list

In [None]:
print(model.embedding.weights.data.shape) # weights_emb; weights_emb_bias is zeroes
print(model.enc_lstm.wo_weights.transpose(0,1).data.shape)
print(model.enc_lstm.wo_bias.data.shape)
print(model.dec_output.weights.transpose(0,1).data.shape) # weights_out_weights
print(model.dec_output.bias.data.shape) # weights_out_bias

In [None]:
xi = flatten(model.enc_lstm.wi_weights.transpose(0,1)[:10].data.numpy().tolist())
wi = flatten(model.enc_lstm.wi_weights.transpose(0,1)[10:].data.numpy().tolist())
bi = model.enc_lstm.wi_bias.data.numpy().tolist()

xf = flatten(model.enc_lstm.wf_weights.transpose(0,1)[:10].data.numpy().tolist())
wf = flatten(model.enc_lstm.wf_weights.transpose(0,1)[10:].data.numpy().tolist())
bf = model.enc_lstm.wf_bias.data.numpy().tolist()

xg = flatten(model.enc_lstm.wg_weights.transpose(0,1)[:10].data.numpy().tolist())
wg = flatten(model.enc_lstm.wg_weights.transpose(0,1)[10:].data.numpy().tolist())
bg = model.enc_lstm.wg_bias.data.numpy().tolist()

xo = flatten(model.enc_lstm.wo_weights.transpose(0,1)[:10].data.numpy().tolist())
wo = flatten(model.enc_lstm.wo_weights.transpose(0,1)[10:].data.numpy().tolist())
bo = model.enc_lstm.wo_bias.data.numpy().tolist()

full_x = xi + xf + xg + xo
full_w = wi + wf + wg + wo
full_b = bi + bf + bg + bo




xid = flatten(model.dec_lstm.wi_weights.transpose(0,1)[:10].data.numpy().tolist())
wid = flatten(model.dec_lstm.wi_weights.transpose(0,1)[10:].data.numpy().tolist())
bid = model.dec_lstm.wi_bias.data.numpy().tolist()

xfd = flatten(model.dec_lstm.wf_weights.transpose(0,1)[:10].data.numpy().tolist())
wfd = flatten(model.dec_lstm.wf_weights.transpose(0,1)[10:].data.numpy().tolist())
bfd = model.dec_lstm.wf_bias.data.numpy().tolist()

xgd = flatten(model.dec_lstm.wg_weights.transpose(0,1)[:10].data.numpy().tolist())
wgd = flatten(model.dec_lstm.wg_weights.transpose(0,1)[10:].data.numpy().tolist())
bgd = model.dec_lstm.wg_bias.data.numpy().tolist()

xod = flatten(model.dec_lstm.wo_weights.transpose(0,1)[:10].data.numpy().tolist())
wod = flatten(model.dec_lstm.wo_weights.transpose(0,1)[10:].data.numpy().tolist())
bod = model.dec_lstm.wo_bias.data.numpy().tolist()

full_xd = xid + xfd + xgd + xod
full_wd = wid + wfd + wgd + wod
full_bd = bid + bfd + bgd + bod


In [19]:
def stringify_lst(lst):
    joined = ", ".join([str(x) for x in lst])
    return "[" + joined + "]"

In [None]:
tf_weights = open("tf_weights.js", "w")
tf_weights.write("emb_wg = " + stringify_lst(flatten(model.embedding.weights.data.numpy().tolist())) + ";\n")
tf_weights.write("full_x = " + stringify_lst(full_x) + ";\n")
tf_weights.write("full_w = " + stringify_lst(full_w) + ";\n")
tf_weights.write("full_b = " + stringify_lst(full_b) + ";\n")
tf_weights.write("full_xd = " + stringify_lst(full_xd) + ";\n")
tf_weights.write("full_wd = " + stringify_lst(full_wd) + ";\n")
tf_weights.write("full_bd = " + stringify_lst(full_bd) + ";\n")
tf_weights.write("out_wg = " + stringify_lst(flatten(model.dec_output.weights.transpose(0,1).data.numpy().tolist())) + ";\n")
tf_weights.write("out_wb = " + stringify_lst(model.dec_output.bias.data.numpy().tolist()) + ";\n")


In [None]:
model.enc_lstm.wo_weights.data.numpy()

In [None]:
np.savetxt("embedding.weights",model.embedding.weights.data.numpy())

np.savetxt("enc_lstm.wi_weights",model.enc_lstm.wi_weights.data.numpy())
np.savetxt("enc_lstm.wi_bias",model.enc_lstm.wi_bias.data.numpy())
np.savetxt("enc_lstm.wf_weights",model.enc_lstm.wf_weights.data.numpy())
np.savetxt("enc_lstm.wf_bias",model.enc_lstm.wf_bias.data.numpy())
np.savetxt("enc_lstm.wg_weights",model.enc_lstm.wg_weights.data.numpy())
np.savetxt("enc_lstm.wg_bias",model.enc_lstm.wg_bias.data.numpy())
np.savetxt("enc_lstm.wo_weights",model.enc_lstm.wo_weights.data.numpy())
np.savetxt("enc_lstm.wo_bias",model.enc_lstm.wo_bias.data.numpy())

np.savetxt("dec_lstm.wi_weights",model.dec_lstm.wi_weights.data.numpy())
np.savetxt("dec_lstm.wi_bias",model.dec_lstm.wi_bias.data.numpy())
np.savetxt("dec_lstm.wf_weights",model.dec_lstm.wf_weights.data.numpy())
np.savetxt("dec_lstm.wf_bias",model.dec_lstm.wf_bias.data.numpy())
np.savetxt("dec_lstm.wg_weights",model.dec_lstm.wg_weights.data.numpy())
np.savetxt("dec_lstm.wg_bias",model.dec_lstm.wg_bias.data.numpy())
np.savetxt("dec_lstm.wo_weights",model.dec_lstm.wo_weights.data.numpy())
np.savetxt("dec_lstm.wo_bias",model.dec_lstm.wo_bias.data.numpy())

np.savetxt("dec_output.weights",model.dec_output.weights.data.numpy())
np.savetxt("dec_output.bias",model.dec_output.bias.data.numpy())


In [None]:
[x[0] for x in model.named_params()]

In [None]:
from load_data import *
from utils import *
from training import *
from models import *

model = EncoderDecoder(34,10,256, recurrent_unit="GRU")
model.load_state_dict(torch.load("yonc_maml_gru_256_5.weights"))



In [None]:
def flatten(lst):
    new_list = []
    for elt in lst:
        new_list = new_list + elt
        
    return new_list

In [None]:
print(model.embedding.weights.data.shape) # weights_emb; weights_emb_bias is zeroes
print(model.enc_lstm.wr_weights.transpose(0,1).data.shape)
print(model.enc_lstm.wr_bias.data.shape)
print(model.dec_output.weights.transpose(0,1).data.shape) # weights_out_weights
print(model.dec_output.bias.data.shape) # weights_out_bias

In [None]:
xr = flatten(model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().tolist())
wr = flatten(model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy().tolist())
br = model.enc_lstm.wr_bias.data.numpy().tolist()

xz = flatten(model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().tolist())
wz = flatten(model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy().tolist())
bz = model.enc_lstm.wz_bias.data.numpy().tolist()

xx = flatten(model.enc_lstm.wx_weights.transpose(0,1).data.numpy().tolist())
#wx = flatten(model.enc_lstm.wx_weights.transpose(0,1)[10:].data.numpy().tolist())
bx = model.enc_lstm.wx_bias.data.numpy()

#xrh = flatten(model.enc_lstm.wrh_weights.transpose(0,1)[:10].data.numpy().tolist())
wrh = flatten(model.enc_lstm.wrh_weights.transpose(0,1).data.numpy().tolist())
brh = model.enc_lstm.wrh_bias.data.numpy()

bxrh = (bx + brh).tolist()

full_x = xz + xr + xx
full_w = wz + wr + wrh
full_b = bz + br + bxrh




xrd = flatten(model.dec_lstm.wr_weights.transpose(0,1)[:10].data.numpy().tolist())
wrd = flatten(model.dec_lstm.wr_weights.transpose(0,1)[10:].data.numpy().tolist())
brd = model.dec_lstm.wr_bias.data.numpy().tolist()

xzd = flatten(model.dec_lstm.wz_weights.transpose(0,1)[:10].data.numpy().tolist())
wzd = flatten(model.dec_lstm.wz_weights.transpose(0,1)[10:].data.numpy().tolist())
bzd = model.dec_lstm.wz_bias.data.numpy().tolist()

xxd = flatten(model.dec_lstm.wx_weights.transpose(0,1).data.numpy().tolist())
#wxd = flatten(model.dec_lstm.wx_weights.transpose(0,1)[10:].data.numpy().tolist())
bxd = model.dec_lstm.wx_bias.data.numpy()

#xrhd = flatten(model.dec_lstm.wrh_weights.transpose(0,1)[:10].data.numpy().tolist())
wrhd = flatten(model.dec_lstm.wrh_weights.transpose(0,1).data.numpy().tolist())
brhd = model.dec_lstm.wrh_bias.data.numpy()

bxrhd = (bxd + brhd).tolist()

full_xd = xzd + xrd + xxd
full_wd = wzd + wrd + wrhd
full_bd = bzd + brd + bxrhd

In [None]:
# HERE IT IS!!!
xr = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy()#.transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy() #.transpose()
br = np.expand_dims(model.enc_lstm.wr_bias.data.numpy(),axis=0)

xz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy()#.transpose()
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy() #.transpose()
bz = np.expand_dims(model.enc_lstm.wz_bias.data.numpy(), axis=0)

xx = model.enc_lstm.wx_weights.transpose(0,1).data.numpy()#.transpose()
bx = np.expand_dims(model.enc_lstm.wx_bias.data.numpy(), axis=0)

#xrh = flatten(model.enc_lstm.wrh_weights.transpose(0,1)[:10].data.numpy().tolist())
wrh = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy() #.transpose()
brh = np.expand_dims(model.enc_lstm.wrh_bias.data.numpy(),axis=0)

bxrh = bx + brh

full_w = flatten(np.concatenate([wz, wr, wrh], axis=1).tolist())
full_x = flatten(np.concatenate([xz, xr, xx], axis=1).transpose().tolist())


full_b = flatten(np.concatenate([bz,br,bxrh], axis=0).tolist())
#print(full_b)

xrd = flatten(model.dec_lstm.wr_weights.transpose(0,1)[:10].data.numpy().tolist())
wrd = flatten(model.dec_lstm.wr_weights.transpose(0,1)[10:].data.numpy().tolist())
brd = model.dec_lstm.wr_bias.data.numpy().tolist()

xzd = flatten(model.dec_lstm.wz_weights.transpose(0,1)[:10].data.numpy().tolist())
wzd = flatten(model.dec_lstm.wz_weights.transpose(0,1)[10:].data.numpy().tolist())
bzd = model.dec_lstm.wz_bias.data.numpy().tolist()

xxd = flatten(model.dec_lstm.wx_weights.transpose(0,1).data.numpy().tolist())
#wxd = flatten(model.dec_lstm.wx_weights.transpose(0,1)[10:].data.numpy().tolist())
bxd = model.dec_lstm.wx_bias.data.numpy()

#xrhd = flatten(model.dec_lstm.wrh_weights.transpose(0,1)[:10].data.numpy().tolist())
wrhd = flatten(model.dec_lstm.wrh_weights.transpose(0,1).data.numpy().tolist())
brhd = model.dec_lstm.wrh_bias.data.numpy()

bxrhd = (bxd + brhd).tolist()

full_xd = xzd + xrd + xxd
full_wd = wzd + wrd + wrhd
full_bd = bzd + brd + bxrhd

In [None]:
#np.concatenate([bz,br,brh]).tolist()

In [None]:
def stringify_lst(lst):
    joined = ", ".join([str(x) for x in lst])
    return "[" + joined + "]"

In [None]:
len(full_wd)

In [None]:
tf_weights = open("tf_weights.js", "w")
tf_weights.write("emb_wg = " + stringify_lst(flatten(model.embedding.weights.data.numpy().tolist())) + ";\n")
tf_weights.write("full_x = " + stringify_lst(full_x) + ";\n")
tf_weights.write("full_w = " + stringify_lst(full_w) + ";\n")
tf_weights.write("full_b = " + stringify_lst(full_b) + ";\n")
tf_weights.write("full_xd = " + stringify_lst(full_xd) + ";\n")
tf_weights.write("full_wd = " + stringify_lst(full_wd) + ";\n")
tf_weights.write("full_bd = " + stringify_lst(full_bd) + ";\n")
tf_weights.write("out_wg = " + stringify_lst(flatten(model.dec_output.weights.transpose(0,1).data.numpy().tolist())) + ";\n")
tf_weights.write("out_wb = " + stringify_lst(model.dec_output.bias.data.numpy().tolist()) + ";\n")


In [None]:
model.enc_lstm.wz_weights.transpose(0,1)[:10]

In [None]:
stringify_lst(full_x)

In [None]:
stringify_lst(full_w)

In [None]:
model.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())
model(["b"])[2]


In [None]:
# Changing some dimensionalities
import numpy as np

emb_mat = model.embedding.weights.data.numpy()

inp = np.matmul(np.array([[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose() # GOOD
h = np.array([[0 for _ in range(256)]]).transpose()

print(h.shape)
print(inp.shape)

uz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy() #.transpose()
bz = model.enc_lstm.wz_bias.data.numpy() # GOOD

ur = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy() #.transpose()
br = model.enc_lstm.wr_bias.data.numpy()

ux = model.enc_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wx = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy() #.transpose()
bx = model.enc_lstm.wx_bias.data.numpy() + model.enc_lstm.wrh_bias.data.numpy() # GOOD


z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
print("z", z.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))
print("r", r.shape)

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)
print("htilde", htilde.shape)

h = (1 - z)*htilde + z*h
inp = np.matmul(np.array([[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose()
print(h)
print("new_enc", h.shape)

z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
print("z", z.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))
print("r", r.shape)

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)
print("htilde", htilde.shape)

h = (1 - z)*htilde + z*h
print(h)


In [None]:
model.embedding(torch.LongTensor([model.char2ind["a"]]))

In [None]:
# Does the right thing
import numpy as np

emb_mat = model.embedding.weights.data.numpy()

inp = np.matmul(np.array([[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose() # GOOD
h = np.array([[0 for _ in range(256)]]).transpose()

print(h.shape)
print(inp.shape)

uz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bz = np.expand_dims(model.enc_lstm.wz_bias.data.numpy(),1) # GOOD

ur = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
br = np.expand_dims(model.enc_lstm.wr_bias.data.numpy(),1)

ux = model.enc_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wx = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bx = np.expand_dims(model.enc_lstm.wx_bias.data.numpy() + model.enc_lstm.wrh_bias.data.numpy(),1) # GOOD


z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
print("z", z.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))
print("r", r.shape)

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)
print("htilde", htilde.shape)

h = (1 - z)*htilde + z*h
inp = np.matmul(np.array([[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose()
print(h)
#print("new_enc", h.shape)


z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
#print("z", z.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))
#print("r", r.shape)

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)
#print("htilde", htilde.shape)


h = (1 - z)*htilde + z*h

#print(h.shape)
print(h)





In [None]:
#full_w = flatten(np.concatenate([wz, wr, wrh], axis=1).tolist())
#full_x = flatten(np.concatenate([xz, xr, xx], axis=1).transpose().tolist())
#full_b = flatten(np.concatenate([bz,br,bxrh], axis=0).tolist())

full_w = flatten(np.concatenate([wz, wr, wrh], axis=0).transpose().tolist())
full_x = flatten(np.concatenate([xz, xr, xx], axis=0).tolist())
full_b = np.concatenate([bz,br,bx]).tolist()


In [None]:
bz

In [None]:
model.char2ind

In [None]:
uz.shape

In [None]:
model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().shape

In [None]:
# Good first hidden state
uz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bz = model.enc_lstm.wz_bias.data.numpy() # GOOD

ur = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
br = model.enc_lstm.wr_bias.data.numpy()

ux = model.enc_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wx = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bx = model.enc_lstm.wx_bias.data.numpy() + model.enc_lstm.wrh_bias.data.numpy() # GOOD


full_w = flatten(np.concatenate([wz, wr, wx], axis=0).transpose().tolist())
full_x = flatten(np.concatenate([uz, ur, ux], axis=0).transpose().tolist()) # CORRECT
full_b = flatten(np.concatenate([np.expand_dims(bz,1),np.expand_dims(br,1),np.expand_dims(bx,1)], axis=0).tolist()) # CORRECT

# No internal transpose:
# 0, no transpose: wrong
# 1, no transpose: right first thing, but not rest
# 0, transpose: wrong
# 1, transpose: wrong

# Internal transpose:
# 0, no transpose: wrong
# 1, no transpose: wrong
# 0, transpose: right first thing, but not rest
# 1, transpose: wrong


tf_weights = open("tf_weights.js", "w")
tf_weights.write("emb_wg = " + stringify_lst(flatten(model.embedding.weights.data.numpy().tolist())) + ";\n")
tf_weights.write("full_x = " + stringify_lst(full_x) + ";\n")
tf_weights.write("full_w = " + stringify_lst(full_w) + ";\n")
tf_weights.write("full_b = " + stringify_lst(full_b) + ";\n")
tf_weights.write("full_xd = " + stringify_lst(full_xd) + ";\n")
tf_weights.write("full_wd = " + stringify_lst(full_wd) + ";\n")
tf_weights.write("full_bd = " + stringify_lst(full_bd) + ";\n")
tf_weights.write("out_wg = " + stringify_lst(flatten(model.dec_output.weights.transpose(0,1).data.numpy().tolist())) + ";\n")
tf_weights.write("out_wb = " + stringify_lst(model.dec_output.bias.data.numpy().tolist()) + ";\n")


In [None]:
model.char2ind

In [11]:
from load_data import *
from utils import *
from training import *
from models import *

model = EncoderDecoder(34,10,256, recurrent_unit="GRU")
model.load_state_dict(torch.load("yonc_maml_gru_256_5.weights"))
model.set_dicts("a e i o u A E I O U b c d f g h j k l m n p q r s t v w x z .".split())

In [13]:
model(["za"])[3]

[tensor([[[-0.3812, -0.4205, -0.8237,  0.8546,  0.8499, -0.2240,  0.2781,
            0.5281,  0.5563, -0.7545, -0.3346,  0.2481, -0.9950, -0.2187,
           -0.8483,  0.2551, -0.3982, -0.7903,  0.5887,  0.6602, -0.6101,
            0.6130, -0.5756, -0.8793,  0.1599, -0.1643, -0.0342,  0.5934,
           -0.8864, -0.4143,  0.9686, -0.3760, -0.8288, -0.2831,  0.5956,
           -0.9481,  0.5187,  0.6435,  0.1961, -0.9974, -0.4230, -0.7563,
           -0.9837,  0.9690,  0.1212, -0.8425, -0.2087, -0.0198,  0.3525,
            0.9036, -0.3997, -0.2741,  0.5559,  0.7476, -0.8179, -0.9807,
           -0.3296,  0.5416, -0.6303,  0.6871, -0.8596,  0.8338,  0.6590,
           -0.4631,  0.0190, -0.1234,  0.1824,  0.9628, -0.4491, -0.5275,
            0.4260,  0.9060,  0.3895,  0.9231, -0.4454,  0.8645,  0.2057,
           -0.3400, -0.5472, -0.7450,  0.0307,  0.8704, -0.9603,  0.2843,
            0.6638, -0.2122,  0.2270, -0.9739, -0.5123, -0.9031,  0.9681,
           -0.6376, -0.7842,  0.1731, 

In [59]:
# Retry
uz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bz = model.enc_lstm.wz_bias.data.numpy() # GOOD

ur = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
br = model.enc_lstm.wr_bias.data.numpy()

ux = model.enc_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wx = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bx = model.enc_lstm.wx_bias.data.numpy() + model.enc_lstm.wrh_bias.data.numpy() # GOOD


full_w = flatten(np.concatenate([wz.transpose(), wr.transpose(), wx.transpose()], axis=1).tolist())
full_x = flatten(np.concatenate([uz, ur, ux], axis=0).transpose().tolist()) # CORRECT
full_b = flatten(np.concatenate([np.expand_dims(bz,1),np.expand_dims(br,1),np.expand_dims(bx,1)], axis=0).tolist()) # CORRECT

# No internal transpose:
# 0, no transpose: wrong
# 1, no transpose: right first thing, but not rest
# 0, transpose: wrong
# 1, transpose: wrong

# Internal transpose:
# 0, no transpose: wrong
# 1, no transpose: wrong
# 0, transpose: right first thing, but not rest
# 1, transpose: wrong


# Retry
uzd = model.dec_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wzd = model.dec_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bzd = model.dec_lstm.wz_bias.data.numpy() # GOOD

urd = model.dec_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wrd = model.dec_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
brd = model.dec_lstm.wr_bias.data.numpy()

uxd = model.dec_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wxd = model.dec_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bxd = model.dec_lstm.wx_bias.data.numpy() + model.dec_lstm.wrh_bias.data.numpy() # GOOD


full_wd = flatten(np.concatenate([wzd.transpose(), wrd.transpose(), wxd.transpose()], axis=1).tolist())
full_xd = flatten(np.concatenate([uzd, urd, uxd], axis=0).transpose().tolist()) # CORRECT
full_bd = flatten(np.concatenate([np.expand_dims(bzd,1),np.expand_dims(brd,1),np.expand_dims(bxd,1)], axis=0).tolist()) # CORRECT



tf_weights = open("tf_weights.js", "w")
tf_weights.write("emb_wg = " + stringify_lst(flatten(model.embedding.weights.data.numpy().tolist())) + ";\n")
tf_weights.write("full_x = " + stringify_lst(full_x) + ";\n")
tf_weights.write("full_w = " + stringify_lst(full_w) + ";\n")
tf_weights.write("full_b = " + stringify_lst(full_b) + ";\n")
tf_weights.write("full_xd = " + stringify_lst(full_xd) + ";\n")
tf_weights.write("full_wd = " + stringify_lst(full_wd) + ";\n")
tf_weights.write("full_bd = " + stringify_lst(full_bd) + ";\n")
tf_weights.write("out_wg = " + stringify_lst(flatten(model.dec_output.weights.transpose(0,1).data.numpy().tolist())) + ";\n")
tf_weights.write("out_wb = " + stringify_lst(model.dec_output.bias.data.numpy().tolist()) + ";\n")


748

In [62]:
model(["za"])[1]

[tensor([[[-3.1345e+01, -3.1722e+01, -1.0695e+01, -1.1993e+01, -1.9148e+01,
           -1.8033e+01, -1.8337e+01, -2.0455e+01, -1.7708e+01, -1.9106e+01,
           -1.7292e+01, -2.0588e+01, -2.1188e+01, -2.1301e+01, -1.7761e+01,
           -1.7994e+01, -2.0617e+01, -2.2182e+01, -2.0595e+01, -1.8950e+01,
           -2.4255e+01, -2.0938e+01, -1.6749e+01, -2.1373e+01, -1.9037e+01,
           -1.9733e+01, -2.0761e+01, -1.6086e+01, -1.9191e+01, -1.6779e+01,
           -2.0447e+01, -2.0326e+01, -8.7156e+00, -1.9322e-04]]],
        grad_fn=<LogSoftmaxBackward>),
 tensor([[[-2.5725e+01, -2.5668e+01, -1.1383e+01, -9.6359e+00, -1.6666e+01,
           -1.5934e+01, -1.6515e+01, -1.7916e+01, -1.5906e+01, -1.6670e+01,
           -1.5977e+01, -1.8716e+01, -1.9436e+01, -1.3169e+01, -9.7857e+00,
           -9.6532e+00, -1.2284e+01, -1.4586e+01, -1.2750e+01, -1.0180e+01,
           -1.6266e+01, -1.3031e+01, -9.9095e+00, -1.2121e+01, -1.0848e+01,
           -1.1269e+01, -1.2470e+01, -7.2635e+00, -1.1487e+

In [26]:
model.char2ind

{'NULL': 0,
 'SOS': 1,
 'EOS': 2,
 'a': 3,
 'e': 4,
 'i': 5,
 'o': 6,
 'u': 7,
 'A': 8,
 'E': 9,
 'I': 10,
 'O': 11,
 'U': 12,
 'b': 13,
 'c': 14,
 'd': 15,
 'f': 16,
 'g': 17,
 'h': 18,
 'j': 19,
 'k': 20,
 'l': 21,
 'm': 22,
 'n': 23,
 'p': 24,
 'q': 25,
 'r': 26,
 's': 27,
 't': 28,
 'v': 29,
 'w': 30,
 'x': 31,
 'z': 32,
 '.': 33}

In [None]:
full_b

In [None]:
model.embedding(torch.LongTensor([model.char2ind["z"]]))

In [61]:
# Does the right thing
import numpy as np

emb_mat = model.embedding.weights.data.numpy()

inp = np.matmul(np.array([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0]]), emb_mat).transpose() # GOOD
h = np.array([[0 for _ in range(256)]]).transpose()

print(h.shape)
print(inp.shape)

uz = model.enc_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wz = model.enc_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bz = np.expand_dims(model.enc_lstm.wz_bias.data.numpy(),axis=1) # GOOD

ur = model.enc_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wr = model.enc_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
br = np.expand_dims(model.enc_lstm.wr_bias.data.numpy(),axis=1)

ux = model.enc_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wx = model.enc_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bx = np.expand_dims(model.enc_lstm.wx_bias.data.numpy() + model.enc_lstm.wrh_bias.data.numpy(),axis=1) # GOOD


uzd = model.dec_lstm.wz_weights.transpose(0,1)[:10].data.numpy().transpose() # GOOD
wzd = model.dec_lstm.wz_weights.transpose(0,1)[10:].data.numpy().transpose()
bzd = np.expand_dims(model.dec_lstm.wz_bias.data.numpy(),axis=1) # GOOD

urd = model.dec_lstm.wr_weights.transpose(0,1)[:10].data.numpy().transpose()
wrd = model.dec_lstm.wr_weights.transpose(0,1)[10:].data.numpy().transpose()
brd = np.expand_dims(model.dec_lstm.wr_bias.data.numpy(),axis=1)

uxd = model.dec_lstm.wx_weights.transpose(0,1).data.numpy().transpose() # GOOD
wxd = model.dec_lstm.wrh_weights.transpose(0,1).data.numpy().transpose()
bxd = np.expand_dims(model.dec_lstm.wx_bias.data.numpy() + model.dec_lstm.wrh_bias.data.numpy(),axis=1) # GOOD


z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
print("z", bz.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)




h = (1 - z)*htilde + z*h

inp = np.matmul(np.array([[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose()

print(h)
#print("new_enc", h.shape)


z_pre = np.matmul(uz,inp) + np.matmul(wz, h) + bz
z = np.exp(z_pre) / (1 + np.exp(z_pre))
#print("z", z.shape)

r_pre = np.matmul(ur,inp) + np.matmul(wr, h) + br
r = np.exp(r_pre) / (1 + np.exp(r_pre))
#print("r", r.shape)

htilde_pre = np.matmul(ux,inp) + np.matmul(wx, r*h) + bx
htilde = np.tanh(htilde_pre)
#print("htilde", htilde.shape)


h = (1 - z)*htilde + z*h

#print(h.shape)
print("new enc")
print(h)

inp = np.matmul(np.array([[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]), emb_mat).transpose()
#print("new_enc", h.shape)


z_pre = np.matmul(uzd,inp) + np.matmul(wzd, h) + bzd
z = np.exp(z_pre) / (1 + np.exp(z_pre))
#print("z", z.shape)

r_pre = np.matmul(urd,inp) + np.matmul(wrd, h) + brd
r = np.exp(r_pre) / (1 + np.exp(r_pre))
#print("r", r.shape)

htilde_pre = np.matmul(uxd,inp) + np.matmul(wxd, r*h) + bxd
htilde = np.tanh(htilde_pre)
#print("htilde", htilde.shape)


h = (1 - z)*htilde + z*h
print("in the dec!!!")
print(h)



(256, 1)
(10, 1)
z (256, 1)
[[ 2.09263693e-01]
 [-3.17104201e-01]
 [-3.46960212e-01]
 [-3.82099066e-01]
 [-2.71665826e-01]
 [ 5.18073654e-01]
 [ 2.30343703e-01]
 [ 1.36834229e-01]
 [-2.99621327e-01]
 [ 5.35334617e-02]
 [-8.24420005e-02]
 [ 1.70503532e-01]
 [-6.46259067e-01]
 [-5.18761615e-02]
 [-1.73623358e-01]
 [-1.36645874e-01]
 [-2.27169164e-01]
 [-8.66621523e-01]
 [-1.16869672e-01]
 [ 1.18509523e-01]
 [ 3.23339634e-01]
 [ 6.70590519e-01]
 [-2.80037580e-01]
 [ 2.53936094e-01]
 [-1.55476413e-01]
 [ 9.44273448e-02]
 [-1.61042798e-01]
 [ 7.55784983e-01]
 [ 3.73489295e-01]
 [ 2.23931687e-01]
 [-8.35582285e-02]
 [-2.44428162e-01]
 [-5.50272047e-01]
 [-6.31977549e-02]
 [ 4.13851373e-01]
 [-2.67462097e-01]
 [ 1.12094530e-01]
 [ 5.21987963e-02]
 [ 1.48834927e-02]
 [-1.85149505e-01]
 [ 1.09912258e-01]
 [-2.77185246e-01]
 [-1.34415487e-02]
 [-1.29794342e-01]
 [ 3.65792176e-01]
 [-1.51237959e-01]
 [ 2.71078636e-02]
 [-6.74984507e-02]
 [-5.26669585e-02]
 [-5.15164537e-01]
 [-3.33322784e-01]
 [-

In [None]:
bz.shape

In [57]:
model(["ba"])[3]

[tensor([[[-0.7056,  0.1680, -0.8589,  0.9210,  0.6723,  0.3097, -0.2655,
            0.6218,  0.7066, -0.8842, -0.5191,  0.4401, -0.9818, -0.4042,
           -0.6082, -0.1423,  0.4787, -0.2996,  0.2559,  0.8032, -0.6610,
           -0.5030,  0.4970, -0.8645, -0.2783, -0.2295, -0.0696, -0.3059,
           -0.7734, -0.2524,  0.9616, -0.2255, -0.9022,  0.9719,  0.5728,
           -0.8875,  0.1908,  0.5525,  0.2557, -0.9923, -0.3461, -0.8700,
           -0.9734,  0.8422,  0.6291, -0.8584,  0.0915,  0.8212,  0.3918,
            0.8232, -0.3085, -0.2670,  0.0442,  0.5186, -0.7909, -0.9605,
           -0.0644, -0.5299, -0.7515,  0.4523, -0.9330,  0.7422,  0.7111,
            0.2425, -0.7234, -0.4766,  0.0549,  0.9401, -0.6049,  0.1121,
            0.5427,  0.8490,  0.5464,  0.9613, -0.7076, -0.3046, -0.0979,
           -0.3552, -0.7209, -0.7082,  0.0614,  0.5661,  0.3739,  0.5323,
           -0.3221, -0.0046,  0.3147, -0.9543, -0.5502,  0.9482,  0.9584,
           -0.4668, -0.5796,  0.1773, 