In [None]:
#ライブラリ
import math
import pickle
import os
import numpy as np
from PIL import Image
from collections import Counter
from pycocotools.coco import COCO
import glob
import matplotlib.pyplot as plt
import nltk
nltk.download('punkt')
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
import torch.utils.data as data

In [None]:
#COCOデータダウンロード
!mkdir data
!wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
!wget http://images.cocodataset.org/zips/train2014.zip -P ./data/

!unzip ./data/captions_train-val2014.zip -d ./data/
!rm ./data/captions_train-val2014.zip
!unzip ./data/train2014.zip -d ./data/
!rm ./data/train2014.zip 

In [None]:
#ボキャブラリー作成
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def make_vocab(json, threshold):
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized captions.".format(i+1, len(ids)))

    words = [word for word, cnt in counter.items() if cnt >= threshold]

    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [None]:
#vocabインスタンス
vocab = make_vocab(json='./data/annotations/captions_train2014.json', threshold=4)
vocab_path = './data/vocab.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved vocabulary wrapper to '{}'".format(vocab_path))

In [None]:
#COCO画像リサイズ
def resize_image(image, size):
    return image.resize((size,size), Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))

image_dir = './data/train2014/'
output_dir = './data/resized2014/'
image_size = 256
resize_images(image_dir, output_dir, image_size)

In [None]:
#COCOカスタムデータセット
class CocoDataset(data.Dataset):
    def __init__(self, root, json, vocab, transform=None):
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

In [None]:
#バッチ
def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets

def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn,
                                              drop_last=True)
    return data_loader

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, encoder_out_size=2048, encode_size=14):
        super(Encoder, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        self.linear_hc = nn.Linear(encoder_out_size, hidden_size)
        self.encode_size = encode_size
        self.encoder_out_size = encoder_out_size

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = F.adaptive_avg_pool2d(features, (self.encode_size, self.encode_size)) 
        features = features.view(features.shape[0], -1, self.encoder_out_size)
        mean_features = features.mean(dim=1)
        h = self.linear_hc(mean_features)
        c = self.linear_hc(mean_features)
        return features, h, c

In [None]:
class CellDecoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, batch_size, encoder_out_size=2048, encode_size=14):
        super(CellDecoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstmcell = nn.LSTMCell(embed_size, hidden_size, bias=True)
        self.linear_ft = nn.Linear(encoder_out_size, hidden_size)
        self.linear_h = nn.Linear(encode_size*encode_size, hidden_size)
        self.linear_word = nn.Linear(hidden_size, vocab_size)
        self.vocab_size = vocab_size
        self.batch_size = batch_size

    def forward(self, captions, features, init_h, init_c, padmask=None):
        h = init_h
        c = init_c
        embeddings = self.embed(captions)
        features = self.linear_ft(features)
        text = torch.zeros(self.batch_size, embeddings.shape[1], self.vocab_size, device=device)
        for i in range(embeddings.shape[1]): 
              scores = torch.matmul(features, h.unsqueeze(1).transpose(-2, -1))
              scores = scores.masked_fill(padmask[:,:,i].unsqueeze(1) == 0, -1e9)
              atten_weights = F.softmax(scores, dim=-1) 
              weights_features = h.unsqueeze(1) * atten_weights
              h = torch.matmul(weights_features, h.unsqueeze(1).transpose(-2, -1))
              h = self.linear_h(h.squeeze(2))
              h, c = self.lstmcell(embeddings[:,i,:], (h, c))
              word = self.linear_word(F.dropout(h))
              text[:,i,:] = word
        return text

In [None]:
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, batch_size, encoder_out_size=2048, decoder_out_size=512, atten_size=512):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear_q = nn.Linear(decoder_out_size, atten_size)
        self.linear_k = nn.Linear(encoder_out_size, atten_size)
        self.linear = nn.Linear((hidden_size+atten_size), vocab_size)
        #self.linear = nn.Linear(14*14, vocab_size)                  #another
        self.embed_size = embed_size
        self.atten_size = atten_size
        self.encoder_out_size = encoder_out_size
        
    def forward(self, captions, features, h, c, padmask=None):
        embeddings = self.embed(captions)
        hiddens, state = self.lstm(embeddings, (h, c))
        atten_hiddens = self.linear_q(hiddens)
        atten_features = self.linear_k(features)
        scores = torch.matmul(atten_features, atten_hiddens.transpose(-2, -1))
        scores = scores.masked_fill(padmask == 0, -1e9)
        atten_weights = F.softmax(scores, dim=-1)
        context = torch.zeros(batch_size, 1, self.atten_size, device=device)
        for i in range(atten_weights.size()[2]):
            temp_weights = atten_weights[:,:,i].unsqueeze(2)
            weights_features = atten_features * temp_weights
            weights_features_sum = torch.sum(weights_features, axis=1).unsqueeze(1)
            context = torch.cat([context, weights_features_sum], dim=1)
        context = context[:,1:,:]
        hiddens = torch.cat([hiddens, context], dim=2)
        #hiddens = torch.matmul(context, features.transpose(-2,-1))  #another
        hiddens = self.linear(hiddens)
        return hiddens, state, atten_weights

