In [1]:
from collections import defaultdict
from collections import namedtuple
import time
import random
import json
import string
from nltk.stem.porter import *
from nltk.stem.snowball import SnowballStemmer
import h5py
import numpy as np
from nltk import word_tokenize
from nltk.corpus import stopwords
from heapq import nlargest
import dill
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import requests
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import PIL
import os
import matplotlib.pyplot as plt


In [3]:
CUDA = torch.cuda.is_available()

#unfotianlly SEQ2SEQ does not work on CUDA
CUDA = False

In [23]:
IMG_FEATURES = "./images/IR_image_features.h5"
IMG_ID = "./images/IR_img_features2id.json"

TRAIN_HARD = "./Data/Hard/IR_train_hard.json"
TRAIN_EASY = "./Data/Easy/IR_train_easy.json"
TEST_HARD = "./Data/Hard/IR_test_hard.json"
TEST_EASY = "./Data/Easy/IR_test_easy.json"
VAL_HARD = "./Data/Hard/IR_val_hard.json"
VAL_EASY = "./Data/Easy/IR_val_easy.json"

IMGID2IMGINFO = "./Data/imgid2imginfo.json"

BATCH_SIZE = 100
LEARNING_RATE = 1e-3
EPOCHS = 30
CAPTION_DICT = {}
CURRENT_MODEL_NAME = "SEQ2SEQ_HARD"

In [5]:
#GET features from images
with open("./images/IR_img_features2id.json", 'r') as f:
    visual_feat_mapping = json.load(f)['IR_imgid2id']
f.close()

img_features = np.asarray( h5py.File("./images/IR_image_features.h5", 'r')['img_features'])

def get_feature_from_id(img_id):
    h5_id = visual_feat_mapping[str(img_id)]
    return img_features[h5_id]

In [37]:
""" Helper function to show an image, from a image id """
#source: https://stackoverflow.com/questions/34255938/is-there-a-way-to-specify-the-width-of-a-rectangle-in-pil
def draw_rectangle(draw, coordinates, color, width=100):
    for i in range(width):
        rect_start = (coordinates[0][0] - i, coordinates[0][1] - i)
        rect_end = (coordinates[1][0] + i, coordinates[1][1] + i)
        draw.rectangle((rect_start, rect_end), outline = color)

def show_img_from_id(img_ids, target_id = -1): 
    try:
        with open(IMGID2IMGINFO, 'r') as f:
            imgid2info = json.load(f)

        response = requests.get(imgid2info[str(img_ids[0])]['coco_url'])
        img = Image.open(BytesIO(response.content))
        imgs = [img, img]
        width, height = img.size
        draw = ImageDraw.Draw(img)
        line_width = 30
        draw.rectangle(((0, 0), (width*2, height*2)), fill="white")
        for img_id in img_ids:
            response = requests.get(imgid2info[str(img_id)]['coco_url'])
            img = Image.open(BytesIO(response.content))
            if(img_id == target_id):
                width, height = img.size
                draw = ImageDraw.Draw(img)
                line_width = 30
                draw_rectangle(draw, coordinates=((line_width, line_width), (width - line_width, height - line_width)), color="green", width=line_width)
            imgs.append(img)

        f.close()

        min_shape = sorted( [(np.sum(i.size), i.size ) for i in imgs])[-1][1]
        imgs_comb = np.hstack( (np.asarray( i.resize(min_shape) ) for i in imgs ) )

        # save that beautiful picture
        imgs_comb = PIL.Image.fromarray( imgs_comb)
        width, height = imgs_comb.size
        draw = ImageDraw.Draw(imgs_comb)
        if(target_id != -1):
            # fnt = ImageFont.truetype("/Library/Fonts/Comic Sans MS.ttf", 30) 
            fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 30)
            draw.multiline_text((10, 10), CAPTION_DICT[target_id].replace("?", "?") ,(0,0,0), font=fnt)

        files = "["+str(target_id)+"]_"
        files += '_'.join(map(str, img_ids))
        if(img_ids[0] == target_id):
            file_name = './results/' + CURRENT_MODEL_NAME + '/correct/top1/'
        elif(img_ids[1] == target_id or img_ids[2] == target_id or img_ids[3] == target_id or img_ids[4] == target_id):
            file_name = './results/' + CURRENT_MODEL_NAME + '/correct/top5/'  
        else:
            file_name = './results/' + CURRENT_MODEL_NAME + '/wrong/' 
        imgs_comb.save( file_name + files + '.jpg' )
    except:
        import traceback
        print( traceback.format_exc())
        pass
    
