In [1]:
import argparse
import torch
import torch.nn as nn
import os
from torchvision import transforms
import pickle
from data_preprocess import get_loader
from model import EncoderCNN,DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from build_vocab import Vocabulary
import numpy as np

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\manda\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.


In [2]:
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
device

'cuda:0'

In [4]:
def main(args):
    #create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    #image preprocessing and normalzation stuff(define transforms)
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485,0.456,0.406),
                             (0.229,0.224,0.225)),
    ])
    
    
    #load vocabulary wrapper
    with open(args.vocab_path , 'rb') as f:
        vocab = pickle.load(f)


    #build data-loader
    data_loader = get_loader(args.image_dir , args.caption , vocab ,transform ,args.batch_size ,shuffle = True ,
                             num_workers = args.num_workers)
    
    #build the model
    encoder = EncoderCNN(args.embed_size).to(device)
    decoder = DecoderRNN(len(vocab) , args.embed_size , args.hidden_size , args.num_layers).to(device)
    
    #define loss and optimizer function
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.batch_norm.parameters())
    optimizer = torch.optim.Adam(params , lr = args.learning_rate)
    
    
    #train the model
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i,(images,captions,lengths) in enumerate(data_loader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions , lengths ,batch_first=True)[0]
            
            #forward and backprop
            features = encoder(images)
            outputs = decoder(features,captions,lengths)
            
            loss = criterion(outputs , targets)
            decoder.zero_grad()
            encoder.zero_grad()
            
            loss.backward()
            optimizer.step()
            #print log info
            if i%args.log_step == 0:
                print('Epoch {}/{}  , Step {}/{}  , Loss {:.4f} , Perplexity{:5.4f}'.format(epoch,args.num_epochs,
                                                                                           i,total_step , loss.item(),
                                                                                           np.exp(loss.item())))
            
            #save the model
            if (i+1)%args.save_step == 0:
                torch.save(decoder.state_dict(),os.path.join(args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1,i+1)))
                torch.save(encoder.state_dict(),os.path.join(args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1,i+1)))
            

In [None]:
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path' , type=str , default='models/',help='path for saving models')
    parser.add_argument('--crop_size' , type=int , default= 224, help= 'crop size')
    parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl' ,help = 'this is path for vocabulary wrapper')
    parser.add_argument('--image_dir',type =str, default='data/processed_images/' ,help = 'path for processed images')
    parser.add_argument('--caption' ,type=str, default='data/annotations/captions_train2014.json' ,help = 'path for train annotations')
    parser.add_argument('--batch_size' ,type=int ,default = 128)
    
    parser.add_argument('--embed_size', type =int ,default =256 ,help = 'dimension of word embedding vectors')
    parser.add_argument('--hidden_size',type = int ,default = 512 ,help = 'dimension of lstm hidden states')
    parser.add_argument('--num_layers', type = int , default = 1, help = 'number of layers of lstm')
    
    parser.add_argument('--learning_rate', type =float , default =0.001)
    parser.add_argument('--num_epochs' ,type = int , default = 5)
    parser.add_argument('--num_workers', type =int ,default =4)
    parser.add_argument('--log_step' ,type =int ,default = 10 ,help = 'step size for printing log info')
    parser.add_argument('--save_step',type =int ,default = 1000,help = 'step size for saving the model')
    args = parser.parse_args(args =[])
    print(args)
    main(args)
    

Namespace(batch_size=128, caption='data/annotations/captions_train2014.json', crop_size=224, embed_size=256, hidden_size=512, image_dir='data/processed_images/', learning_rate=0.001, log_step=10, model_path='models/', num_epochs=5, num_layers=1, num_workers=4, save_step=1000, vocab_path='data/vocab.pkl')
loading annotations into memory...
Done (t=2.73s)
creating index...
index created!


Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to C:\Users\manda/.torch\models\resnet152-b121ed2d.pth
100.0%


Epoch 0/5  , Step 0/3236  , Loss 9.2067 , Perplexity9963.3521
Epoch 0/5  , Step 10/3236  , Loss 5.9100 , Perplexity368.7021
Epoch 0/5  , Step 20/3236  , Loss 5.4377 , Perplexity229.9238
Epoch 0/5  , Step 30/3236  , Loss 4.9734 , Perplexity144.5235
Epoch 0/5  , Step 40/3236  , Loss 4.5520 , Perplexity94.8214
Epoch 0/5  , Step 50/3236  , Loss 4.6409 , Perplexity103.6411
Epoch 0/5  , Step 60/3236  , Loss 4.3414 , Perplexity76.8171
Epoch 0/5  , Step 70/3236  , Loss 4.1295 , Perplexity62.1489
Epoch 0/5  , Step 80/3236  , Loss 4.1163 , Perplexity61.3329
Epoch 0/5  , Step 90/3236  , Loss 4.0430 , Perplexity56.9974
Epoch 0/5  , Step 100/3236  , Loss 3.8160 , Perplexity45.4235
Epoch 0/5  , Step 110/3236  , Loss 3.8374 , Perplexity46.4070
Epoch 0/5  , Step 120/3236  , Loss 3.8711 , Perplexity47.9956
Epoch 0/5  , Step 130/3236  , Loss 3.5995 , Perplexity36.5805
Epoch 0/5  , Step 140/3236  , Loss 3.6755 , Perplexity39.4696
Epoch 0/5  , Step 150/3236  , Loss 3.7865 , Perplexity44.1026
Epoch 0/5  , 

Epoch 0/5  , Step 1320/3236  , Loss 2.5165 , Perplexity12.3846
Epoch 0/5  , Step 1330/3236  , Loss 2.4453 , Perplexity11.5337
Epoch 0/5  , Step 1340/3236  , Loss 2.4396 , Perplexity11.4682
Epoch 0/5  , Step 1350/3236  , Loss 2.4138 , Perplexity11.1765
Epoch 0/5  , Step 1360/3236  , Loss 2.4960 , Perplexity12.1344
Epoch 0/5  , Step 1370/3236  , Loss 2.4555 , Perplexity11.6525
Epoch 0/5  , Step 1380/3236  , Loss 2.3633 , Perplexity10.6264
Epoch 0/5  , Step 1390/3236  , Loss 2.3760 , Perplexity10.7622
Epoch 0/5  , Step 1400/3236  , Loss 2.5006 , Perplexity12.1893
Epoch 0/5  , Step 1410/3236  , Loss 2.4937 , Perplexity12.1058
Epoch 0/5  , Step 1420/3236  , Loss 2.4835 , Perplexity11.9829
Epoch 0/5  , Step 1430/3236  , Loss 2.4955 , Perplexity12.1275
Epoch 0/5  , Step 1440/3236  , Loss 2.5685 , Perplexity13.0459
Epoch 0/5  , Step 1450/3236  , Loss 2.4623 , Perplexity11.7323
Epoch 0/5  , Step 1460/3236  , Loss 2.3490 , Perplexity10.4749
Epoch 0/5  , Step 1470/3236  , Loss 2.4117 , Perplexity