In [None]:
import math
import pickle
import os
import numpy as np
from PIL import Image
from collections import Counter
from pycocotools.coco import COCO
import glob
import matplotlib.pyplot as plt
import nltk
nltk.download('punkt')
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
import torch.utils.data as data

In [None]:
!mkdir data
!wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
!wget http://images.cocodataset.org/zips/train2014.zip -P ./data/

!unzip ./data/captions_train-val2014.zip -d ./data/
!rm ./data/captions_train-val2014.zip
!unzip ./data/train2014.zip -d ./data/
!rm ./data/train2014.zip 

In [None]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def make_vocab(json, threshold):
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized captions.".format(i+1, len(ids)))

    words = [word for word, cnt in counter.items() if cnt >= threshold]

    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

vocab = make_vocab(json='./data/annotations/captions_train2014.json', threshold=4)
vocab_path = './data/vocab.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved vocabulary wrapper to '{}'".format(vocab_path))

In [None]:
def resize_image(image, size):
    return image.resize((size,size), Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))

image_dir = './data/train2014/'
output_dir = './data/resized2014/'
image_size = 256
resize_images(image_dir, output_dir, image_size)

In [None]:
class CocoDataset(data.Dataset):
    def __init__(self, root, json, vocab, transform=None):
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

In [None]:
#バッチ
def collate_fn_transformer(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)
    img_embed_size = 65

    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), img_embed_size).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets

def get_loader_transformer(root, json, vocab, transform, batch_size, shuffle, num_workers):
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn_transformer,
                                              drop_last=True)
    
    return data_loader

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, maxlength=64):
        super(PositionalEncoding, self).__init__()

        posemb = torch.zeros(maxlength, embed_size)
        pos = torch.arange(0, maxlength, dtype=torch.float).unsqueeze(1)
        div = torch.exp(-torch.arange(0, embed_size, 2).float() * (math.log(10000.0) / embed_size))
        posemb[:, 0::2] = torch.sin(pos * div)
        posemb[:, 1::2] = torch.cos(pos * div)
        posemb = posemb.unsqueeze(0).transpose(0, 1)
        self.register_buffer('posemb', posemb)
        self.dropout = nn.Dropout(dropout)
        self.embed_size = embed_size

    def forward(self, x):
        x = math.sqrt(self.embed_size)*x + self.posemb[:x.size(0), :]
        return self.dropout(x)

