In [None]:
#ライブラリ
import math
import pickle
import os
import numpy as np
from PIL import Image
from collections import Counter
from pycocotools.coco import COCO
import glob
import matplotlib.pyplot as plt
import nltk
nltk.download('punkt')
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
import torch.utils.data as data

In [None]:
#COCOデータダウンロード
!mkdir data
!wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
!wget http://images.cocodataset.org/zips/train2014.zip -P ./data/

!unzip ./data/captions_train-val2014.zip -d ./data/
!rm ./data/captions_train-val2014.zip
!unzip ./data/train2014.zip -d ./data/
!rm ./data/train2014.zip 

In [None]:
#ボキャブラリー作成
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def make_vocab(json, threshold):
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized captions.".format(i+1, len(ids)))

    words = [word for word, cnt in counter.items() if cnt >= threshold]

    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

vocab = make_vocab(json='./data/annotations/captions_train2014.json', threshold=4)
vocab_path = './data/vocab.pkl'
with open(vocab_path, 'wb') as f:
   pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved vocabulary wrapper to '{}'".format(vocab_path))

In [None]:
#COCO画像リサイズ
def resize_image(image, size):
    return image.resize((size,size), Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))

image_dir = './data/train2014/'
output_dir = './data/resized2014/'
image_size = 256
resize_images(image_dir, output_dir, image_size)

In [None]:
#COCOカスタムデータセット
class CocoDataset(data.Dataset):
    def __init__(self, root, json, vocab, transform=None):
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

In [None]:
#バッチ
def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets

def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn,
                                              drop_last=True)
    
    return data_loader

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, NM=4):
        super(Encoder, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        self.attn_conv = nn.Conv2d(2048, NM, 1, bias=False)
        nn.init.xavier_uniform_(self.attn_conv.weight)
        self.linear = nn.Linear(resnet.fc.in_features*NM, hidden_size)
        self.ln = nn.LayerNorm(hidden_size)
        self.NM = NM

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        atten_map = torch.sigmoid(self.attn_conv(features))
        B, _, H, W = atten_map.shape
        features = features.reshape(B, 1, 2048, H, W)
        atten_map = atten_map.reshape(B, self.NM, 1, H, W)
        features = features * atten_map
        features = features.reshape(B*self.NM, 2048, H, W)
        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = features.reshape(B, -1)
        features = self.ln(self.linear(features))
        return features

In [None]:
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, batch_size):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.cell = torch.nn.parameter.Parameter(torch.zeros(num_layers, batch_size, hidden_size))
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.embed_size = embed_size
        self.liner_q = nn.Linear(embed_size, embed_size)
        self.liner_k = nn.Linear(embed_size, embed_size)
        self.liner_v = nn.Linear(embed_size, embed_size)
        self.atten_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, captions, features, padmask=None):
        embeddings = self.embed(captions)
        embeddings = self.attention(embeddings, embeddings, embeddings, padmask)
        features = features.unsqueeze(0)
        hiddens, state = self.lstm(embeddings, (features, self.cell))
        outputs = self.linear(hiddens)
        return outputs, state

    def attention(self, q, k, v, padmask=None):
        q = self.liner_q(q)
        k = self.liner_k(k)
        v = self.liner_v(v)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.embed_size)
        if padmask is not None:
          scores = scores.masked_fill(padmask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        output = torch.matmul(scores, v)
        output = self.atten_out(output)
        return output

    def decode(self, captions, features, cell, padmask=None):
        embeddings = self.embed(captions)
        embeddings = self.attention(embeddings, embeddings, embeddings, padmask)
        features = features.unsqueeze(0)
        hiddens, state = self.lstm(embeddings, (features, cell))
        outputs = self.linear(hiddens)
        return outputs, state

In [None]:
#学習
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'models/'
crop_size = 224
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
log_step=10
save_step=1000
embed_size=256
hidden_size=512
num_layers=1
num_epochs=10
batch_size=64
num_workers=2
learning_rate=0.001

def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

if not os.path.exists(model_path):
    os.makedirs(model_path)

transform = transforms.Compose([ 
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

data_loader = get_loader(image_dir, caption_path, vocab, transform, batch_size,
                          shuffle=True, num_workers=num_workers) 

encoder = Encoder(hidden_size).to(device)
decoder = Decoder(embed_size, hidden_size, len(vocab), num_layers, batch_size).to(device)

criterion = nn.CrossEntropyLoss()
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(data_loader):
        
        images = images.to(device)

        caption = captions[:, :-1]
        targets = captions[:, 1:]

        caption = caption.to(device)
        targets = targets.to(device)
        
        pad = 0
        pad_mask = (targets != pad).unsqueeze(1)
        pad_mask = pad_mask & torch.autograd.Variable(subsequent_mask(targets.size(-1)).type_as(pad_mask.data))
        
        feature = encoder(images)
        outputs, state = decoder(caption, feature, pad_mask)

        outputs = outputs.reshape(-1, outputs.shape[-1])
        targets = targets.reshape(-1)

        loss = criterion(outputs, targets)

        encoder.zero_grad()
        decoder.zero_grad()

        loss.backward()
        optimizer.step()

        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
            
        if (i+1) % save_step == 0:
            torch.save(encoder.state_dict(), os.path.join(
                model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
        if (i+1) % save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            
    scheduler.step()

In [None]:
#予測
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'models/'
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
embed_size=256
hidden_size=512
num_layers=1
batch_size=64

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

encoder = Encoder(hidden_size).to(device)
decoder = Decoder(embed_size, hidden_size, len(vocab), num_layers, batch_size).to(device)
encoder.load_state_dict(torch.load('./models/encoder-10-6000.ckpt',torch.device('cpu')))
decoder.load_state_dict(torch.load('./models/decoder-10-6000.ckpt',torch.device('cpu')))
encoder.eval()
decoder.eval()

test_img_dir = './data/train2014/COCO_train2014_000000000263.jpg'
test_img = Image.open(test_img_dir).convert('RGB')
test_img = test_img.resize([256, 256], Image.LANCZOS)
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_img = transform(test_img)

In [None]:
#予測
images = test_img.to(device)
predicted_ids = []
caption = []
start_char = [vocab('<start>') for _ in range(batch_size)]
input_char = torch.tensor(start_char, device=device)
input_char = input_char.unsqueeze(1)
images = images.unsqueeze(0)
images = images.repeat(batch_size,1,1,1)
features = encoder(images)
      
pad = 0
pad_mask = (input_char != pad).unsqueeze(1)
pad_mask = pad_mask & (subsequent_mask(input_char.size(-1)).type_as(pad_mask.data))

cell = torch.zeros(1, 64, 512).to(device)

outputs, states = decoder.decode(input_char, features, cell, pad_mask)
state = states[0].squeeze(0)
cell = states[1]

for i in range(max_length):
    with torch.no_grad():
        pad = 0
        pad_mask = (input_char != pad).unsqueeze(1)
        pad_mask = pad_mask & (subsequent_mask(input_char.size(-1)).type_as(pad_mask.data))
        outputs, states = decoder.decode(input_char, state, cell, pad_mask)
        
        state = states[0].squeeze(0)
        cell = states[1]

        _, output_chars = torch.max(outputs,dim=2)
        if int(output_chars[0]) == vocab('<end>'):
            break
        output_char = output_chars[:,-1]
        predicted_ids.append(int(output_char[0]))
        input_char = output_char.unsqueeze(1)

for j in range(len(predicted_ids)):
    word = vocab.idx2word[int(predicted_ids[j])]
    caption.append(word)