In [1]:
from decoder import *
from encoder import *
from other_data_loader import *
import pickle
import random
import torch.optim as optim
from torch.autograd import Variable
import csv
import time
from tqdm import tqdm
import gc
import os
import torchvision.transforms as tf
import json
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from nltk.translate.bleu_score import corpus_bleu
from matplotlib import pyplot as plt


[nltk_data] Downloading package punkt to
[nltk_data]     /datasets/home/64/364/rhadden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:


def validate_test(val_loader, encoder, decoder, criterion, maxSeqLen,
             vocab, batch_size, use_gpu = True, calculate_bleu = True):

    
    #Evaluation Mode
    decoder.eval()
    encoder.eval()

    
    references = list()
    hypotheses = list() 
   
    if use_gpu:
        device = torch.device("cuda:0")
        
        
    with torch.no_grad():
        
        count    = 0
        loss_avg = 0
        bleu1_avg = 0
        bleu4_avg = 0
                
        for i, (inputs, caps, allcaps) in enumerate(val_loader):
            
            
            
            # Move to device, if available
            if use_gpu:
                inputs = inputs.to(device)
                caps = caps.to(device)

                        
            enc_out = encoder(inputs)
            actual_lengths = allcaps
            
            
            
            temperature = 1
            test_pred = decoder.generate_caption(enc_out, maxSeqLen, temperature)

            test_pred_sample = test_pred[0].cpu().numpy()          

            

            
            #Build a list of the predicted sentences
            # Convert word_ids to words
            sampled_caption = []

            for word_id in test_pred_sample:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                if word == '<end>':
                    break
            sentence = ' '.join(sampled_caption)
            hypotheses.append(sampled_caption)                
            #if i % 750 ==0:
                #print ('generated sentence: ',sentence)            
                #print(type(sampled_caption))        
                #print(sampled_caption)        
                #print('len(generated_sentence): ',len(sampled_caption))

                
            #targets = pack_padded_sequence(labels, lengths, batch_first=True)[0]
            
            
            decoder.resetHidden(inputs.shape[0])
            outputs = decoder(caps, enc_out, actual_lengths)
#             if i % 1000 == 0:
#                 print('VAL: outputs shape: ', outputs.size())
#            new_outputs = torch.zeros(inputs.shape[0], maxSeqLen, vocab.idx)
#            new_outputs[:inputs.shape[0],:maxSeqLen, 0] = torch.ones((inputs.shape[0], maxSeqLen))
#             for dim in range(maxSeqLen):
#                 for b in range(inputs.shape[0]):
#                     new_outputs[b, dim, 0] = 1.0
#            new_outputs[:, :(outputs.shape[1]), :] = outputs
#            new_outputs = new_outputs.permute(0, 2, 1).to(device)
            
            #del inputs

            
            loss = criterion(outputs, Variable(caps.long()))
            loss_avg += loss
            count+=1
            
            #del outputs            
            
            #print('VAL: loss: ', loss)


            caps_array = caps.cpu().numpy()  
            # Convert word_ids to words
            reference_caption = []
            sampled_caption = []
            
            for word_id in caps_array[0]:
                word = vocab.idx2word[word_id]
                reference_caption.append(word)
                if word == '<end>':
                    break
            ref_sentence = ' '.join(reference_caption)
            #if i % 500 == 0:
                #print('ref_sentence: ', ref_sentence)
                #print('len(ref_sentence): ',len(reference_caption))
            references.append(reference_caption)   
        

        
        
            #print('len(references)', len(references))
            #print('len(hypotheses)', len(hypotheses))
            #print('references: ', references)
            #print('hypotheses: ', hypotheses)
        

            # Calculate BLEU-4 scores
            if calculate_bleu:
                bleu4 = corpus_bleu(references, hypotheses)                
                bleu1 = corpus_bleu(references, hypotheses,weights=(1.0, 0, 0, 0))
                #print('bleu4: ', bleu4)        
                #print('bleu1: ', bleu1)  
                bleu4_avg+=bleu4
                bleu1_avg+=bleu1
            
            
            
            del caps
            del outputs            
            
            
            #if i % 10 == 0:
            #    break
                
        loss_avg  = loss_avg/count
        print('VAL: loss_avg: ', loss_avg)

        if calculate_bleu:
            
            bleu4_avg = bleu4_avg/count
            bleu1_avg = bleu1_avg/count 
            
            print('VAL: bleu4_avg: ', bleu4_avg)
            print('VAL: bleu1_avg: ', bleu1_avg)
        
        
        
            
    return loss_avg, bleu1_avg, bleu4_avg