In [None]:
#学習
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'modelstest/'
crop_size = 224
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
log_step=10
save_step=100
embed_size=512
hidden_size=512
num_layers=1
num_epochs=100
batch_size=64
num_workers=2
learning_rate=0.0001

if not os.path.exists(model_path):
    os.makedirs(model_path)

transform = transforms.Compose([ 
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

data_loader = get_loader(image_dir, caption_path, vocab, transform, batch_size,
                          shuffle=True, num_workers=num_workers) 

encoder = Encoder(hidden_size).to(device)
decoder = CellDecoder(embed_size, hidden_size, len(vocab), batch_size).to(device)

criterion = nn.CrossEntropyLoss()
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

all_losses = []

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(data_loader):
        
        images = images.to(device)

        caption = captions[:, :-1]
        targets = captions[:, 1:]

        caption = caption.to(device)
        targets = targets.to(device)
        
        pad = 0
        pad_mask = (targets != pad).unsqueeze(1)
        
        feature, h, c = encoder(images)

        outputs = decoder(caption, feature, h, c, pad_mask)

        #loss = 0
        #for j in range(outputs.size()[1]):
        #    loss += criterion(outputs[:, j, :], targets[:, j])

        loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets.reshape(-1))

        encoder.zero_grad()
        decoder.zero_grad()

        loss.backward()
        optimizer.step()

        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
        
        all_losses.append(loss.item())
            
        if (i+1) % save_step == 0:
            torch.save(encoder.state_dict(), os.path.join(
                model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
        if (i+1) % save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            
    scheduler.step()

In [None]:
#予測
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'modelstest/'
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
embed_size=512
hidden_size=512
num_layers=1
batch_size=64

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

encoder = Encoder(hidden_size).to(device)
decoder = CellDecoder(embed_size, hidden_size, len(vocab), batch_size).to(device)
encoder.load_state_dict(torch.load('./modelstest/encoder-1-600.ckpt',torch.device('cpu')))
decoder.load_state_dict(torch.load('./modelstest/decoder-1-600.ckpt',torch.device('cpu')))
encoder.eval()
decoder.eval()

test_img_dir = './data/train2014/COCO_train2014_000000000025.jpg'
test_img = Image.open(test_img_dir).convert('RGB')
test_img = test_img.resize([256, 256], Image.LANCZOS)
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_img = transform(test_img)

In [None]:
#予測(CellDecoder)
images = test_img.to(device)
predicted_ids = []
caption = []
start_char = [vocab('<start>') for _ in range(batch_size)]
input_char = torch.tensor(start_char, device=device)
input_char = input_char.unsqueeze(1)
images = images.unsqueeze(0)
images = images.repeat(batch_size,1,1,1)
with torch.no_grad():
    features, h, c = encoder(images)
    for i in range(20):
        with torch.no_grad():
            pad = 0
            pad_mask = (input_char != pad).unsqueeze(1)
            outputs = decoder(input_char, features, h, c, pad_mask)
            _, output_chars = torch.max(outputs,dim=2)
            if int(output_chars[0]) == vocab('<end>'):
                break
            output_char = output_chars[:,-1]
            predicted_ids.append(int(output_char[0]))
            input_char = output_char.unsqueeze(1)

for j in range(len(predicted_ids)):
    word = vocab.idx2word[int(predicted_ids[j])]
    caption.append(word)

In [None]:
#予測(Decoder)
images = test_img.to(device)
predicted_ids = []
caption = []
start_char = [vocab('<start>') for _ in range(batch_size)]
input_char = torch.tensor(start_char, device=device)
input_char = input_char.unsqueeze(1)
images = images.unsqueeze(0)
images = images.repeat(batch_size,1,1,1)
with torch.no_grad():
    features, h, c = encoder(images)
    for i in range(2):
        with torch.no_grad():
            pad = 0
            pad_mask = (input_char != pad).unsqueeze(1)
            outputs, states, _ = decoder(input_char, features, h, c, pad_mask)
            h = states[0]
            c = states[1]
            _, output_chars = torch.max(outputs,dim=2)
            if int(output_chars[0]) == vocab('<end>'):
                break
            output_char = output_chars[:,-1]
            predicted_ids.append(int(output_char[0]))
            input_char = output_char.unsqueeze(1)

for j in range(len(predicted_ids)):
    word = vocab.idx2word[int(predicted_ids[j])]
    caption.append(word)