In [16]:
def generate_cv(length):
    if length == 0:
        return [""]
    
    else:
        previous = generate_cv(length - 1)
        new = []
        
        for elt in previous:
            new.append(elt + "V")
            new.append(elt + "C")
            
        return new
    
def generate_cv_cumul(max_length):
    output = []
    for i in range(max_length + 1):
        output += generate_cv(i)
        
    return output

In [17]:
def syllabifiable(word):
    if word[:2] == "CC":
        return False
    if word[-2:] == "CC":
        return False
    if "CCC" in word:
        return False
    if word == "C":
        return False
    
    return True

def syllabify(word):
    prev = "#"
    syll = ""
    
    rev = word[::-1]
    
    for char in rev:
        if prev == "#":
            syll += "."
            syll += char
            prev = char
        elif prev == "C" and char == "V":
            syll += char
            prev = char
        elif prev == "V" and char == "C":
            syll += char
            syll += "."
            prev = "."
        elif prev == "V" and char == "V":
            syll += "."
            syll += char
            prev = char
        elif prev == ".":
            syll += char
            prev = char
        else:
            print(word)
    
    syll = syll[::-1]
    if len(syll) > 0:
        if syll[0] == "V" or syll[0] == "C":
            syll = "." + syll
    
    return syll

In [18]:
def violations(ur, sr):
    onset = 0
    nocoda = 0
    mx = 0
    dep = 0
    
    if len(sr) > 0:
        if sr[0] == ".":
            sr = sr[1:]
        if sr[-1] == ".":
            sr = sr[:-1]
        
        syllables = sr.split(".")
        

        # Onset, NoCoda
        for syllable in syllables:
            parts = syllable.split("V")
            ons = parts[0]
            cod = parts[1]
            if ons == "":
                onset += 1
            if cod != "":
                nocoda += 1
            
    # Max, Dep
    edit_paths = edit_path(ur,sr.replace(".",""))
    
    all_violations = []
    for path in edit_paths:
        all_violations.append([onset, nocoda, path[1], path[0]])
        
    return all_violations
        


In [19]:
def edit_path(w1,w2):
    l1 = len(w1) + 1
    l2 = len(w2) + 1
    
    #if l1 == 0 and l2 == 0:
    #    return [[0,0]]
    #elif l1 == 0:
    #    return [[l2, 0]]
    #elif l2 == 0:
    #    return [[0, l2]]
    
    grid = [[0 for i in range(l2)] for j in range(l1)]
    
    for ind in range(l1):
        grid[ind][0] = [[0,ind]]
    for ind in range(l2):
        grid[0][ind] = [[ind,0]]
        
    #print(grid)
        
    for i1 in range(1,l1):
        for i2 in range(1,l2):
            p1 = grid[i1-1][i2]
            p2 = grid[i1][i2-1]
            
            if w1[i1-1] == w2[i2-1]:
                possibles = grid[i1-1][i2-1]
            else:
                new_poss = []
                
                possibles = p1 
                for poss in possibles:
                    new_poss.append([poss[0], 1 + poss[1]])
                    
                possibles = p2 
                for poss in possibles:
                    new_poss.append([1 + poss[0], poss[1]])
                    
                possibles = new_poss
            
            grid[i1][i2] = min_cands(possibles)
            
    return grid[l1-1][l2-1]
    
    
    
    
    

In [20]:
def min_cands(cands):
    min_first = 1000000
    min_second = 1000000
    
    for cand in cands:
        first = cand[0]
        second = cand[1]
        
        if first < min_first:
            min_first = first
            
        if second < min_second:
            min_second = second
            
    firsts = []
    seconds = []
    
    for cand in cands:
        if cand[0] == min_first:
            firsts.append(cand)
        if cand[1] == min_second:
            seconds.append(cand)
            
    min_second_firsts = 1000000
    best_first = []
    for first in firsts:
        if first[1] < min_second_firsts:
            best_first = first
            
    min_first_seconds = 1000000
    best_second = []
    for second in seconds:
        if second[1] < min_first_seconds:
            best_second = second
            
    if best_first[0] ==  best_second[0] and best_first[1] == best_second[1]:
        return [best_first]
    else:
        return [best_first, best_second]
    
            
        

In [21]:
def winner(ur, candidates, ranking):
    all_violations = []
    for cand in candidates:
        viols = violations(ur, cand)
        for viol in viols:
            all_violations += [[cand, viol]] 
            
    for constraint in ranking:
        min_viols = 1000000
        for candidate in all_violations:
            #print(all_violations, candidate)
            this_constraint_viols = candidate[1][constraint]
            if this_constraint_viols < min_viols:
                min_viols = this_constraint_viols
                
        filtered_cands = []
        for candidate in all_violations:
            if candidate[1][constraint] == min_viols:
                filtered_cands.append(candidate)
                
        all_violations = filtered_cands
        
    return all_violations