In [3]:
def trainEncoderDecoder(encoder, decoder, criterion, epochs,
                        train_loader,val_loader, test_loader,
                        name, batch_size, maxSeqLen, vocab,save_generated_imgs= False):
    
    #Create non-existing logfiles
    logname = './logs/' + name + '.log'
    i = 0
    if os.path.exists(logname) == True:
        
        logname = './logs/' + name + str(i) + '.log'
        while os.path.exists(logname):
            i+=1
            logname = './logs/' + name + str(i) + '.log'

    print('Loading results to logfile: ' + logname)
    with open(logname, "w") as file:
        file.write("Log file DATA: Validation Loss and Accuracy\n") 
    
    logname_summary = './logs/' + name + '_summary' + str(i) + '.log'    
    print('Loading Summary to : ' + logname_summary) 
    
    
    try:
        os.mkdir('./generated_imgs')
    except:
        pass
    
    generated_imgs_filename = './generated_imgs/generated_imgs' + name + '_summary' + str(i) + '.log'
    
    
    parameters = list(encoder.fc.parameters())
    parameters.extend(list(decoder.parameters()))
    optimizer = optim.Adam(parameters, lr=5e-5)
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        device = torch.device("cuda:0")
#         encoder = torch.nn.DataParallel(encoder)
#         decoder = torch.nn.DataParallel(decoder)
        
        encoder.to(device)
        decoder.to(device)
        
        
    
    val_loss_set = []
    val_bleu1_set = []
    val_bleu4_set = []
    
    
    training_loss = []
    
    # Early Stop criteria
    minLoss = 1e6
    minLossIdx = 0
    earliestStopEpoch = 10
    earlyStopDelta = 5
    for epoch in range(epochs):
        ts = time.time()
        print("Type")    

        print(type(train_loader))
        for iter, (inputs, labels, lengths) in tqdm(enumerate(train_loader)):

            optimizer.zero_grad()
            
            
            
            encoder.train()
            decoder.train()
            
            if use_gpu:
                inputs = inputs.to(device)# Move your inputs onto the gpu
                labels = labels.to(device) # Move your labels onto the gpu
            
                
            enc_out = encoder(inputs)
            
            temperature = 1
            
            
            
            decoder.resetHidden(inputs.shape[0])
            outputs = decoder(labels, enc_out, lengths) #calls forward
            #targets = pack_padded_sequence(labels, lengths, batch_first=True)
            #targets = pack_padded_sequence(labels, actual_lengths, batch_first=True, enforce_sorted=False)
            #targets = torch.zeros()
            loss = criterion(outputs, labels.cuda())
            del labels
            del outputs

            loss.backward()
#             loss = loss#.item()
            optimizer.step()

            if iter % 200 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss))
                #break
#                 test_pred = decoder.generate_caption(enc_out, maxSeqLen, temperature).cpu()
#                 for b in range(inputs.shape[0]):
#                     caption = (" ").join([vocab.idx2word[x.item()] for x in test_pred[b]])
#                     img = tf.ToPILImage()(inputs[b,:,:,:].cpu())
#                     plt.imshow(img)
                    
