In [1]:
from decoder import *
from encoder import *
from textEncoder import *
from genZero import *
from other_data_loader import *
import pickle
import random
import torch.optim as optim
from torch.autograd import Variable
import csv
import time
from tqdm import tqdm
import gc
import os
import torchvision.transforms as tf
import json
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from nltk.translate.bleu_score import corpus_bleu
from matplotlib import pyplot as plt


[nltk_data] Downloading package punkt to
[nltk_data]     /datasets/home/64/364/rhadden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [13]:


def validate(val_loader, encoder, decoder, embed, stack, maxSeqLen,
             vocab, batch_size, use_gpu = True):

    
    #Evaluation Mode
    decoder.eval()
    encoder.eval()
    embed.eval()
    stack.eval()

    # critereon
    discCrit = nn.BCEWithLogitsLoss()
    textCrit = nn.CrossEntropyLoss()
    distCrit = nn.SmoothL1Loss()
    if use_gpu:
        device = torch.device("cuda:0")
        
        
    with torch.no_grad():
        
        count    = 0
        loss_avg = 0

                
        for i, (inputs, labels, lengths) in enumerate(val_loader):
            
            
            
            # Move to device, if available
            if use_gpu:
                inputs = inputs.to(device)# Move your inputs onto the gpu
                labels = labels.to(device) # Move your labels onto the gpu
                        

            #STEM
            embed.resetHidden(labels.shape[0])
            pred_out, w, s = embed(labels, lengths)
            
            #GLAM
            generated_imgs, h0s, discr_gen, discr_real = stackZero(s, inputs)
            
            
            # STREAM
            enc_out = encoder(generated_imgs)

            decoder.resetHidden(inputs.shape[0])
            outputs = decoder(labels, enc_out, lengths) #calls forward
            
            
            loss = (
                    distCrit(generated_imgs, inputs)
                    + textCrit(outputs, labels.cuda())
                    + discCrit(discr_real, torch.ones(discr_real.shape).cuda())
            )
            
            loss_avg += loss.item()
            count+=1
            
            
            del caps
            del outputs            
            
            

                
        loss_avg  = loss_avg/count



        
        
            
    return loss_avg

In [14]:
def trainStack(encoder, decoder, embed, stackZero, epochs, train_loader,val_loader, test_loader,
                        name, batch_size, maxSeqLen, vocab,save_generated_imgs= False):
    
    #Create non-existing logfiles
    logname = './logs/' + name + '.log'
    i = 0
    if os.path.exists(logname) == True:
        
        logname = './logs/' + name + str(i) + '.log'
        while os.path.exists(logname):
            i+=1
            logname = './logs/' + name + str(i) + '.log'

    print('Loading results to logfile: ' + logname)
    with open(logname, "w") as file:
        file.write("Log file DATA: Validation Loss and Accuracy\n") 
    
    logname_summary = './logs/' + name + '_summary' + str(i) + '.log'    
    print('Loading Summary to : ' + logname_summary) 
    
    
    try:
        os.mkdir('./generated_imgs')
    except:
        pass
    
    generated_imgs_filename = './generated_imgs/generated_imgs' + name + '_summary' + str(i) + '.log'
    
    parameters = list(stackZero.parameters())
    optimizer = optim.Adam(parameters, lr=5e-5)
    
    
    discCrit = nn.BCEWithLogitsLoss()
    textCrit = nn.CrossEntropyLoss()
    distCrit = nn.SmoothL1Loss()
    
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        device = torch.device("cuda:0")

        encoder.to(device)
        decoder.to(device)
        embed.to(device)
        stackZero.to(device)
    temperature=1
    
    val_loss_set = []

    training_loss = []
    
    # Early Stop criteria
    minLoss = 1e6
    minLossIdx = 0
    earliestStopEpoch = 7
    earlyStopDelta = 3
    for epoch in range(epochs):
        ts = time.time()

        for iter, (inputs, labels, lengths) in tqdm(enumerate(train_loader)):

            optimizer.zero_grad()
            
            stackZero.train()
            encoder.train()
            decoder.train()
            embed.train()
            
            if use_gpu:
                inputs = inputs.to(device)# Move your inputs onto the gpu
                labels = labels.to(device) # Move your labels onto the gpu
            
            #STEM
            embed.resetHidden(labels.shape[0])
            pred_out, w, s = embed(labels, lengths)
            
            #GLAM
            generated_imgs, h0s, discr_gen, discr_real = stackZero(s, inputs)
            
            
            # STREAM
            enc_out = encoder(generated_imgs)

            decoder.resetHidden(inputs.shape[0])
            outputs = decoder(labels, enc_out, lengths) #calls forward
            
            #For first epoch will train the discriminator/generator everyother iteration
            if epoch ==0:
                if iter % 2 == 0:
                    # Train discriminator
                    loss = (
                        distCrit(generated_imgs, inputs)
                        + textCrit(outputs, labels.cuda())
                        - discCrit(discr_gen, torch.zeros(discr_gen.shape).cuda())
                    )
                else:
                    # Train generator
                    loss = (
                        discCrit(discr_gen, torch.zeros(discr_gen.shape).cuda()) 
                        + discCrit(discr_real, torch.ones(discr_real.shape).cuda())
                    )
                    
            # After first epoch will train discriminator every 5 epochs
            elif epoch % 5:
                # Train discriminator
                loss = (
                        discCrit(discr_gen, torch.zeros(discr_gen.shape).cuda()) 
                        + discCrit(discr_real, torch.ones(discr_real.shape).cuda())
                )
            else:
                # Train generator
                loss = (
                    distCrit(generated_imgs, inputs)
                    + textCrit(outputs, labels)
                    - discCrit(discr_gen, torch.zeros(discr_gen.shape).cuda())
                )
            