show_img_from_id([378466, 378466, 378466, 378466], 378466)

In [7]:
"""
Preprocess a sentence
"""
def preprocess(sentence, stemmer, stop):
    low_sent = sentence.lower()
    # Possibility to tokenize entire dataset and put in counter to filter out
    # infrequent/frequent words
    tok_sent = word_tokenize(low_sent)
    stop_stem_sent = [stemmer.stem(i) for i in tok_sent if i not in stop]
    return stop_stem_sent


In [8]:
"""
Convert Samples
"""

def get_dialog_caption_targets_from_sample(sample):
    dialog = ''
    caption = sample['caption']
    targets = []
    targetidx = sample['target']
   
    dialog_text = ''
    for d in sample['dialog']:
        dialog += ' ' + d[0]
        dialog_text += ' ' + d[0] + '\n'

    for img in sample['img_list']:
        targets.append(img)
    
    CAPTION_DICT[targets[targetidx]] = dialog_text + "  \n " + caption
    return dialog, caption, targets, targetidx

Sample = namedtuple("Sample", ["words", "images", "target"])

"""
 For every Sample we retrieve the sentences and the img_ids, and the correct target_id. 
"""
def read_dataset(filename, stemmer, stopwords):
    with open(filename, "r") as f:
        dataset = json.load(f)
    f.close()
    for idx, sample in enumerate(dataset):
        if(idx % 10000 == 0):
            print(idx)
        dialog, caption, targets, targetidx = get_dialog_caption_targets_from_sample(dataset[
                                                                                     str(sample)])
        sentences = preprocess(dialog + ' ' + caption, stemmer, stopwords)
        yield Sample(words=[w2i[x] for x in sentences], images=targets, target=targetidx)



In [9]:
w2i = defaultdict(lambda: len(w2i))
UNK = w2i["<UNK>"]
PAD = w2i["<PAD>"]

# Do this super one time 
import nltk
# nltk.download("english")
stemmer = SnowballStemmer("english")
stop = stopwords.words('english') + list(string.punctuation)


In [27]:
#read the datasets and use w2i (only do this once)
train = list(read_dataset(TRAIN_HARD, stemmer, stop))
w2i = defaultdict(lambda: UNK, w2i)
val = list(read_dataset(VAL_HARD, stemmer, stop))
test = list(read_dataset(TEST_HARD, stemmer, stop))

0
10000
20000
30000
0
0


In [13]:
nwords = len(w2i)

In [14]:
"""
CLASS LSTM 
"""
class Seq2SeqNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_dim_mlp, output_dim, batch_size):
        super(Seq2SeqNN,self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size,embedding_size,padding_idx=PAD)
        self.batch_size = batch_size
        self.linear1 = nn.Linear(embedding_size,hidden_dim_mlp)
        self.linear2 = nn.Linear(hidden_dim_mlp,output_dim)
        
    def forward(self,sentence):
        embeds = self.word_embeddings(sentence)
        x = torch.sum(embeds,1)
        lin1 = F.sigmoid(self.linear1(x))
        lin2 = self.linear2(lin1)
        return lin2

class ClassificationNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, img_feat_dim, hidden_dim_mlp, output_dim, batch_size):
        super(ClassificationNN,self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD)
        self.batch_size = batch_size
        self.linear1 = nn.Linear(embedding_size+img_feat_dim,hidden_dim_mlp)
        self.linear2 = nn.Linear(hidden_dim_mlp,output_dim)

    def forward(self, sentence, image_feat):
        embeds = self.word_embeddings(sentence)
        x = torch.sum(embeds, 1)
        x = x.unsqueeze(1)
        x = x.repeat(1,10,1)
        lin1 = F.sigmoid(self.linear1(torch.cat((x, image_feat),2)))
        lin2 = self.linear2(lin1)
        return lin2

In [15]:
# INIT MODEL AND INIT OPTIMIZER
print("Batch_size: ", BATCH_SIZE, "LEARNING_RATE: ",LEARNING_RATE)
print()

