# Pythonで学ぶ画像認識　第６章 画像キャプショニング

## 第6.4節 アテンション機構による手法〜Show, attend and tellを実装してみよう

### ライブラリの準備

In [1]:
!pip install torch torchvision pycocotools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### 実行環境の設定

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/data/coco2014/val2014.zip

### モデル学習の前処理（辞書の準備）

In [4]:
import pickle
from pycocotools.coco import COCO
from collections import Counter

# データの保存先
fp_train_caption = '/content/drive/MyDrive/data/coco2014/captions_val2014.json'
fp_word_to_id = '/content/drive/MyDrive/6_image_captioning/vocab/word_to_id.pkl'
fp_id_to_word = '/content/drive/MyDrive/6_image_captioning/vocab/id_to_word.pkl'

# キャプションを読み込み
coco = COCO(fp_train_caption)
anns_keys = coco.anns.keys()

# 単語ーID対応表の作成
coco_token = []
for key in anns_keys:
    caption = coco.anns[key]['caption']
    tokens = caption.lower().split()
    coco_token.extend(tokens)

# ピリオド、カンマを削除
table = str.maketrans({"." : "",
                       "," : ""})
for k in range(len(coco_token)):
    coco_token[k] = coco_token[k].translate(table)

# 単語ヒストグラムを作成
freq = Counter(coco_token)

# 3回以上出現する単語を限定して辞書を作成
vocab = []
common = freq.most_common()
for t,c in common:
    if c >= 3:
        vocab.append(t)
sorted(vocab)

# 特殊トークンの追加
vocab.append('<start>') # 文書の始まりを表すトークンを追加
vocab.append('<end>') # 文書の終わりを表すトークンを追加
vocab.append('<unk>') # 辞書内に無い単語を表すトークンを追加
vocab.append('<null>') # 系列長を揃えるためのトークンを追加

# 単語ー単語ID対応表の作成
word_to_id = {t:i for i,t in enumerate(vocab)}
id_to_word = {i:t for i,t in enumerate(vocab)}

# ファイル出力
with open(fp_word_to_id, 'wb') as f:
    pickle.dump(word_to_id, f)
with open(fp_id_to_word, 'wb') as f:
    pickle.dump(id_to_word, f)

print('単語数: ' + str(len(word_to_id)))

loading annotations into memory...
Done (t=0.73s)
creating index...
index created!
単語数: 8583


### エンコーダの実装

In [5]:
import torch
from torch import nn
from torchvision import models
from torch.nn.utils.rnn import pack_padded_sequence

# GPUの設定
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

class EncoderCNN(nn.Module):
    ''' Show, attend and tellのエンコーダ
    encoded_image_size: 画像部分領域サイズ
    embedding_dim:      埋込みの次元
    '''
    def __init__(self, encoded_image_size: int, 
                 embedding_dim: int):
        super(EncoderCNN, self).__init__()
        self.enc_image_size = encoded_image_size

        # IMAGENET1K_V2で事前学習された
        # ResNet152モデルをバックボーンとする
        resnet = models.resnet152(weights="IMAGENET1K_V2") 
        
        # プーリング層と全結合層を削除
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # AdaptiveAvgPool2dで部分領域(14x14)を作成
        self.adaptive_pool = nn.AdaptiveAvgPool2d(
                                (encoded_image_size, 
                                 encoded_image_size))

    ''' エンコーダの順伝播
    images : 入力画像テンソル [バッチサイズ, チャネル数, 高さ, 幅]
    '''
    def forward(self, images: torch.Tensor):
        # 特徴抽出
        features = self.resnet(images) 
        features = self.adaptive_pool(features)

        # 並び替え -> [バッチサイズ, 14, 14, 2048]
        features = features.permute(0, 2, 3, 1)

        return features

### アテンション機構の実装