In [22]:
inputs = generate_cv_cumul(5)
outputs = []

for inp in generate_cv_cumul(10):
    if syllabifiable(inp):
        outputs.append(syllabify(inp))
        


In [23]:
import random
from random import shuffle

In [24]:
train_rankings = [
    [0,1,2,3],
    [0,1,3,2],
    [0,2,3,1],
    [0,3,2,1],
    [2,3,0,1],
    [3,2,0,1],
    [1,2,3,0], # withheld
    [1,3,2,0] # withheld
]

test_rankings = [
]

all_rankings = train_rankings + test_rankings

In [26]:
all_input_outputs = {}

for ranking in all_rankings:
    print(ranking)
    io_list = []
    
    for inp in inputs:
        print(inp)
        output = winner(inp, outputs, ranking)[0][0]
        io_list.append([inp, output])
    
    all_input_outputs[tuple(ranking)] = io_list

[0, 1, 2, 3]

V
C
VV
VC
CV
CC
VVV
VVC
VCV
VCC
CVV
CVC
CCV
CCC
VVVV
VVVC
VVCV
VVCC
VCVV
VCVC
VCCV
VCCC
CVVV
CVVC
CVCV
CVCC
CCVV
CCVC
CCCV
CCCC
VVVVV
VVVVC
VVVCV
VVVCC
VVCVV
VVCVC
VVCCV
VVCCC
VCVVV
VCVVC
VCVCV
VCVCC
VCCVV
VCCVC
VCCCV
VCCCC
CVVVV
CVVVC
CVVCV
CVVCC
CVCVV
CVCVC
CVCCV
CVCCC
CCVVV
CCVVC
CCVCV
CCVCC
CCCVV
CCCVC
CCCCV
CCCCC
[0, 1, 3, 2]

V
C
VV
VC
CV
CC
VVV
VVC
VCV
VCC
CVV
CVC
CCV
CCC
VVVV
VVVC
VVCV
VVCC
VCVV
VCVC
VCCV
VCCC
CVVV
CVVC
CVCV
CVCC
CCVV
CCVC
CCCV
CCCC
VVVVV
VVVVC
VVVCV
VVVCC
VVCVV
VVCVC
VVCCV
VVCCC
VCVVV
VCVVC
VCVCV
VCVCC
VCCVV
VCCVC
VCCCV
VCCCC
CVVVV
CVVVC
CVVCV
CVVCC
CVCVV
CVCVC
CVCCV
CVCCC
CCVVV
CCVVC
CCVCV
CCVCC
CCCVV
CCCVC
CCCCV
CCCCC
[0, 2, 3, 1]

V
C
VV
VC
CV
CC
VVV
VVC
VCV
VCC
CVV
CVC
CCV
CCC
VVVV
VVVC
VVCV
VVCC
VCVV
VCVC
VCCV
VCCC
CVVV
CVVC
CVCV
CVCC
CCVV
CCVC
CCCV
CCCC
VVVVV
VVVVC
VVVCV
VVVCC
VVCVV
VVCVC
VVCCV
VVCCC
VCVVV
VCVVC
VCVCV
VCVCC
VCCVV
VCCVC
VCCCV
VCCCC
CVVVV
CVVVC
CVVCV
CVVCC
CVCVV
CVCVC
CVCCV
CVCCC
CCVVV
CCVVC
CCVCV
CCVCC
CCCVV
CCCVC
CCCCV
CCCC

In [27]:
def make_task(ranking, all_input_outputs, n=10):
    io_list = all_input_outputs[tuple(ranking)][:]
    shuffle(io_list)
    
    train_pairs = io_list[:n]
    test_pairs = io_list[n:]
        
    return train_pairs, test_pairs, None, None

In [28]:
io_dict = {}

for key in all_input_outputs:
    interior_dict = {}
    
    io_list = all_input_outputs[key]
    
    for elt in io_list:
        interior_dict[elt[0]] = elt[1]
        
    io_dict[key] = interior_dict

In [29]:
def same_preds_dict(inputs, r1, r2):
    for inp in inputs:
        p1 = io_dict[tuple(r1)][inp]
        p2 = io_dict[tuple(r2)][inp]
        
        if p1 != p2:
            return False
        
    return True

In [30]:
def check_task_dict(task, rankings):
    inputs = []
    for pair in task[0]:
        inputs.append(pair[0])
        
    for i1 in range(len(rankings)):
        for i2 in range(i1+1, len(rankings)):
            r1 = rankings[i1]
            r2 = rankings[i2]
    
            if same_preds_dict(inputs, r1, r2):
                return False
            
    return True

In [31]:
train_set = []

n_train = 20