In [None]:
class CNNEncoder(nn.Module):
  def __init__(self, embed_size, img_embed_size=64, num_head=4):
        super(CNNEncoder, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        self.conv = nn.Conv2d(2048, embed_size, kernel_size=(1, 1), stride=(1, 1), bias=True)
  def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = self.conv(features)
        features = F.adaptive_avg_pool2d(features, (8, 8))
        B, C, H, W = features.shape
        features = features.reshape(B, C, H*W)
        features = torch.permute(features, (0, 2, 1))
        return features

In [None]:
class Img2capTransformer(nn.Module):
  def __init__(self, num_encoder_layers, num_decoder_layers, embed_size, vocab_size, dim_ffn=256, dropout=0.1, num_head=4):
        super(Img2capTransformer, self).__init__()
        #PositionEncode
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, dropout)
        #Encoder
        encoder_layer = nn.TransformerEncoderLayer(embed_size, num_head, dim_ffn, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        #Decoder
        decoder_layer = nn.TransformerDecoderLayer(embed_size, num_head, dim_ffn, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        #Output
        self.linear = nn.Linear(embed_size, vocab_size)

  def forward(self, captions, features, mask_src=None, mask_tgt=None, mask_memory=None, padding_mask_src=None, padding_mask_tgt=None, memory_key_padding_mask=None):
        #PotitionEncode
        images_src = self.positional_encoding(features)
        captions_tgt = self.positional_encoding(self.embed(captions))
        #Encoder
        features = self.transformer_encoder(images_src, mask_src, padding_mask_src)
        #Decoder
        captions_out = self.transformer_decoder(captions_tgt, features, mask_tgt, mask_memory, padding_mask_tgt, memory_key_padding_mask)
        #Output
        outputs = self.linear(captions_out)
        return outputs

  def encode(self, features, mask_src=None, padding_mask_src=None):
        return self.transformer_encoder(self.positional_encoding(features), mask_src, padding_mask_src)

  def decode(self, captions, features, mask_tgt=None, mask_memory=None, padding_mask_tgt=None, memory_key_padding_mask=None):
        return self.transformer_decoder(self.positional_encoding(self.embed(captions)), features, mask_tgt, mask_memory, padding_mask_tgt, memory_key_padding_mask)

In [None]:
#学習
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'transformermodels/'
crop_size = 224
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
log_step=10
save_step=6000
embed_size=512
num_layers=1
num_epochs=30
batch_size=64
num_workers=2
learning_rate=0.001
num_encoder_layers=4
num_decoder_layers=4
dim_ffn=512
dropout=0.1
num_head=8

if not os.path.exists(model_path):
    os.makedirs(model_path)

transform = transforms.Compose([
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

data_loader = get_loader_transformer(image_dir, caption_path, vocab, transform, batch_size,
                          shuffle=True, num_workers=num_workers) 

cnnencoder = CNNEncoder(embed_size=embed_size).to(device)
transformer = Img2capTransformer(num_encoder_layers, num_decoder_layers, embed_size, len(vocab), dim_ffn=dim_ffn, dropout=dropout, num_head=num_head).to(device)

criterion = nn.CrossEntropyLoss()
params = list(cnnencoder.parameters()) + list(transformer.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(data_loader):
        
        #torch.autograd.set_detect_anomaly(True)
        images = images.to(device)

        caption = captions[:, :-1]
        targets = captions[:, 1:]

        caption = caption.to(device)
        targets = targets.to(device)

        feature = cnnencoder(images)
        
        pad = 0
        src_mask = torch.zeros((num_head*batch_size, feature.size(1), feature.size(1)), dtype=bool).to(device)
        tgt_mask = torch.triu(torch.full((num_head*batch_size, targets.size(1), targets.size(1)), float('-inf')), diagonal=1).to(device)
        memory_mask = torch.zeros((num_head*batch_size, feature.size(1), feature.size(1)), dtype=bool).to(device)
        src_key_padding_mask = (feature == pad).to(device)
        tgt_key_padding_mask = (targets == pad).to(device)
        memory_key_padding_mask = (feature == pad).to(device)
        
        outputs = transformer(caption, feature, mask_src=None, mask_tgt=tgt_mask, mask_memory=None, padding_mask_src=None, padding_mask_tgt=tgt_key_padding_mask, memory_key_padding_mask=None)

        outputs = outputs.reshape(-1, outputs.shape[-1])
        targets = targets.reshape(-1)

        loss = criterion(outputs, targets)
        cnnencoder.zero_grad()
        transformer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
            
        if (i+1) % save_step == 0:
            torch.save(cnnencoder.state_dict(), os.path.join(
                model_path, 'cnnencoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(transformer.state_dict(), os.path.join(
                model_path, 'transformer-{}-{}.ckpt'.format(epoch+1, i+1)))
            
    scheduler.step()

In [None]:
#予測
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'transformermodels/'
crop_size = 224
vocab_path = 'data/vocab.pkl'
image_dir ='data/resized2014'
caption_path='data/annotations/captions_train2014.json'
embed_size=256
num_epochs=30
num_encoder_layers=4
num_decoder_layers=4
dim_ffn=512
dropout=0.1
num_head=8

with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

cnnencoder = CNNEncoder(embed_size=embed_size).to(device)
transformer = Img2capTransformer(num_encoder_layers, num_decoder_layers, embed_size, len(vocab), dim_ffn=dim_ffn, dropout=dropout, num_head=num_head).to(device)
cnnencoder.load_state_dict(torch.load('./transformermodels/cnnencoder-1-6000.ckpt',torch.device('cpu')))
transformer.load_state_dict(torch.load('./transformermodels/transformer-1-6000.ckpt',torch.device('cpu')))
cnnencoder.eval()
transformer.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

test_img_dir = './data/train2014/COCO_train2014_000000000138.jpg'
test_img = Image.open(test_img_dir).convert('RGB')
test_img = test_img.resize([256, 256], Image.LANCZOS)
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_img = transform(test_img)

In [None]:
#予測
def predict(images, max_length, states=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    predicted_ids = []
    caption = []
    start_char = [vocab('<start>')]
    input_char = torch.tensor(start_char, device=device)
    input_char = input_char.unsqueeze(1)
    images = images.unsqueeze(0).to(device)
    feature = cnnencoder(images).to(device)
    memory = transformer.encode(feature).to(device)
    for i in range(max_length):
        with torch.no_grad():
            pad = 0
            tgt_mask = torch.triu(torch.full((8*input_char.size(0), input_char.size(1), input_char.size(1)), float('-inf')), diagonal=1).to(device)
            tgt_key_padding_mask = (input_char == pad).to(device)
            outputs = transformer.decode(input_char, memory, mask_tgt=tgt_mask, padding_mask_tgt=tgt_key_padding_mask).to(device)
            outputs = transformer.linear(outputs[:]).to(device)
            _, output_chars = torch.max(outputs,dim=2)
            if int(output_chars[:,-1]) == vocab('<end>'):
                break
            output_char = output_chars[:,-1]
            predicted_ids.append(output_char)
            input_char = torch.cat([torch.tensor(start_char, device=device).unsqueeze(0), output_chars], dim=1)
    predicted_ids = torch.stack(predicted_ids, 1)
    predicted_ids = predicted_ids[0]
    for j in range(len(predicted_ids)):
        word = vocab.idx2word[int(predicted_ids[j])]
        caption.append(word)
        sentence = ' '.join(caption[:])
    return sentence