In [6]:
class Attention(nn.Module):
    ''' アテンション機構 (Attention mechanism)
    encoder_dim: エンコーダ出力の特徴次元
    decoder_dim: デコーダ出力の次元
    attention_dim: アテンション機構の次元
    '''
    def __init__(self, encoder_dim: int, 
                 decoder_dim: int, attention_dim: int):
        super(Attention, self).__init__()

        # z: エンコーダ出力を変換する線形層(Wz)
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)

        # h: デコーダ出力を変換する線形層(Wh)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)

        # e: アライメントスコアを計算するための線形層
        self.full_att = nn.Linear(attention_dim, 1)

        # α: アテンション重みを計算する活性化関数/ソフトマックス層
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    ''' Attentionの順伝播
    encoder_out: エンコーダ出力
    decoder_hidden: デコーダ隠れ状態の次元
    '''
    def forward(self, encoder_out: torch.Tensor, 
                decoder_hidden: torch.Tensor):

        # e: アライメントスコア
        att1 = self.encoder_att(encoder_out) # Wz * z
        att2 = self.decoder_att(decoder_hidden) # Wh * h_{t-1}
        att = self.full_att(
                self.relu(att1 + att2.unsqueeze(1))).squeeze(2) 

        # α: T個の部分領域ごとのアテンション重み
        alpha = self.softmax(att)

        # c: コンテキストベクトル
        context_vector = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)

        return context_vector, alpha

### アテンション機構付きデコーダの実装