#                     plt.show()
#                     print("Caption: " + caption)
                
        print("epoch{}, iter{}, loss: {}, epoch duration: {}".format(epoch, iter, loss, time.time() - ts))
        test_pred = decoder.generate_caption(enc_out, maxSeqLen, temperature).cpu()
        
        k = 0
        for b in range(inputs.shape[0]):
            caption = (" ").join([vocab.idx2word[x.item()] for x in test_pred[b]])
            img = tf.ToPILImage()(inputs[b,:,:,:].cpu())
            plt.imshow(img)
                    
            plt.show()
            print("Caption: " + caption)
            if save_generated_imgs:
                file = "./generated_imgs/" + "train_epoch" + str(epoch) + "im_"+ str(k) 
                img.save(file + ".png", "PNG")
                k+=1
                with open(generated_imgs_filename, "a") as file:
                    file.write("writing! " + "train_epoch" + str(epoch) + "im_"+ str(k) + "\n")            
                    file.write("Caption: " + caption +"\n \n")
        
        # calculate val loss each epoch
        val_loss, val_bleu1, val_bleu4  = validate_test(val_loader, encoder, decoder, criterion,maxSeqLen,
                             vocab, batch_size, use_gpu, calculate_bleu = False)
        val_loss_set.append(val_loss)
        val_bleu1_set.append(val_bleu1)
        val_bleu4_set.append(val_bleu4)

        