#             del labels
#             del outputs

            loss.backward()
#             loss = loss#.item()
            optimizer.step()

            if iter % 200 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))

                
        print("epoch{}, iter{}, loss: {}, epoch duration: {}".format(epoch, iter, loss, time.time() - ts))
        test_pred = decoder.generate_caption(enc_out, maxSeqLen, temperature).cpu()
        
        k = 0
        for b in range(inputs.shape[0]):
            gen_caption = (" ").join(
                [vocab.idx2word[x.item()] for x in test_pred[b] if vocab.idx2word[x.item()] is not '<pad>'])
            raw_caption = (" ").join(
                [vocab.idx2word[x.item()] for x in labels[b] if vocab.idx2word[x.item()] is not '<pad>'])
            gen_img = tf.ToPILImage()(generated_imgs[b,:,:,:].cpu())
            raw_img = tf.ToPILImage()(inputs[b,:,:,:].cpu())
                    
            plt.figure(figsize=(14,8))
            plt.subplot(1,2,1)
            plt.imshow(raw_img)
            plt.subplot(1,2,2)
            plt.imshow(gen_img)      
            plt.show()
            print("Base Caption: " + raw_caption)
            print("Generated Caption: " + gen_caption)
            
            if save_generated_imgs:
                file = "./generated_imgs/" + "train_epoch" + str(epoch) + "im_"+ str(k) 
                img.save(file + ".png", "PNG")
                k+=1
                with open(generated_imgs_filename, "a") as file:
                    file.write("writing! " + "train_epoch" + str(epoch) + "im_"+ str(k) + "\n")            
                    file.write("Caption: " + caption +"\n \n")
        del labels
        del outputs
        # calculate val loss each epoch
        val_loss = validate(val_loader, encoder, decoder, embed, stackZero, maxSeqLen,
                             vocab, batch_size, use_gpu)
        val_loss_set.append(val_loss)

        print("epoch{}, iter{}, val loss: {}, epoch duration: {}".format(epoch, iter, val_loss, time.time() - ts))
        
      
        training_loss.append(loss)
        
        torch.save(stackZero, 'weights/stack0_{}_epoch{}'.format(name, epoch))

        
        with open(logname, "a") as file:
            file.write("writing!\n")
            file.write("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
            file.write("\n training Loss:   " + str(loss.item()))
            file.write("\n Validation Loss: " + str(val_loss_set[-1]))
                                          
                                                                                                
                                                                                                
        
        # Early stopping
        if val_loss < minLoss:
            # Store new best
            torch.save(stackZero, 'weights/stack0_{}_best'.format(name))
            minLoss = val_loss#.item()
            minLossIdx = epoch
            
        #If passed min threshold, and no new min has been reached for delta epochs
        elif epoch > earliestStopEpoch and (epoch - minLossIdx) > earlyStopDelta:
            print("Stopping early at {}".format(minLossIdx))
            break
        

        
        
        with open(logname_summary, "w") as file:
            file.write("Summary!\n")
            #file.write("Stopped early at {}".format(minLossIdx))
            file.write("\n training Loss:   " + str(training_loss))        
            file.write("\n Validation Loss : " + str(val_loss_set))



In [None]:
if __name__=='__main__':
    with open('trainvalIds.csv', 'r') as f:
        trainIds = []
        for line in f:
            if len(line) > 1:
                trainIds.append(line.strip("\n"))

        
    with open('testIds.csv', 'r') as f:
        testIds = []
        for line in f:
            if len(line) > 1:
                testIds.append(line.strip("\n"))
    
    print("found {} train ids".format(len(trainIds)))
    print("found {} test ids".format(len(testIds)))
    
    # Will shuffle the trainIds incase of ordering in csv
    random.shuffle(trainIds)
    splitIdx = int(len(trainIds)/5)
    
    # Selecting 1/5 of training set as validation
    valIds = trainIds[:splitIdx]
    trainIds = trainIds[splitIdx:]
    #print(trainIds)
    
    
    trainValRoot = "./data/realImages/"
    testRoot = "./data/realImages/"
    
    trainValCaps = "./data/captions/trainvalCaps.csv"
    testCaps = "./data/captions/testCaps.csv"
    
    
    with open('./data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)
    
    img_side_length = 64
    transform = tf.Compose([
        tf.Resize(img_side_length),
        #tf.RandomCrop(img_side_length),
        tf.CenterCrop(img_side_length),
        tf.ToTensor(),
    ])
    batch_size = 20
    shuffle = True
    num_workers = 20
    
    
    trainDl = get_loader(trainValRoot, trainValCaps, trainIds, vocab, 
                         transform=transform, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=num_workers)
    valDl = get_loader(trainValRoot, trainValCaps, valIds, vocab, 
                         transform=transform, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=num_workers)
    testDl = get_loader(testRoot, testCaps, testIds, vocab, 
                        transform=transform, batch_size=batch_size, 
                        shuffle=shuffle, num_workers=num_workers)
    
    encoded_feature_dim = 800
    maxSeqLen = 49
    hidden_dim = 1500
    depth = 1
    
    embed = torch.load('./weights/bs{}_embed_best'.format(batch_size))
    
    encoder = torch.load('./weights/lstm{}encoder_best'.format(img_side_length))
    decoder = torch.load('./weights/lstm{}decoder_best'.format(img_side_length))
    # Turn off all gradients in encoder
    for param in embed.parameters():
        param.requires_grad = False
        
    for param in encoder.parameters():
        param.requires_grad = False

    for param in decoder.parameters():
        param.requires_grad = False

    
#     criterion = nn.NLLLoss()
    stackZero = BaseGenerator(batch_size, embed.hidden_dim)
    
    epochs = 100
    trainStack(encoder, decoder, embed, stackZero, epochs,
                        trainDl, valDl, testDl, "stackZero",
                        batch_size, maxSeqLen, vocab,save_generated_imgs = False)

found 7323 train ids
found 866 test ids
# ids: 58590
# ids: 14640
# ids: 8660
Loading results to logfile: ./logs/stackZero.log
Loading Summary to : ./logs/stackZero_summary0.log



0it [00:00, ?it/s][A
1it [00:01,  1.34s/it][A

epoch0, iter0, loss: -0.7631231546401978



2it [00:01,  1.03s/it][A
3it [00:01,  1.27it/s][A
5it [00:02,  1.73it/s][A
6it [00:02,  2.27it/s][A
7it [00:02,  2.93it/s][A
8it [00:02,  3.70it/s][A
9it [00:02,  4.34it/s][A

In [None]:
%debug

In [None]:
with open('./data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)
vocab.idx

In [None]:
import torch
import torch.optim as optim
import os
#print(os.getcwd())
embed = torch.load('./weights/base_embed_best')
embed.batch_size
optimizer = optim.Adam(list(embed.parameters()), lr=5e-5)

help(optimizer.step)