while len(train_set) < 20000:
    ranking = random.choice(train_rankings)
    task = make_task(ranking, all_input_outputs, n=n_train)
    
    if check_task_dict(task, all_rankings):
        train_set.append(task)
        
        if len(train_set) % 1000 == 0:
            print(len(train_set))


1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000


In [32]:
train_set[0]

([['VCCCV', '.V.CV.'],
  ['CVVC', '.CV.V.'],
  ['VCVCC', '.V.CV.'],
  ['VVVVV', '.V.V.V.V.V.'],
  ['VVCVC', '.V.V.CV.'],
  ['CVCCV', '.CV.CV.'],
  ['VVVVC', '.V.V.V.V.'],
  ['CCVVV', '.CV.V.V.'],
  ['V', '.V.'],
  ['VCVVV', '.V.CV.V.V.'],
  ['VCCC', '.V.'],
  ['VCCVC', '.V.CV.'],
  ['CCV', '.CV.'],
  ['CCVCC', '.CV.'],
  ['', ''],
  ['CVCC', '.CV.'],
  ['VC', '.V.'],
  ['CC', ''],
  ['CVCCC', '.CV.'],
  ['VCV', '.V.CV.']],
 [['CVC', '.CV.'],
  ['CVCVV', '.CV.CV.V.'],
  ['VCCVV', '.V.CV.V.'],
  ['VVC', '.V.V.'],
  ['CVCVC', '.CV.CV.'],
  ['VCVVC', '.V.CV.V.'],
  ['CCVVC', '.CV.V.'],
  ['CCCVC', '.CV.'],
  ['VVCCC', '.V.V.'],
  ['VVCVV', '.V.V.CV.V.'],
  ['CCCC', ''],
  ['CCC', ''],
  ['VVVCC', '.V.V.V.'],
  ['CVVCC', '.CV.V.'],
  ['CCCV', '.CV.'],
  ['VCCCC', '.V.'],
  ['CVVCV', '.CV.V.CV.'],
  ['CCVC', '.CV.'],
  ['CVCV', '.CV.CV.'],
  ['VV', '.V.V.'],
  ['CCCVV', '.CV.V.'],
  ['CVVVC', '.CV.V.V.'],
  ['VCVCV', '.V.CV.CV.'],
  ['CCCCV', '.CV.'],
  ['VVCV', '.V.V.CV.'],
  ['CCVV', '.CV.

In [1]:
from load_data import *
from models import *
from random import shuffle

In [2]:
shuffle([3,4,5])

In [3]:
train_set = load_dataset("phonology.train")
dev_set = load_dataset("phonology.dev")
test_set = load_dataset("phonology.test")

In [33]:
class EncoderDecoder(ModifiableModule):
    def __init__(self, enc_vocab_size, dec_vocab_size, input_size, hidden_size):
        super(EncoderDecoder, self).__init__()
        self.enc_vocab_size = enc_vocab_size
        self.dec_vocab_size = dec_vocab_size
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.enc_embedding = GradEmbedding(enc_vocab_size, input_size)
        self.enc_lstm = GradLSTM(input_size, hidden_size)

        self.dec_embedding = GradEmbedding(dec_vocab_size, input_size)
        self.dec_lstm = GradLSTM(input_size, hidden_size)
        self.dec_output = GradLinear(hidden_size, dec_vocab_size)

        self.max_length = 20

        self.char2ind = {}
        self.ind2char = {}

    def forward(self, sequence_list):
        # Initialize the hidden state
        hidden = (V(torch.zeros(1, len(sequence_list), self.hidden_size)),
                  V(torch.zeros(1, len(sequence_list), self.hidden_size)))

        all_seqs = []
        for sequence in sequence_list:
            this_seq = []
            # Iterate over the sequence
            for elt in sequence:
                ind = self.char2ind[elt]
                this_seq.append(ind)
            all_seqs.append(torch.LongTensor(this_seq))
            
        #print(all_seqs)
        #print(torch.nn.utils.rnn.pad_sequence(all_seqs))
        #all_seqs = torch.LongTensor(all_seqs).transpose(0,1)
        all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)
        all_seqs_onehot = (all_seqs > 0).type(torch.FloatTensor)
        #print(all_seqs_onehot)
        #all_seqs = torch.nn.utils.rnn.pack_padded_sequence(all_seqs, lengths=[len(x) for x in sequence_list])
                
        #print(all_seqs)
        #print(all_seqs[0])
        #print(all_seqs.shape)
            
        for index, elt in enumerate(all_seqs):
            #print(elt)
            emb = self.enc_embedding(elt.unsqueeze(0))
            #print(emb.shape, hidden[0].shape, hidden[1].shape)
            output, hidden_new = self.enc_lstm(emb, hidden)
            hidden_prev = hidden
            #print(hidden_prev[0].shape)
            #print(all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape))
            hx = hidden_prev[0] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)) + hidden_new[0] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[0].shape)
            cx = hidden_prev[1] * (1 - all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)) + hidden_new[1] * all_seqs_onehot[index].unsqueeze(0).unsqueeze(2).expand(hidden_prev[1].shape)
            
            hidden = (hx, cx)
            #print("")
            #print(hidden)
            
        #print(hidden[0].shape, hidden[1].shape)
            

        # Return the final hidden state
        # Note that hidden = (hidden state, cell state). So hidden[0] is just the hidden state
        #encoding = hidden[0]
        #print(encoding.shape)

        #hidden = (encoding,
        #          V(torch.zeros(1, len(sequence_list), self.hidden_size)))

        prev_output = ["SOS" for _ in range(len(sequence_list))]
        out_strings = ["" for _ in range(len(sequence_list))]
        logits = []

        for i in range(self.max_length):
            prev_outputs = []
            for elt in prev_output:
                ind = self.char2ind[elt]
                prev_outputs.append(ind)
                
            emb = self.dec_embedding(torch.LongTensor([prev_outputs]))
            output, hidden = self.dec_lstm(emb, hidden)
            pred = self.dec_output(output)

            probs = F.log_softmax(pred, dim=2)
            logits.append(probs)

            topv, topi = probs.data.topk(1)
            label = topi[0] #.item()
            
            prev_output = []
            for index, elt in enumerate(label):
                char = self.ind2char[elt.item()]
                
                out_strings[index] += char
                prev_output.append(char)

            #char = self.ind2char[label]


            #if char == "EOS":
            #    break

            #out_string += char

            #prev_output = char

        # Return the final hidden state
        # Note that hidden = (hidden state, cell state). So hidden[0] is just the hidden state
        return out_strings, logits
    
    def named_submodules(self):
        return [('enc_embedding', self.enc_embedding), ('enc_lstm', self.enc_lstm),
                ('dec_embedding', self.dec_embedding), ('dec_lstm', self.dec_lstm),
                ('dec_output', self.dec_output)]

    def set_dict(self,v_list,c_list):
        char2ind = {}
        char2ind["NULL"] = 0
        char2ind["SOS"] = 1
        char2ind["EOS"] = 2
        char2ind["."] = 3
        char2ind["C"] = 4
        char2ind["V"] = 5
        #char2ind["a"] = 4
        #char2ind["b"] = 5
        char2ind["c"] = 6
        char2ind["d"] = 7
        char2ind["e"] = 8
        char2ind["f"] = 9
        char2ind["g"] = 10
        char2ind["h"] = 11
        char2ind["i"] = 12
        char2ind["j"] = 13
        char2ind["k"] = 14
        char2ind["l"] = 15
        char2ind["m"] = 16
        char2ind["n"] = 17
        char2ind["o"] = 18
        char2ind["p"] = 19
        char2ind["q"] = 20
        char2ind["r"] = 21
        char2ind["s"] = 22
        char2ind["t"] = 23
        char2ind["u"] = 24
        char2ind["v"] = 25
        char2ind["w"] = 26
        char2ind["x"] = 27
        char2ind["y"] = 28
        char2ind["z"] = 29
        char2ind["A"] = 30
        char2ind["E"] = 31
        char2ind["I"] = 32
        char2ind["O"] = 33
        char2ind["U"] = 34

        ind2char = {}
        ind2char[0] = "NULL"
        ind2char[1] = "SOS"
        ind2char[2] = "EOS"
        ind2char[3] = "."
        ind2char[4] = "C"
        ind2char[5] = "V"
        
        #ind2char[4] = "a"
        #ind2char[5] = "b"
        ind2char[6] = "c"
        ind2char[7] = "d"
        ind2char[8] = "e"
        ind2char[9] = "f"
        ind2char[10] = "g"
        ind2char[11] = "h"
        ind2char[12] = "i"
        ind2char[13] = "j"
        ind2char[14] = "k"
        ind2char[15] = "l"
        ind2char[16] = "m"
        ind2char[17] = "n"
        ind2char[18] = "o"
        ind2char[19] = "p"
        ind2char[20] = "q"
        ind2char[21] = "r"
        ind2char[22] = "s"
        ind2char[23] = "t"
        ind2char[24] = "u"
        ind2char[25] = "v"
        ind2char[26] = "w"
        ind2char[27] = "x"
        ind2char[28] = "y"
        ind2char[29] = "z"
        ind2char[30] = "A"
        ind2char[31] = "E"
        ind2char[32] = "I"
        ind2char[33] = "O"
        ind2char[34] = "U"

        #phoneme_list = v_list + c_list
        #possible_indices = [i for i in range(4,12)]
        #shuffle(possible_indices)

        #indices = possible_indices[:len(phoneme_list)]
        #remainder = possible_indices[len(phoneme_list):]

        #for i in range(len(phoneme_list)):
        #    char2ind[phoneme_list[i]] = indices[i]
        #    ind2char[indices[i]] = phoneme_list[i]

        #for ind in remainder:
        #    char2ind[str(ind)] = ind
        #    ind2char[ind] = str(ind)

        self.char2ind = char2ind
        self.ind2char = ind2char