#         print("epoch {}, time {}, train loss {}, val loss {}, val acc {}, val iou {}".format(epoch, time.time() - ts,
#                                                                                                loss, val_loss,
#                                                                                                val_acc,
#                                                                                                val_iou))        
        training_loss.append(loss)
        
        torch.save(encoder, 'weights_base/encoder_epoch{}'.format(epoch))
        torch.save(decoder, 'weights_base/decoder_epoch{}'.format(epoch))
        
        with open(logname, "a") as file:
            file.write("writing!\n")
            file.write("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
            file.write("\n training Loss:   " + str(loss.item()))
            file.write("\n Validation Loss: " + str(val_loss_set[-1]))
            file.write("\n Validation bleu1:  " + str(val_bleu1_set[-1]))
            file.write("\n Validation bleu4:  " + str(val_bleu4_set[-1]) + "\n ")                                            
                                                                                                
                                                                                                
        
        # Early stopping
#         if val_loss < minLoss:
#             # Store new best
#             torch.save(model, name)
#             minLoss = val_loss#.item()
#             minLossIdx = epoch
            
        # If passed min threshold, and no new min has been reached for delta epochs
#         elif epoch > earliestStopEpoch and (epoch - minLossIdx) > earlyStopDelta:
#             print("Stopping early at {}".format(minLossIdx))
#             break
        

        
        
        with open(logname_summary, "w") as file:
            file.write("Summary!\n")
            #file.write("Stopped early at {}".format(minLossIdx))
            file.write("\n training Loss:   " + str(training_loss))        
            file.write("\n Validation Loss : " + str(val_loss_set))
            file.write("\n Validation bleu1:  " + str(val_bleu1_set))
            file.write("\n Validation bleu4:  " + str(val_bleu4_set) + "\n ")
            
        
    #return val_loss_set, val_acc_set, val_iou_set

In [None]:
if __name__=='__main__':
    with open('trainvalIds.csv', 'r') as f:
        trainIds = []
        for line in f:
            if len(line) > 1:
                trainIds.append(line.strip("\n"))

        
    with open('testIds.csv', 'r') as f:
        testIds = []
        for line in f:
            if len(line) > 1:
                testIds.append(line.strip("\n"))
    
    print("found {} train ids".format(len(trainIds)))
    print("found {} test ids".format(len(testIds)))
    
    # Will shuffle the trainIds incase of ordering in csv
    random.shuffle(trainIds)
    splitIdx = int(len(trainIds)/5)
    
    # Selecting 1/5 of training set as validation
    valIds = trainIds[:splitIdx]
    trainIds = trainIds[splitIdx:]
    #print(trainIds)
    
    
    trainValRoot = "./data/realImages/"
    testRoot = "./data/realImages/"
    
    trainValCaps = "./data/captions/trainvalCaps.csv"
    testCaps = "./data/captions/testCaps.csv"
    
    
    with open('./data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)
    
    img_side_length = 256
    transform = tf.Compose([
        tf.Resize(img_side_length),
        #tf.RandomCrop(img_side_length),
        tf.CenterCrop(img_side_length),
        tf.ToTensor(),
    ])
    batch_size = 10
    shuffle = True
    num_workers = 5
    
    
    trainDl = get_loader(trainValRoot, trainValCaps, trainIds, vocab, 
                         transform=transform, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=num_workers)
    valDl = get_loader(trainValRoot, trainValCaps, valIds, vocab, 
                         transform=transform, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=num_workers)
    testDl = get_loader(testRoot, testCaps, testIds, vocab, 
                        transform=transform, batch_size=batch_size, 
                        shuffle=shuffle, num_workers=num_workers)
    
    encoded_feature_dim = 1024
    maxSeqLen = 56
    hidden_dim = 1500
    depth = 1
    
    encoder = Encoder(encoded_feature_dim)
    # Turn off all gradients in encoder
    for param in encoder.parameters():
        param.requires_grad = False
    # Turn on gradient of final hidden layer for fine tuning
    for param in encoder.fc.parameters():
        param.requires_grad = True
    decoder = Decoder(encoded_feature_dim, hidden_dim, depth, vocab.idx, batch_size)
    
#     criterion = nn.NLLLoss()
    criterion = nn.CrossEntropyLoss()
    
    epochs = 100
    trainEncoderDecoder(encoder, decoder, criterion, epochs,
                        trainDl, valDl, testDl, "LSTM",
                        batch_size, maxSeqLen, vocab,save_generated_imgs = True)

found 7323 train ids
found 866 test ids
# ids: 58590
# ids: 14640
# ids: 8660
Loading results to logfile: ./logs/LSTM2.log
Loading Summary to : ./logs/LSTM_summary2.log
Type
<class 'torch.utils.data.dataloader.DataLoader'>


2it [00:00,  3.86it/s]

epoch0, iter0, loss: 7.545274257659912


202it [00:24,  8.22it/s]

epoch0, iter200, loss: 1.0059680938720703


402it [00:49,  8.14it/s]

epoch0, iter400, loss: 0.5710200071334839


601it [01:14,  7.69it/s]

epoch0, iter600, loss: 0.5080430507659912


802it [01:38,  7.58it/s]

epoch0, iter800, loss: 0.7432484030723572


1002it [02:03,  8.22it/s]

epoch0, iter1000, loss: 0.4593638777732849


1202it [02:27,  8.07it/s]

epoch0, iter1200, loss: 0.5876981616020203


1402it [02:52,  8.17it/s]

epoch0, iter1400, loss: 0.4628753066062927


1602it [03:16,  7.65it/s]

epoch0, iter1600, loss: 0.6250199675559998


1802it [03:41,  8.22it/s]

epoch0, iter1800, loss: 0.5417322516441345


2002it [04:05,  8.31it/s]

epoch0, iter2000, loss: 0.4870154857635498


2202it [04:30,  8.27it/s]

epoch0, iter2200, loss: 0.48893558979034424


2402it [04:54,  7.56it/s]

epoch0, iter2400, loss: 0.6270290017127991


2601it [05:18,  7.62it/s]

epoch0, iter2600, loss: 0.5298337936401367


2802it [05:43,  8.17it/s]

epoch0, iter2800, loss: 0.42725706100463867


3002it [06:07,  8.21it/s]

epoch0, iter3000, loss: 0.4083181321620941


3202it [06:31,  8.29it/s]

epoch0, iter3200, loss: 0.4104045629501343


3402it [06:56,  8.10it/s]

epoch0, iter3400, loss: 0.40149155259132385


3602it [07:20,  8.04it/s]

epoch0, iter3600, loss: 0.43959739804267883


3802it [07:44,  8.15it/s]

epoch0, iter3800, loss: 0.356599897146225


4002it [08:09,  8.36it/s]

epoch0, iter4000, loss: 0.599996030330658


4201it [08:33,  7.77it/s]

epoch0, iter4200, loss: 0.508206307888031


4333it [08:49,  8.02it/s]

In [None]:
%debug

In [None]:
vocab.idx

In [None]:
x = torch.Tensor([])