In [18]:
class DecoderWithAttention(nn.Module):
    ''' アテンション機構 (Attention mechanism)付きデコーダネットワーク
    attention_dim: アテンション機構の次元
    embed_dim: 埋込み次元
    decoder_dim: デコーダの次元
    vocab_size: 辞書の次元
    encoder_dim: エンコーダ出力の特徴次元
    '''
    def __init__(self, attention_dim: int, embed_dim: int, 
                 decoder_dim: int, vocab_size: int, 
                 encoder_dim: int=2048, dropout: float=0.5):
        super(DecoderWithAttention, self).__init__()

        # パラメータ
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        # アテンション機構
        self.attention = Attention( encoder_dim, 
                                    decoder_dim, 
                                    attention_dim)

        # 単語の埋め込み
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=self.dropout)

        # LSTMセル
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, 
                                       decoder_dim, bias=True)

        # LSTM隠れ状態/メモリセルを初期化
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)

         # シグモイド活性化前の線形層
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        # 単語出力用の線形層
        self.fc = nn.Linear(decoder_dim, vocab_size)

        # 埋め込み層、全結合層の重みを初期化
        self.init_weights()
        
    '''
    デコーダの重みパラメータを初期化
    '''
    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    ''' 画像特徴の平均値で隠れ状態とメモリセルを初期化
    encoder_out: エンコーダ出力 [バッチサイズ, 14, 14, 2048]
    '''
    def init_hidden_state(self, encoder_out: torch.Tensor):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

    ''' アテンション機構付きデコーダの順伝播
    encoder_out: エンコーダ出力 [バッチサイズ, 14, 14, 2048]
    encoded_captions: キャプション [バッチサイズ, 最大系列長]
    caption_lengths: 系列長 [バッチサイズ, 1]
    '''
    def forward(self, encoder_out: torch.Tensor, 
                encoded_captions: torch.Tensor,
                caption_lengths: list):
        # パラメータ
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # エンコーダ出力特徴の平坦化
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim) 
        num_pixels = encoder_out.size(1)

        # 単語埋込み
        embedded_captions = self.embedding(encoded_captions)

        # 隠れ状態ベクトル、メモリセルを初期化
        h, c = self.init_hidden_state(encoder_out)

        # 最大系列長（<end>を除く）
        caption_lengths = torch.tensor(caption_lengths)
        dec_lengths = (caption_lengths - 1).tolist()

        # キャプショニング結果を保持するためのテンソル
        predictions = torch.zeros(batch_size, 
                                    max(dec_lengths), 
                                    vocab_size).to(device)

        # アテンション重みを保持するためのテンソル
        alphas = torch.zeros(batch_size, 
                                max(dec_lengths), 
                                num_pixels).to(device)

        # センテンス予測処理
        for t in range(max(dec_lengths)):
            batch_size_t = sum([l > t for l in dec_lengths])

            # コンテキストベクトル, アテンション重み
            context_vector, alpha = self.attention(
                                        encoder_out[:batch_size_t],
                                        h[:batch_size_t])

            # LSTMセル
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            context_vector = gate * context_vector
            h, c = self.decode_step(
                torch.cat([embedded_captions[:batch_size_t, t, :],
                            context_vector], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))

            # 情報保持
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        # Show and tellの出力に合わせる
        encoded_captions = encoded_captions[:, 1:] 
        predictions = pack_padded_sequence(predictions, 
                                            dec_lengths, 
                                            batch_first=True)
        encoded_captions = pack_padded_sequence(encoded_captions, 
                                                dec_lengths, 
                                                batch_first=True)

        return predictions.data, encoded_captions.data, \
               dec_lengths, alphas

    ''' サンプリングによる説明文出力（ビームサーチ無し）
    features:   エンコーダ出力特徴 [バッチサイズ, 埋め込み次元]
    word_to_id: 単語->単語ID辞書
    id_to_word: 単語ID->単語辞書
    '''    
    def sample(self, feature: torch.Tensor, 
               word_to_id: list, id_to_word: list, 
               states=None):
        vocab_size = self.vocab_size

        # エンコーダ出力特徴の平坦化
        enc_image_size = feature.size(1)
        encoder_dim = feature.size(-1)
        feature = feature.view(1, -1, encoder_dim)
        num_pixels = feature.size(1)
        feature = feature.expand(1, num_pixels, encoder_dim)
        
        # 隠れ状態ベクトル、メモリセルを初期化
        h, c = self.init_hidden_state(feature)

        # センテンス生成の初期値として<start>を埋め込み
        id_start = word_to_id['<start>']
        prev_words = torch.LongTensor([[id_start]]).to(device) 

        # サンプリングによるセンテンス生成
        predictions = []
        alphas = []
        step = 1
        while True:
            # 単語埋め込み
            embedded_captions = self.embedding(prev_words).squeeze(1)

            # アテンション重み/コンテキストベクトルの計算
            context_vector, alpha = self.attention(feature,h)
            alpha = alpha.view(-1, enc_image_size, enc_image_size)  
            gate = self.sigmoid(self.f_beta(h))
            context_vector = gate * context_vector

            # デコード処理
            h, c = self.decode_step(
                torch.cat([embedded_captions, context_vector], 
                            dim=1), (h, c))

            preds = self.fc(self.dropout(h))
            preds = torch.nn.functional.log_softmax(preds)
            
            # 単語予測
            prob, predicted = preds.max(1)
            word = id_to_word[predicted.item()]

            # 予測結果とアテンション重みを保存
            predictions.append(predicted)
            alphas.append(alpha)

            # 次のタイムステップへ
            prev_words = torch.LongTensor(
                [predicted.item()]).to(device) 

            # 系列が長くなりすぎたらBreak
            if step > 50:
                break
            step += 1

        return predictions, alphas

### データローダの実装

In [8]:
import pickle
import random
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.utils.data.sampler import SubsetRandomSampler
from pycocotools.coco import COCO