#model = ClassificationNN(nwords, 50, 2048, 45, 1, BATCH_SIZE)
model = Seq2SeqNN(nwords, 50, 320, 2048, BATCH_SIZE)
if CUDA:
    model.cuda()
print(model)

#@TODO we can use a adaptive learnrate for adam
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

Batch_size:  100 LEARNING_RATE:  0.001

Seq2SeqNN(
  (word_embeddings): Embedding(15456, 50, padding_idx=1)
  (linear1): Linear(in_features=50, out_features=320)
  (linear2): Linear(in_features=320, out_features=2048)
)


In [16]:
"""
HELPER FUNCTIONS
"""
def preprocessbatch(batch):
    """ Add zero-padding to a batch. """
    seqs = [sample.words for sample in batch]
    max_length = max(map(len, seqs))
    seqs = [seq + [PAD] * (max_length - len(seq)) for seq in seqs]

    ims = np.array([get_feature_from_id(sample.images[sample.target]) for sample in batch])

    idxs = [sample.target for sample in batch]
    
    image_ids = [sample.images for sample in batch]

    return seqs, ims, idxs, image_ids

def minibatch(data, batch_size=BATCH_SIZE):
    for i in range(0, len(data), batch_size):
        yield data[i:i+batch_size]

def getLongTensor(x):
    tensor = torch.cuda.LongTensor(x) if CUDA else torch.LongTensor(x)
    return Variable(tensor)


def getFloatTensor(x):
    tensor = torch.cuda.FloatTensor(torch.from_numpy(
        x).cuda()) if CUDA else torch.FloatTensor(x)
    return Variable(tensor)


In [17]:
def evaluate(model, data,show_n_wrong_images=0, show_n_good_images = 0):
    top1 = 0
    top5 = 0
    counter_good = 0
    counter_wrong = 0
    for batch in minibatch(data):
        seqs, ims, idxs, imids = preprocessbatch(batch)
        scores = model(getLongTensor(seqs))
        for i in range(len(batch)):
            tmpscore = scores[i].unsqueeze(0)
            tmpimids = imids[i]
            tmpidx = idxs[i]
            currentdist = 0
            bestid = - 1
            top5list = []
            top5indices = [x for x in range(10)]
            for j in range(10):
                comparevec = getFloatTensor(get_feature_from_id(tmpimids[j])).unsqueeze(0)
                
                cos = nn.CosineSimilarity(dim=1, eps=1e-6)
                
                dist = cos(tmpscore,comparevec).data[0]
                top5list.append(dist)
                if dist>currentdist:
                    currentdist = dist
                    bestid = j
            if(bestid==tmpidx):
                top1 +=1
                
            top5largestcos = nlargest(5, top5indices, key=lambda i: top5list[i])
            
            if(tmpidx in top5largestcos):
                top5 += 1
                if counter_good < show_n_good_images:
                    image_list = list(np.array(imids[i])[(top10_predictions[i].cpu().numpy() if CUDA else nlargest(10, top5indices, key=lambda i: top5list[i]))])
                    show_img_from_id(image_list, target_id = int(np.array(imids[i])[idxs[i]]))
                    counter_good += 1
            else:
                if counter_wrong < show_n_wrong_images:
                    image_list = list(np.array(imids[i])[(top10_predictions[i].cpu().numpy() if CUDA else nlargest(10, top5indices, key=lambda i: top5list[i]))])
                    print(image_list, int(np.array(imids[i])[idxs[i]]))
                    show_img_from_id(image_list, target_id = int(np.array(imids[i])[idxs[i]]))
                    counter_wrong += 1

    return top1/len(data), top5/len(data)

In [19]:
# top1, top5 = evaluate(model, val, show_n_wrong_images=2, show_n_good_images=2)