In [34]:
CV_REPTILE = EncoderDecoder(6,6,20,128)

In [35]:
CV_REPTILE.set_dict(train_set[0][2], train_set[0][3])

In [36]:
train_set[0][0][:3]

[['VCCCV', '.V.CV.'], ['CVVC', '.CV.V.'], ['VCVCC', '.V.CV.']]

In [37]:
CV_REPTILE(['VCCCV'])

(['SOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOSSOS'],
 [tensor([[[-1.7259, -1.6470, -1.8447, -1.8279, -1.8552, -1.8701]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7548, -1.6605, -1.8376, -1.7924, -1.8412, -1.8798]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7701, -1.6678, -1.8338, -1.7753, -1.8349, -1.8829]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7780, -1.6715, -1.8320, -1.7667, -1.8321, -1.8839]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7821, -1.6732, -1.8313, -1.7623, -1.8308, -1.8843]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7843, -1.6739, -1.8310, -1.7601, -1.8302, -1.8845]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7856, -1.6740, -1.8310, -1.7589, -1.8299, -1.8846]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7863, -1.6740, -1.8309, -1.7583, -1.8298, -1.8848]]],
         grad_fn=<LogSoftmaxBackward>),
  tensor([[[-1.7867, -1.6738, -1.8309, -1.7580, -1.8297, -1.8

In [None]:
"dogEOS"[:"dogEOS".index("EOS")]

In [None]:
"dogEOS".index("EOS")

In [5]:
def process_output(output):
    if "EOS" in output:
        return output[:output.index("EOS")]
    else:
        return output

In [56]:


def fit_task(model, task, lr_inner, create_graph=True):
    training_set = task[0]
    test_set = task[1]
    v_list = task[2]
    c_list = task[3]
    
    new_model = EncoderDecoder(model.enc_vocab_size,model.dec_vocab_size,model.input_size,model.hidden_size)
    new_model.copy(model, same_var=True)
    #new_model.copy(model, same_var=False)
    
    new_model.set_dict(v_list, c_list)
    
    loss = 0
    criterion = nn.NLLLoss(ignore_index=0, size_average=False)

    output, logits = new_model([pair[0] for pair in training_set])
    
    all_seqs = [] # HERE
    for inp, sequence in training_set:
        this_seq = []
        # Iterate over the sequencex
        for elt in sequence:
            ind = new_model.char2ind[elt]
            this_seq.append(ind)
        this_seq.append(new_model.char2ind["EOS"])
        all_seqs.append(torch.LongTensor(this_seq))
            
    all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)
    #print(all_seqs)

    for index, logit in enumerate(logits):
        if index >= len(all_seqs):
            break
            
        loss += criterion(logit[0], all_seqs[index])
        
            
    total_values = sum([len(x[1]) + 1 for x in training_set])  
    loss /= total_values
    loss.backward(create_graph=create_graph, retain_graph=True)
    
    #print("")
    #print("INNER UPDATE")
    #print(new_model.enc_lstm.wi_bias)
    #for param in new_model.params():
    #    print(param[0])
        #print("")
   #     break
    for name, param in new_model.named_params():
        grad = param.grad
   #     print("grad", grad)
   #     print(param)
   #     print(param - lr_inner * grad)
                
        new_model.set_param(name, param - lr_inner * grad)
    #print(new_model.enc_lstm.wi_bias)
    #for param in new_model.params():
    #    print(param[0])
    #    print("DONE INNER UPDATE")
    #    print("")
        
    #    break
        
    
    
    
    test_loss = 0  
    correct = 0
    total = 0
    
    all_seqs = [] # HERE
    for inp, sequence in test_set:
        this_seq = []
        # Iterate over the sequencex
        for elt in sequence:
            ind = new_model.char2ind[elt]
            this_seq.append(ind)
        this_seq.append(new_model.char2ind["EOS"])
        all_seqs.append(torch.LongTensor(this_seq))
            
    all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)
    #print(all_seqs)

    output, logits = new_model([pair[0] for pair in test_set])
    for index, output_guess in enumerate(output):
        if process_output(output_guess) == test_set[index][1]:
            correct += 1
        #else:
            #print(process_output(output_guess))
            #print(test_set[index][1])
            #print("")
        total += 1
        
    for index, logit in enumerate(logits):
        if index >= len(all_seqs):
            break
            
        test_loss += criterion(logit[0], all_seqs[index])

    total_values = sum([len(x[1]) + 1 for x in test_set])  
    #print(total_values)
    test_loss /= total_values
   
    test_acc = correct * 1.0 / total
                
    
    return test_loss, test_acc, new_model



In [57]:
def maml(model, epochs, train_set, lr_inner=0.0001, lr_outer=0.001, batch_size=1, first_order=False):
    optimizer = torch.optim.Adam(model.params(), lr=lr_outer)
    
    for _ in range(epochs):
        
        total_postupdate_test_losses = 0
        total_test_accs = 0
        count_postupdate_test_losses = 0
        
        print_total_postupdate_test_losses = 0
        print_total_test_accs = 0
        print_count_postupdate_test_losses = 0

        for i, t in enumerate(train_set):
            #if i % 10 == 0:
            #    print(i)
            
            #print("MAML UPDATE")
            #print(model.enc_lstm.wi_bias)
            #for param in model.params():
            #    print(param[0])
            #    break
            test_loss, test_acc, new_model = fit_task(model, t, lr_inner, create_graph=not first_order)
            #for param in model.params():
            #    print(param[0])
            #    break
            #print(model.enc_lstm.wi_bias) 
            #print(new_model.enc_lstm.wi_bias)
            #for param in new_model.params():
            #    print(param[0])
            #    break
            #print("END MAML UPDATE")
            #print("")
                
            #14/0
            
            total_postupdate_test_losses += test_loss
            total_test_accs += test_acc
            count_postupdate_test_losses += 1
            
            print_total_postupdate_test_losses += test_loss.item()
            print_total_test_accs += test_acc
            print_count_postupdate_test_losses += 1
            
            if i % 1000 == 0:
                print(i,print_total_postupdate_test_losses/print_count_postupdate_test_losses, print_total_test_accs/print_count_postupdate_test_losses)
                print_total_postupdate_test_losses = 0
                print_total_test_accs = 0
                print_count_postupdate_test_losses = 0


            if (i + 1) % batch_size == 0:
                
                total_postupdate_test_losses /= count_postupdate_test_losses
                total_postupdate_test_losses.backward(create_graph=True, retain_graph=True)
                
                #print("")
                #print("BEFORE UPDATE")
                #print(model.enc_lstm.wi_bias) 
                #for param in model.params():
                #    print(param[0])
                #    break
                    
                optimizer.step()
                optimizer.zero_grad()
                
                #print("AFTER UPDATE")
                #print(model.enc_lstm.wi_bias) 
                #for param in model.params():
                #    print(param[0])
                #    break
                #print("")
                
                total_postupdate_test_losses = 0
                total_test_accs = 0
                count_postupdate_test_losses = 0
                

In [68]:
def reptile(model, epochs, train_set, lr_inner=0.0001, lr_outer=0.001, batch_size=1, first_order=False):
    optimizer = torch.optim.Adam(model.params(), lr=lr_outer)
    
    name_to_param = dict(model.named_params())
    
    for _ in range(epochs):
        
        total_postupdate_test_losses = 0
        total_test_accs = 0
        count_postupdate_test_losses = 0
        
        print_total_postupdate_test_losses = 0
        print_total_test_accs = 0
        print_count_postupdate_test_losses = 0

        for i, t in enumerate(train_set):
            
            test_loss, test_acc, new_model = fit_task(model, t, lr_inner, create_graph=not first_order)
            
            
            for name, param in new_model.named_params():
                cur_grad = (name_to_param[name].data - param.data) / lr_inner
                if name_to_param[name].grad is None:
                    name_to_param[name].grad = V(torch.zeros(cur_grad.size()))
                name_to_param[name].grad.data.add_(cur_grad / batch_size)
            
            print_total_postupdate_test_losses += test_loss.item()
            print_total_test_accs += test_acc
            print_count_postupdate_test_losses += 1
            
            if i % 1000 == 0:
                print(i,print_total_postupdate_test_losses/print_count_postupdate_test_losses, print_total_test_accs/print_count_postupdate_test_losses)
                print_total_postupdate_test_losses = 0
                print_total_test_accs = 0
                print_count_postupdate_test_losses = 0


            if (i + 1) % batch_size == 0:
                
               
                    
                optimizer.step()
                optimizer.zero_grad()
                
                
                total_postupdate_test_losses = 0
                total_test_accs = 0
                count_postupdate_test_losses = 0
                

In [69]:
CV_REPTILE = EncoderDecoder(6,6,6,60)

In [None]:
for param in CV_REPTILE.named_params():
    #print(param)
    #break
    pass

print CV_REPTILE.enc_lstm.wi_bias

In [None]:
train_set[8]

In [39]:
_, acc, new_model = fit_task(CV_REPTILE, train_set[8], 0.001)
print(acc)

0.0


In [None]:
process_output(new_model(['chpcp'])[0][0])

In [62]:
maml(CV_REPTILE, 10, train_set, batch_size=1, lr_inner=0.01, first_order=True)

(0, 1.783060073852539, 0.0)
(1000, 0.9749329946637154, 0.10320930232558084)
(2000, 0.7846943070590496, 0.16751162790697627)
(3000, 0.6780034032464027, 0.22872093023255774)
(4000, 0.6204485424160957, 0.2227674418604649)
(5000, 0.5943890239596367, 0.23644186046511603)
(6000, 0.5718764623105526, 0.22309302325581323)
(7000, 0.565203436717391, 0.2372790697674412)
(8000, 0.5544433343410492, 0.2547674418604648)
(9000, 0.5509172289669514, 0.2450697674418601)


KeyboardInterrupt: 

In [71]:
reptile(CV_REPTILE, 10, train_set, batch_size=1, lr_inner=0.01, first_order=True)

(0, 0.47298333048820496, 0.3023255813953488)
(1000, 0.5191877368688583, 0.24927906976744155)
(2000, 0.5177044985890389, 0.2578372093023253)
(3000, 0.5197958633601666, 0.26423255813953417)
(4000, 0.515871185451746, 0.2692093023255812)
(5000, 0.5149109952151776, 0.2600465116279063)
(6000, 0.5185968863070011, 0.24397674418604628)
(7000, 0.5187369344234467, 0.25827906976744186)
(8000, 0.5143364772498608, 0.2686511627906968)
(9000, 0.5164067787826061, 0.24976744186046465)
(10000, 0.5237976936101914, 0.2374651162790695)
(11000, 0.5160904158949852, 0.25734883720930224)
(12000, 0.5153337240815162, 0.2696511627906975)
(13000, 0.5183304651677608, 0.25334883720930207)
(14000, 0.516270188510418, 0.25367441860465056)
(15000, 0.5138649631142617, 0.2741860465116277)
(16000, 0.5170898350477219, 0.25553488372092986)
(17000, 0.5166957820057869, 0.26272093023255777)
(18000, 0.5226676015555859, 0.2456744186046509)
(19000, 0.5099458540678025, 0.26951162790697625)
(0, 0.4973919689655304, 0.32558139534883723

KeyboardInterrupt: 

In [52]:
_, acc, new_model = fit_task(CV_REPTILE, train_set[0], 0.001)
print(acc)

0.232558139535


In [46]:
train_set[0]

([['VCCCV', '.V.CV.'],
  ['CVVC', '.CV.V.'],
  ['VCVCC', '.V.CV.'],
  ['VVVVV', '.V.V.V.V.V.'],
  ['VVCVC', '.V.V.CV.'],
  ['CVCCV', '.CV.CV.'],
  ['VVVVC', '.V.V.V.V.'],
  ['CCVVV', '.CV.V.V.'],
  ['V', '.V.'],
  ['VCVVV', '.V.CV.V.V.'],
  ['VCCC', '.V.'],
  ['VCCVC', '.V.CV.'],
  ['CCV', '.CV.'],
  ['CCVCC', '.CV.'],
  ['', ''],
  ['CVCC', '.CV.'],
  ['VC', '.V.'],
  ['CC', ''],
  ['CVCCC', '.CV.'],
  ['VCV', '.V.CV.']],
 [['CVC', '.CV.'],
  ['CVCVV', '.CV.CV.V.'],
  ['VCCVV', '.V.CV.V.'],
  ['VVC', '.V.V.'],
  ['CVCVC', '.CV.CV.'],
  ['VCVVC', '.V.CV.V.'],
  ['CCVVC', '.CV.V.'],
  ['CCCVC', '.CV.'],
  ['VVCCC', '.V.V.'],
  ['VVCVV', '.V.V.CV.V.'],
  ['CCCC', ''],
  ['CCC', ''],
  ['VVVCC', '.V.V.V.'],
  ['CVVCC', '.CV.V.'],
  ['CCCV', '.CV.'],
  ['VCCCC', '.V.'],
  ['CVVCV', '.CV.V.CV.'],
  ['CCVC', '.CV.'],
  ['CVCV', '.CV.CV.'],
  ['VV', '.V.V.'],
  ['CCCVV', '.CV.V.'],
  ['CVVVC', '.CV.V.V.'],
  ['VCVCV', '.V.CV.CV.'],
  ['CCCCV', '.CV.'],
  ['VVCV', '.V.V.CV.'],
  ['CCVV', '.CV.

In [55]:
process_output(new_model([''])[0][0])

''

In [None]:
print(torch.__version__)

In [8]:
CV_REPTILEB = EncoderDecoder(35,35,35,128)

In [12]:
maml(CV_REPTILEB, 10, train_set, batch_size=1, lr_inner=0.01)

(0, 3.5397775173187256, 0.0)
(1000, 1.9945365778803825, 0.01503999999999996)


KeyboardInterrupt: 

In [None]:
CV_REPTILEB.state_dict()

In [13]:
torch.save(CV_REPTILEB.state_dict(), "CV_REPTILE.weights")

In [14]:
loaded_model = EncoderDecoder(35,35,35,128)
loaded_model.load_state_dict(torch.load("CV_REPTILE.weights"))

In [15]:
_, acc, new_model = fit_task(loaded_model, train_set[8], 0.001)
print(acc)

0.02


In [None]:
train_set[8]

In [None]:
_, acc, new_model = fit_task(CV_REPTILEB, train_set[8], 0.001)
print(acc)

In [None]:
process_output(new_model(['p'])[0][0])

In [None]:


def fit_task(model, task, lr_inner):
    training_set = task[0]
    test_set = task[1]
    v_list = task[2]
    c_list = task[3]
    
    new_model = EncoderDecoder(model.enc_vocab_size,model.dec_vocab_size,model.input_size,model.hidden_size)
    new_model.copy(model, same_var=True)
    
    new_model.set_dict(v_list, c_list)
    
    loss = 0
    criterion = nn.NLLLoss(ignore_index=0, size_average=False)

    #for pair in training_set:
    #    inp, outp = pair
        #output, logits = new_model(inp)
    #    output, logits = new_model([pair[0] for pair in training_set])

    #    for index, logit in enumerate(logits):
    #        if index == len(pair[1]):
    #            loss += criterion(logit[0], torch.LongTensor([new_model.char2ind["EOS"]]))
    #        elif index > len(pair[1]):
    #            break
    #        else:
    #            loss += criterion(logit[0], torch.LongTensor([new_model.char2ind[pair[1][index]]]))
                
    #print(new_model.char2ind)
    #print(new_model.enc_vocab_size)
    output, logits = new_model([pair[0] for pair in training_set])
    
    all_seqs = [] # HERE
    for inp, sequence in training_set:
        this_seq = []
        # Iterate over the sequencex
        for elt in sequence:
            ind = new_model.char2ind[elt]
            this_seq.append(ind)
        this_seq.append(new_model.char2ind["EOS"])
        all_seqs.append(torch.LongTensor(this_seq))
            
    all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)
    print(all_seqs)

    for index, logit in enumerate(logits):
        if index >= len(all_seqs):
            break
            
        loss += criterion(logit[0], all_seqs[index])
        
        #if index == len(pair[1]):
        #    loss += criterion(logit[0], torch.LongTensor([new_model.char2ind["EOS"]]))
        #elif index > len(pair[1]):
        #    break
        #else:
        #    loss += criterion(logit[0], torch.LongTensor([new_model.char2ind[pair[1][index]]]))

            
    total_values = sum([len(x[1]) + 1 for x in training_set])  
    loss /= total_values
    loss.backward(create_graph=True, retain_graph=True)
    
    for name, param in new_model.named_params():
        grad = param.grad
                
        new_model.set_param(name, param - lr_inner * grad)
    
    
    
    
    test_loss = 0  
    correct = 0
    total = 0
    
    all_seqs = [] # HERE
    for inp, sequence in test_set:
        this_seq = []
        # Iterate over the sequencex
        for elt in sequence:
            ind = new_model.char2ind[elt]
            this_seq.append(ind)
        this_seq.append(new_model.char2ind["EOS"])
        all_seqs.append(torch.LongTensor(this_seq))
            
    all_seqs = torch.nn.utils.rnn.pad_sequence(all_seqs)
    print(all_seqs)

    output, logits = new_model([pair[0] for pair in test_set])
    for index, output_guess in enumerate(output):
        if output_guess == test_set[index][1]:
            correct += 1
        total += 1
        
    for index, logit in enumerate(logits):
        if index >= len(all_seqs):
            break
            
        test_loss += criterion(logit[0], all_seqs[index])

    total_values = sum([len(x[1]) + 1 for x in test_set])  
    test_loss /= total_values
    #for pair in test_set:
        #inp, outp = pair
        #output, logits = new_model(inp)
        
        #if output == outp:
        #    correct += 1
        #total += 1

        #for index, logit in enumerate(logits):
        #    if index == len(pair[1]):
        #        test_loss += criterion(logit[0], torch.LongTensor([new_model.char2ind["EOS"]]))
        #    elif index > len(pair[1]):
        #        break
        #    else:
        #        test_loss += criterion(logit[0], torch.LongTensor([new_model.char2ind[pair[1][index]]]))
                
        #test_loss /= len(logits)
        
    #for index, logit in enumerate(logits):
    #    if index >= len(all_seqs):
    #        break
            
   #     test_loss += criterion(logit[0], all_seqs[index])

    #test_loss /= len(test_set)
    test_acc = correct * 1.0 / total
                
    # I think we save this step for the end?
    #test_loss.backward(create_graph=True, retain_graph=True)
    
        
    #return test_loss.data.cpu().numpy() ###
    return test_loss, test_acc