'''
COCOデータセットローダ
batch_size:         バッチサイズ
word_to_id:         単語->単語ID辞書
fp_train_caption:   学習用のキャプション
fp_train_image_dir: 学習画像のパス
'''
def COCO_loader(batch_size: int, word_to_id: list, 
                fp_train_caption: str, 
                fp_train_image_dir: str):

    ''' トークナイザ
    文章(caption)を単語IDのリスト(tokens_id)に変換
    caption: 画像キャプション [バッチサイズ, 系列長]
    '''
    def tokenize_caption(caption: torch.Tensor):
        # 単語についたピリオド、カンマを削除
        tokens = caption.lower().split()
        tokens_temp = []
        for t in tokens:
            if t.endswith('.') and t != '.':
                tokens_temp.append(t.replace('.', ''))
            elif t.endswith(',') and t != ',':
                tokens_temp.append(t.replace(',', ''))
            elif t == '.' or t == ',':
                continue
            else:
                tokens_temp.append(t)
        tokens = tokens_temp        
        
        # 文章(caption)を単語IDのリスト(tokens_id)に変換
        tokens_ext = ['<start>'] + tokens + ['<end>']
        tokens_id = []
        for k in tokens_ext:
            if k in word_to_id:
                tokens_id.append(word_to_id[k])
            else:
                tokens_id.append(word_to_id['<unk>'])
        return torch.Tensor(tokens_id)

    '''
    COCOデータセットからデータを取り出すためのcollate関数
    '''
    def cap_collate_fn(data):
        images, captions = zip(*data)
        captions = [tokenize_caption(cap[random.randrange(len(cap))]) for cap in captions]
        
        data = zip(images, captions)
        data = sorted(data, key=lambda x: len(x[1]), reverse=True)
        images, captions = zip(*data)
        images = torch.stack(images, 0)

        lengths = [len(c) for c in captions]
        targets = torch.zeros(len(captions), max(lengths)).long()
        targets[:] = word_to_id['<null>']   # nullでパディング
        for i,c in enumerate(captions):
            end = lengths[i]
            targets[i,:end] = c[:end]
        return images, targets, lengths
 
    # 画像のtransformsを定義
    crop_size = (224,224)             # CNN入力画像サイズ
    in_mean = (0.485, 0.456, 0.406)   # ImageNetの平均値
    in_std = (0.229, 0.224, 0.225)    # ImageNetの標準偏差
    trans = transforms.Compose([
            transforms.Resize(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(in_mean, in_std) 
    ])

    # COCOデータロードの定義
    train_val_set = dset.CocoCaptions(root=fp_train_image_dir, 
                                        annFile=fp_train_caption, 
                                        transform=trans)
            
    # データサブセットを取得するサンプラーの定義
    # 学習データ70%、評価データ30%に分割
    n_samples = len(train_val_set)
    indices = list(range(n_samples))
    tr_split = int(0.7 * n_samples)      
    train_idx, val_idx = indices[:tr_split], indices[tr_split:]
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    # Dataloaderを生成
    train_loader = torch.utils.data.DataLoader(
                        train_val_set, 
                        batch_size=batch_size, 
                        num_workers=4, 
                        sampler=train_sampler,
                        collate_fn=cap_collate_fn)

    val_loader = torch.utils.data.DataLoader(
                        train_val_set, 
                        batch_size=batch_size, 
                        num_workers=4, 
                        sampler=val_sampler,
                        collate_fn=cap_collate_fn)
                                            
    return train_loader, val_loader

### Configクラスの実装

In [9]:
import os
import pickle

class Config(object):
    '''
    ハイパーパラメータ、グローバル変数の設定
    '''    
    def __init__(self):

        # ハイパーパラメータ（Show, attend and tell用）
        self.enc_image_size = 14    # Attention計算用画像サイズ
        self.attention_dim = 128    # Attention層の次元
        self.embedding_dim = 128    # 埋め込み層の次元
        self.hidden_dim = 128       # LSTM隠れ層の次元
        self.num_layers = 2         # LSTM階層の数
        self.max_seg_len = 30       # 最大系列長

        # ハイパーパラメータ（学習用）
        self.learning_rate = 0.001  # 学習率
        self.batch_size = 30        # ミニバッチの数
        self.num_epochs = 30        # エポック
        
        # グローバル変数
        self.fp_train_cap = '/content/drive/MyDrive/data/coco2014/captions_val2014.json'
        self.fp_train_image_dir = 'val2014'
        self.fp_word_to_id = '/content/drive/MyDrive/6_image_captioning/vocab/word_to_id.pkl'
        self.fp_id_to_word = '/content/drive/MyDrive/6_image_captioning/vocab/id_to_word.pkl'
        self.fp_model_dir = '/content/drive/MyDrive/6_image_captioning/model'

        # 辞書（単語→単語ID）の読み込み
        with open(self.fp_word_to_id, 'rb') as f:
            self.word_to_id = pickle.load(f)

        # 辞書（単語ID→単語）の読み込み
        with open(self.fp_id_to_word, 'rb') as f:
            self.id_to_word = pickle.load(f)

        # 辞書サイズ
        self.vocab_size = len(self.word_to_id)

        # モデル出力用のディレクトリ
        if not(os.path.isdir(self.fp_model_dir)):
            os.makedirs(self.fp_model_dir)

### 学習実装

In [10]:
import os
import pickle
import numpy as np
import datetime
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm import tqdm
from torch.utils.data.dataset import Subset

'''
Show, attend and Tellの学習
'''
def train():

    # GPUの設定
    device = 'cuda' if torch.cuda.is_available() else 'cpu' 

    # ハイパーパラメータの設定
    cfg = Config()

    # 学習データの読み込み
    train_loader, valid_loader = COCO_loader(cfg.batch_size, 
                                    cfg.word_to_id, 
                                    cfg.fp_train_cap, 
                                    cfg.fp_train_image_dir)
    # モデルの定義
    encoder = EncoderCNN(cfg.enc_image_size, 
                         cfg.embedding_dim).to(device)
    decoder = DecoderWithAttention(cfg.attention_dim, 
                                    cfg.embedding_dim, 
                                    cfg.hidden_dim, 
                                    cfg.vocab_size).to(device)
    
    # 損失関数の定義
    criterion = nn.CrossEntropyLoss()

    # 最適化手法の定義
    params = list(decoder.parameters()) + \
             list(encoder.adaptive_pool.parameters())
    optimizer = torch.optim.AdamW(params, lr=cfg.learning_rate)

    # 学習経過の書き込み
    now = datetime.datetime.now()
    fp_train_loss_out = '{}/6-4_train_loss_{}.csv'\
        .format(cfg.fp_model_dir, now.strftime('%Y%m%d_%H%M%S'))
    fp_val_loss_out = '{}/6-4_val_loss_{}.csv'\
        .format(cfg.fp_model_dir, now.strftime('%Y%m%d_%H%M%S'))

    # 学習
    print("学習開始")
    val_loss_best = float('inf')
    for epoch in range(cfg.num_epochs):
        with tqdm(train_loader) as pbar:
            pbar.set_description("[Train epoch %d]" % (epoch + 1))
            train_losses = []
            for i, (images, captions,lengths) in enumerate(pbar):

                # 学習モード
                encoder.train()
                decoder.train()

                # ミニバッチを設定
                images, captions = \
                    images.to(device), captions.to(device)
                targets = pack_padded_sequence(captions, 
                                               lengths, 
                                               batch_first=True)[0]                
                optimizer.zero_grad()

                # Forward
                features = encoder(images)
                outputs, targets, decode_lengths, alphas = \
                    decoder(features, captions, lengths)
                loss = criterion(outputs, targets)

                # backward
                loss.backward()
                optimizer.step()

                # Training Lossをログに書き込み
                train_losses.append(loss.item())
                with open(fp_train_loss_out, 'a') as f:
                    print("{},{}".format(epoch, loss.item()), file=f)

        # Loss 表示
        print("Training loss: {}".format(np.average(train_losses)))

        # validation
        with tqdm(valid_loader) as pbar:
            pbar.set_description("[Validation %d]" % (epoch + 1))
            val_losses = []
            for j, (images, captions,lengths) in enumerate(pbar):

                # 評価モード
                encoder.eval()
                decoder.eval()

                # ミニバッチを設定
                images, captions = \
                    images.to(device), captions.to(device)
                targets = pack_padded_sequence(captions, 
                                               lengths, 
                                               batch_first=True)[0]

                features = encoder(images)
                outputs, targets, decode_lengths, alphas = \
                    decoder(features, captions, lengths)
                val_loss = criterion(outputs, targets)
                val_losses.append(val_loss.item())

                # Validation Lossをログに書き込み
                with open(fp_val_loss_out, 'a') as f:
                    print("{},{}".format(epoch, val_loss.item()), file=f)

        # Loss 表示
        val_loss = np.average(val_losses)
        print("Validation loss: {}".format(val_loss))

        # より良い検証結果が得られた場合、モデルを保存
        if val_loss < val_loss_best:
            val_loss_best = val_loss

            # エンコーダモデルを保存
            fp_encoder = '{}/6-4_encoder_best.pth'.format(cfg.fp_model_dir)
            torch.save(encoder.to('cpu').state_dict(), fp_encoder)
            encoder.to(device)

            # デコーダモデルを保存
            fp_decoder = '{}/6-4_decoder_best.pth'.format(cfg.fp_model_dir)
            torch.save(decoder.to('cpu').state_dict(), fp_decoder)
            decoder.to(device)
    
    print("学習終了")

if __name__ == '__main__':
    train()

loading annotations into memory...
Done (t=0.34s)
creating index...
index created!


Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

学習開始


[Train epoch 1]: 100%|██████████| 946/946 [03:04<00:00,  5.14it/s]


Training loss: 4.6005125449021325


[Validation 1]: 100%|██████████| 406/406 [00:31<00:00, 12.77it/s]


Validation loss: 3.8196538105386817


[Train epoch 2]: 100%|██████████| 946/946 [02:45<00:00,  5.71it/s]


Training loss: 3.765526287399437


[Validation 2]: 100%|██████████| 406/406 [00:32<00:00, 12.49it/s]


Validation loss: 3.4387045576067394


[Train epoch 3]: 100%|██████████| 946/946 [02:46<00:00,  5.68it/s]


Training loss: 3.529830376643207


[Validation 3]: 100%|██████████| 406/406 [00:32<00:00, 12.35it/s]


Validation loss: 3.2687661371794827


[Train epoch 4]: 100%|██████████| 946/946 [02:46<00:00,  5.68it/s]


Training loss: 3.3930501428769455


[Validation 4]: 100%|██████████| 406/406 [00:32<00:00, 12.48it/s]


Validation loss: 3.159297685317805


[Train epoch 5]: 100%|██████████| 946/946 [02:48<00:00,  5.63it/s]


Training loss: 3.3002687049970567


[Validation 5]: 100%|██████████| 406/406 [00:32<00:00, 12.53it/s]


Validation loss: 3.087300557808336


[Train epoch 6]: 100%|██████████| 946/946 [02:46<00:00,  5.67it/s]


Training loss: 3.217187118328651


[Validation 6]: 100%|██████████| 406/406 [00:32<00:00, 12.64it/s]


Validation loss: 3.0351133023576784


[Train epoch 7]: 100%|██████████| 946/946 [02:47<00:00,  5.64it/s]


Training loss: 3.1607824538777294


[Validation 7]: 100%|██████████| 406/406 [00:32<00:00, 12.58it/s]


Validation loss: 3.0016791033627364


[Train epoch 8]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 3.126956529395525


[Validation 8]: 100%|██████████| 406/406 [00:32<00:00, 12.58it/s]


Validation loss: 2.953347954844019


[Train epoch 9]: 100%|██████████| 946/946 [02:46<00:00,  5.67it/s]


Training loss: 3.0781797735181975


[Validation 9]: 100%|██████████| 406/406 [00:32<00:00, 12.61it/s]


Validation loss: 2.92090480433309


[Train epoch 10]: 100%|██████████| 946/946 [02:47<00:00,  5.63it/s]


Training loss: 3.052268370261404


[Validation 10]: 100%|██████████| 406/406 [00:32<00:00, 12.65it/s]


Validation loss: 2.906334676178805


[Train epoch 11]: 100%|██████████| 946/946 [02:46<00:00,  5.68it/s]


Training loss: 3.0153299623009517


[Validation 11]: 100%|██████████| 406/406 [00:32<00:00, 12.49it/s]


Validation loss: 2.884445573896023


[Train epoch 12]: 100%|██████████| 946/946 [02:46<00:00,  5.69it/s]


Training loss: 2.988496896328432


[Validation 12]: 100%|██████████| 406/406 [00:32<00:00, 12.61it/s]


Validation loss: 2.880455274887273


[Train epoch 13]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.963229777444997


[Validation 13]: 100%|██████████| 406/406 [00:32<00:00, 12.68it/s]


Validation loss: 2.8317553668186584


[Train epoch 14]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.940013558617606


[Validation 14]: 100%|██████████| 406/406 [00:32<00:00, 12.53it/s]


Validation loss: 2.8208623317074895


[Train epoch 15]: 100%|██████████| 946/946 [02:46<00:00,  5.67it/s]


Training loss: 2.9227093281755994


[Validation 15]: 100%|██████████| 406/406 [00:32<00:00, 12.58it/s]


Validation loss: 2.8147060066608374


[Train epoch 16]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.9089824850917116


[Validation 16]: 100%|██████████| 406/406 [00:32<00:00, 12.59it/s]


Validation loss: 2.8100951587038088


[Train epoch 17]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.8810201376739566


[Validation 17]: 100%|██████████| 406/406 [00:32<00:00, 12.56it/s]


Validation loss: 2.795070484353991


[Train epoch 18]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.8766321289110888


[Validation 18]: 100%|██████████| 406/406 [00:32<00:00, 12.61it/s]


Validation loss: 2.793437493845747


[Train epoch 19]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.85223215377356


[Validation 19]: 100%|██████████| 406/406 [00:32<00:00, 12.54it/s]


Validation loss: 2.782191715804227


[Train epoch 20]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.8419237250253464


[Validation 20]: 100%|██████████| 406/406 [00:32<00:00, 12.53it/s]


Validation loss: 2.7720210951537334


[Train epoch 21]: 100%|██████████| 946/946 [02:46<00:00,  5.68it/s]


Training loss: 2.829322222171324


[Validation 21]: 100%|██████████| 406/406 [00:32<00:00, 12.60it/s]


Validation loss: 2.7596496971957203


[Train epoch 22]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.8213256529471336


[Validation 22]: 100%|██████████| 406/406 [00:32<00:00, 12.47it/s]


Validation loss: 2.768274192152352


[Train epoch 23]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.8091993221016818


[Validation 23]: 100%|██████████| 406/406 [00:32<00:00, 12.53it/s]


Validation loss: 2.74632114406877


[Train epoch 24]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.7911395388980242


[Validation 24]: 100%|██████████| 406/406 [00:32<00:00, 12.50it/s]


Validation loss: 2.750002176009963


[Train epoch 25]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.78305109509958


[Validation 25]: 100%|██████████| 406/406 [00:32<00:00, 12.50it/s]


Validation loss: 2.751508169573516


[Train epoch 26]: 100%|██████████| 946/946 [02:46<00:00,  5.67it/s]


Training loss: 2.78392867457287


[Validation 26]: 100%|██████████| 406/406 [00:32<00:00, 12.63it/s]


Validation loss: 2.734361811518082


[Train epoch 27]: 100%|██████████| 946/946 [02:47<00:00,  5.64it/s]


Training loss: 2.7758394215122086


[Validation 27]: 100%|██████████| 406/406 [00:32<00:00, 12.54it/s]


Validation loss: 2.735458015808331


[Train epoch 28]: 100%|██████████| 946/946 [02:46<00:00,  5.68it/s]


Training loss: 2.7574122375464087


[Validation 28]: 100%|██████████| 406/406 [00:32<00:00, 12.63it/s]


Validation loss: 2.7402892500308935


[Train epoch 29]: 100%|██████████| 946/946 [02:47<00:00,  5.66it/s]


Training loss: 2.7499215683523746


[Validation 29]: 100%|██████████| 406/406 [00:32<00:00, 12.46it/s]


Validation loss: 2.7366784576124745


[Train epoch 30]: 100%|██████████| 946/946 [02:47<00:00,  5.65it/s]


Training loss: 2.7475697679197055


[Validation 30]: 100%|██████████| 406/406 [00:32<00:00, 12.57it/s]


Validation loss: 2.731278054232668
学習終了


### 推論（画像キャプショニング）

In [25]:
import torch
import torchvision.transforms as transforms
import glob
import os
import matplotlib.pyplot as plt
import skimage.transform
import matplotlib.cm as cm
from PIL import Image

''' 画像読み込み
image_file:   画像ファイル
transform:    画像変換
'''
def load_image(image_file: str, transform=None):
    image = Image.open(image_file)
    image = image.resize([224, 224], Image.LANCZOS)
    if transform is not None:
        image = transform(image).unsqueeze(0)
    return image

''' 
画像キャプショニングの推論
'''
def infer(fp_encoder: str, fp_decoder: str, fp_infer_image_dir: str):
  
    # GPUを利用
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Running in %s." % device)

    # パラメータ設定
    cfg = Config()
    
    # 画像の正規化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # エンコーダモデルの定義
    encoder = EncoderCNN(cfg.enc_image_size, cfg.embedding_dim)
    encoder = encoder.to(device).eval()

    # デコーダモデルの定義
    decoder = DecoderWithAttention(cfg.attention_dim, cfg.embedding_dim, cfg.hidden_dim, cfg.vocab_size)
    decoder = decoder.to(device).eval()

    # モデルの学習済み重みパラメータをロード
    encoder.load_state_dict(torch.load(fp_encoder), strict=False)
    decoder.load_state_dict(torch.load(fp_decoder), strict=False)
    print('エンコーダ: {}'.format(fp_encoder))
    print('デコーダ: {}'.format(fp_decoder))

    for image_file in sorted(glob.glob(os.path.join(fp_infer_image_dir, "*.jpg"))):

        # 画像読み込み
        print("ファイル名: {}".format(os.path.basename(image_file)))
        image = load_image(image_file, transform).to(device)

        # Encoder-decoderによる予測
        with torch.no_grad():

            # encoder
            feature = encoder(image)
            enc_image_size = feature.size(1)
            encoder_dim = feature.size(3)

            # decoder
            predictions, alphas = decoder.sample(feature, cfg.word_to_id, cfg.id_to_word)

        # 可視化
        sampled_caption = []
        word_len = len(predictions)
        image_plt = Image.open(image_file)
        image_plt = image_plt.resize([224, 224], Image.LANCZOS)
        plt.imshow(image_plt)
        plt.axis('off')
        plt.show()
        for t in range(word_len):

            # Attention重みを可視化
            cur_alpha = alphas[t]
            alpha = cur_alpha.to('cpu').numpy()
            alpha = skimage.transform.pyramid_expand(alpha[0, :, :], upscale=16, sigma=8)

            # キャプショニング
            word_id = predictions[t]
            word = cfg.id_to_word[word_id.item()]
            sampled_caption.append(word)

            # タイムステップtの画像をプロット
            plt.imshow(image_plt)
            plt.text(0, 1, '%s' % (word), color='black', backgroundcolor='white', fontsize=12)
            plt.imshow(alpha, alpha=0.8)
            plt.set_cmap(cm.Greys_r)
            plt.axis('off')
            plt.show()

            if word == '<end>':
                break
        
        sentence = ' '.join(sampled_caption)
        print ("  {}".format(sentence))

        # 推定結果を書き込み
        gen_sentence_out = image_file[:-4] + "_show_attend_and_tell.txt"
        with open(gen_sentence_out, 'w') as f:
            print("{}".format(sentence), file=f)

### 推論の実行

In [26]:
# 画像キャプショニング推論
fp_encoder = '/content/drive/MyDrive/6_image_captioning/model/6-4_encoder_best.pth'
fp_decoder = '/content/drive/MyDrive/6_image_captioning/model/6-4_decoder_best.pth'
fp_infer_image_dir = '/content/drive/MyDrive/data/image_captioning/'    

infer(fp_encoder, fp_decoder, fp_infer_image_dir)

Output hidden; open in https://colab.research.google.com to view.