"""
    RUNNING THE MODEL!!!!!!!!!!
"""
top1_val_list, top5_val_list, top1_train_list, top5_train_list,train_loss_list = [],[],[],[],[]
try:
    for ITER in range(EPOCHS):
        # Init variable
        random.shuffle(train)
        train_loss = 0.0
        start = time.time()
        updates = 0
        
        for batch in minibatch(train):
            updates += 1
            
            # pad data with zeros
            seqs, ims, idxs, imids = preprocessbatch(batch)
            # forward pass

            scores = model(getLongTensor(seqs))

            targets = getFloatTensor(ims)

            loss = nn.MSELoss()
            output = loss(scores, targets)

            train_loss += output.data[0]
            
            # backward pass
            model.zero_grad()
            output.backward()
            
            # update weights
            optimizer.step()
            
            if(updates % 1000 == 0):
                print("update: {}, train_loss: {}, time {}".format(updates, train_loss/updates, time.time()-start))
        
        train_loss_list.append(train_loss/updates)
        print("iter %r: avg train loss=%.4f, time=%.2fs" % (ITER, train_loss/updates, time.time()-start))
        top1, top5 = evaluate(model, val, show_n_wrong_images=0, show_n_good_images=0)   
        top1_val_list.append(top1)
        top5_val_list.append(top5)
        
        print("VALIDATION: TOP 1: {}, TOP 5: {}".format(top1, top5))
        top1, top5 = evaluate(model, train, show_n_wrong_images=0, show_n_good_images=0)  
        top1_train_list.append(top1)
        top5_train_list.append(top5)
        
        print("TRAIN: TOP 1: {}, TOP 5: {} \n".format(top1, top5))

except KeyboardInterrupt:
    print('Stopped at ITER: ' + str(ITER))
    
create_learning_curves(top1_val_list, top5_val_list, top1_train_list, top5_train_list)
top1, top5 = evaluate(model, val, show_n_wrong_images=2, show_n_good_images=2)
print("TOP 1: {}, TOP 5: {}".format(top1, top5))

iter 0: avg train loss=0.1666, time=6.71s
VALIDATION: TOP 1: 0.2624, TOP 5: 0.789
TRAIN: TOP 1: 0.25055, TOP 5: 0.77 

iter 1: avg train loss=0.1397, time=6.92s
VALIDATION: TOP 1: 0.3274, TOP 5: 0.8432
TRAIN: TOP 1: 0.2936, TOP 5: 0.821975 

iter 2: avg train loss=0.1323, time=5.82s
VALIDATION: TOP 1: 0.3488, TOP 5: 0.8604
TRAIN: TOP 1: 0.31515, TOP 5: 0.8414 

iter 3: avg train loss=0.1286, time=6.05s
VALIDATION: TOP 1: 0.3642, TOP 5: 0.8726
Stopped at ITER: 3


NameError: name 'create_learning_curves' is not defined

In [38]:
top1, top5 = evaluate(model, test, show_n_wrong_images=1000, show_n_good_images=1000)
print("TOP 1: {}, TOP 5: {}".format(top1, top5))

[228463, 75027, 27424, 216223, 287849, 118925, 188532, 263576, 204978, 58595] 118925
[231572, 475693, 329097, 93657, 180923, 118920, 51571, 184822, 337087, 201213] 118920
[305287, 486240, 500130, 524257, 262756, 520343, 382441, 182602, 208915, 326011] 326011
[444651, 144519, 95892, 197875, 425943, 162084, 338529, 325963, 441599, 224248] 162084
[168748, 459821, 406317, 182905, 23050, 559209, 423383, 410052, 404283, 527868] 423383
[449904, 549317, 505980, 515710, 95079, 268130, 324200, 495793, 59582, 121602] 59582
[228461, 352444, 106419, 274526, 534791, 397216, 521739, 415533, 437353, 484110] 415533
[12930, 2093, 533451, 207231, 575711, 568171, 238685, 18401, 575323, 355904] 355904
[104502, 95133, 366641, 466742, 107776, 337321, 151156, 44039, 504325, 562008] 44039
[126388, 395920, 244387, 242074, 318462, 372902, 84291, 183051, 5073, 51871] 372902
[537648, 375590, 342325, 29579, 35962, 29575, 572358, 477563, 107501, 528970] 29575
[527863, 181503, 21531, 43388, 442727, 154435, 516957, 33

[539734, 119765, 254743, 75727, 175189, 135836, 417961, 291625, 24712, 364862] 135836
[547863, 568614, 201623, 97260, 514601, 83507, 110011, 452968, 181042, 5913] 452968
[328054, 372723, 575612, 452968, 178307, 493500, 479693, 83507, 244227, 308168] 493500
[545850, 137382, 191847, 109640, 122583, 545394, 51571, 482300, 286499, 75243] 545394
[317322, 529857, 508195, 270784, 260962, 9615, 356664, 165253, 529780, 443764] 356664
[308590, 1811, 217856, 202413, 486606, 417315, 181043, 88267, 341145, 401041] 88267
[485080, 982, 184202, 269260, 63668, 61201, 565311, 311228, 355760, 346189] 311228
[164168, 321328, 417887, 341963, 554654, 150902, 552093, 358457, 46242, 254204] 46242
[219250, 207976, 51046, 542718, 515796, 473912, 559203, 523173, 354843, 174188] 559203
[530719, 515796, 285563, 524317, 514559, 141387, 501284, 472729, 559203, 198214] 501284
[284333, 206889, 581797, 223459, 17226, 299050, 144084, 403534, 437354, 557771] 299050
[164208, 574443, 333201, 10245, 515803, 485938, 475567, 

[457086, 219474, 308590, 243839, 517140, 58595, 364937, 163114, 286490, 384998] 364937
[120818, 427449, 309267, 207978, 123568, 237833, 137564, 457922, 141862, 87507] 87507
[139551, 534681, 417298, 240358, 1815, 439889, 156500, 134211, 470738, 555265] 439889
[444997, 231720, 148267, 178484, 136541, 125115, 258248, 126073, 538875, 332867] 125115
[477497, 33793, 261648, 342184, 277788, 208338, 370391, 67496, 91543, 210471] 208338
[11029, 391162, 565081, 325682, 242076, 529963, 438447, 26028, 358451, 167827] 358451
[370417, 547186, 503401, 385805, 407820, 158873, 250417, 359595, 415475, 91544] 250417
[565360, 92098, 247126, 572745, 382341, 277064, 493983, 345020, 530945, 28802] 493983
[256364, 271999, 93000, 88549, 334372, 134137, 503407, 487952, 405361, 91547] 134137
[566600, 312712, 172158, 57641, 213070, 322707, 49378, 165253, 41103, 206749] 322707
[51680, 421677, 195991, 418260, 493507, 281855, 493210, 530127, 419474, 212405] 419474
[544561, 40144, 125116, 308590, 19737, 291625, 16874

[220772, 205367, 420922, 32664, 500896, 423776, 122780, 404208, 378024, 135576] 423776
[544263, 346437, 276332, 95033, 178034, 449108, 423770, 563155, 152789, 541768] 423770
[224759, 477497, 171099, 277788, 23355, 487824, 43388, 442727, 562989, 565367] 487824
[149314, 203380, 407522, 51720, 1811, 445990, 424529, 11680, 509161, 447224] 445990
[525678, 102546, 196099, 277064, 223550, 303036, 533542, 565878, 499788, 343792] 565878
[208919, 285512, 527243, 369397, 500212, 264823, 361866, 14025, 67420, 513424] 14025
[519713, 160327, 508251, 306552, 275902, 294635, 524788, 547192, 554075, 265444] 294635
[315045, 403534, 503407, 136217, 415539, 553233, 53458, 524788, 265974, 382681] 553233
[168999, 275276, 53178, 170428, 53999, 106023, 210679, 521914, 426420, 206830] 106023
[77760, 432851, 518180, 128670, 366560, 323028, 542458, 2591, 342665, 101675] 342665
[77842, 154209, 315036, 404013, 273336, 239244, 337321, 36425, 409338, 562008] 36425
[499028, 222929, 193853, 60599, 114776, 465414, 2302

[426975, 157049, 49378, 53451, 236335, 358176, 507721, 392556, 142198, 305585] 305585
[77842, 154209, 95133, 445028, 366641, 466742, 337321, 151156, 160106, 421406] 466742
[269260, 208919, 446422, 557137, 87504, 268583, 119428, 386683, 317177, 192301] 119428
[355462, 262851, 511075, 319676, 550019, 59817, 559754, 250268, 155739, 215149] 250268
[468321, 455565, 517144, 468393, 215107, 542799, 121556, 383220, 560180, 560180] 383220
[77308, 449760, 575584, 449004, 317322, 95990, 474363, 568405, 260962, 411824] 260962
[525678, 195998, 282567, 434357, 489856, 247826, 397217, 201768, 130122, 122453] 247826
[137564, 37646, 538281, 285563, 82759, 559184, 408067, 431707, 381694, 39460] 408067
[527868, 280923, 126755, 33869, 178177, 194210, 517026, 480996, 304187, 570856] 194210
[175033, 101091, 194551, 402362, 406954, 321059, 496300, 497226, 459826, 126751] 459826
[165410, 528782, 274224, 491932, 258853, 383923, 349791, 416943, 274715, 540082] 349791
[411191, 428414, 502534, 175236, 246328, 285

[503616, 244222, 221915, 448494, 20392, 493581, 426427, 447932, 168994, 356535] 493581
[411825, 21257, 392936, 479427, 263479, 398967, 47175, 524788, 527865, 122411] 524788
[27424, 426343, 492975, 98193, 118925, 188532, 563707, 328588, 228293, 473911] 563707
[241860, 334328, 357926, 210672, 431708, 26221, 271157, 577623, 577623, 354804] 354804
[378541, 111040, 141015, 84291, 442422, 414607, 42865, 51871, 537213, 222929] 414607
[98193, 244141, 486187, 517139, 381485, 46535, 118413, 260651, 444769, 319962] 46535
[575711, 331560, 429433, 41818, 279803, 13568, 419085, 61203, 314979, 491381] 13568
[452963, 415728, 77533, 221422, 354237, 300177, 401812, 565330, 246839, 251072] 565330
[446661, 415190, 336966, 565641, 555446, 293880, 330923, 211116, 319952, 246156] 293880
[408263, 100468, 83507, 427633, 496179, 244227, 267708, 258890, 577968, 27343] 577968
[217463, 453549, 261785, 414285, 342946, 235893, 120360, 553067, 9615, 259652] 235893
[125720, 141304, 87671, 465414, 478525, 180505, 33644

[427771, 135576, 406805, 268333, 463753, 527863, 575711, 568369, 242549, 332003] 242549
[312480, 96754, 31612, 496855, 559209, 240568, 220232, 14691, 346637, 260982] 220232
[487401, 177530, 97543, 382898, 560020, 511224, 321059, 19739, 417887, 148263] 511224
[468867, 580523, 80041, 248579, 263685, 423389, 218985, 100798, 156213, 331191] 423389
[119423, 551571, 442600, 188805, 280187, 200659, 183293, 205362, 189038, 215586] 205362
[392556, 229821, 327770, 101807, 451420, 488965, 574770, 570660, 127680, 495134] 488965
[51046, 219250, 137564, 38670, 114674, 156391, 524317, 419769, 523173, 499372] 524317
[276671, 381334, 320192, 270809, 407210, 509712, 564976, 105561, 184536, 516996] 509712
[356535, 298639, 64802, 189542, 516994, 185667, 471665, 310716, 503647, 27725] 185667
[111040, 369132, 338091, 57839, 277663, 481565, 338224, 51871, 40735, 482392] 481565
[149314, 203380, 225113, 51720, 357184, 168335, 182904, 147250, 41340, 239769] 41340
[260415, 546183, 239769, 538747, 139948, 41343, 

In [17]:

def create_learning_curves(top1_val_list, top5_val_list, top1_train_list, top5_train_list):
    fig = plt.figure()
    ax = plt.subplot(111)

    
    ax.plot(top5_val_list, label='top5 validation', c='darkred')
    ax.plot(top5_train_list, label='top5 train', c='royalblue')
    
    ax.plot(top1_val_list, linestyle=':', label='top1 validation', c='darkred')
    ax.plot(top1_train_list, linestyle=':', label='top1 train', c='royalblue')
    
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title('Learning curve of SEQ2SEQ HARD')
    plt.ylabel('Accuracy')
    plt.xlabel('Number of iterations')
    plt.savefig('./results/' + CURRENT_MODEL_NAME + '/accuracy1.png')
    plt.clf()
    plt.cla()
    plt.close()


In [18]:
create_learning_curves(top1_val_list, top5_val_list, top1_train_list, top5_train_list)

In [19]:
def write(x, name):
    import pickle
    file_name = './results/' + CURRENT_MODEL_NAME + '/' + name + '.npy'
    np.save(file_name, np.array(x))

In [20]:
write(top1_val_list, "top1_val")
write(top5_val_list, "top5_val")
write(top1_train_list, "top1_train")
write(top5_train_list, "top5_train")
write(train_loss_list, "train_loss")
write(top1, "top1_test")
write(top5, "top5_test")