# 사용법
1. 드라이브 연결
2. Tokenizer - KoBERT 밑의 셀 실행
3. 데이터 로드 -> Paths 밑에 압축파일 **경로 알아서 바꾸고** 압축해제 (사람마다 경로가 다를수있음)
4. 나머지 셀 순서대로 실행

## 중간에 끊겨도 상관없음!!

# 0. Connect Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

BASE_DIR = '/content/drive/MyDrive/image_caption'

Mounted at /content/drive


# 1. Tokenizer

### KoBERT

In [None]:
# mxnet, gluonnlp, sentencepiece
!pip install mxnet gluonnlp sentencepiece transformers

# kobert
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting mxnet
[?25l  Downloading https://files.pythonhosted.org/packages/30/07/66174e78c12a3048db9039aaa09553e35035ef3a008ba3e0ed8d2aa3c47b/mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9MB)
[K     |████████████████████████████████| 46.9MB 98kB/s 
[?25hCollecting gluonnlp
[?25l  Downloading https://files.pythonhosted.org/packages/9c/81/a238e47ccba0d7a61dcef4e0b4a7fd4473cb86bed3d84dd4fe28d45a0905/gluonnlp-0.10.0.tar.gz (344kB)
[K     |████████████████████████████████| 348kB 35.7MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/f5/99/e0808cb947ba10f575839c43e8fafc9cc44e4a7a2c8f79c60db48220a577/sentencepiece-0.1.95-cp37-cp37m-manylinux2014_x86_64.whl (1.2MB)
[K     |████████████████████████████████| 1.2MB 31.7MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d5/43/cfe4ee779bbd6a678ac6a97c5a5cdeb03c35f9eaebbb9720b036680f9a2d/transformers-4.6.1-py3-none-any.whl (2.2MB)
[K    

In [None]:
from gluonnlp.data import SentencepieceTokenizer
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model


class KoBERTTokenizer:
    def __init__(self):
        self.type = 'KOBERT'
        self.model, self.vocab = get_pytorch_kobert_model()
        self.tokenizer = SentencepieceTokenizer(get_tokenizer())

    def tokenize(self, sentence: str):
        return self.tokenizer(sentence)

    def idx_to_token(self, idx):
        return self.vocab.idx_to_token[idx]

    def token_to_idx(self, token):
        return self.vocab.token_to_idx[token]

    def get_word_map(self):
        return self.vocab.token_to_idx

    def get_vocab_size(self):
        return len(self.vocab.token_to_idx)  # 7002

    def get_embedding_dim(self):
        return list(self.model.embeddings.children())[0].embedding_dim  # 768

    def get_pretrained_embedding(self):
        return self.model.embeddings.word_embeddings


tokenizer = KoBERTTokenizer()

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]
using cached model


# 2. 데이터 로드


## Paths

In [None]:
# Sample data (12,000 images)
# !unzip -qq /content/drive/MyDrive/image_caption/data/train2014.zip
# !unzip -qq /content/drive/MyDrive/image_caption/data/val2014.zip

# Full-size data (120,000 images)
!tar -zxf /content/drive/MyDrive/image_caption/data/train2014.tar.gz
!tar -zxf /content/drive/MyDrive/image_caption/data/val2014.tar.gz

In [None]:
import os

# BASE_DIR = os.getcwd()

# 내 폴더구조에 알아서 맞게 사용하기 !!!
BASE_DIR = '/content/drive/MyDrive/image_caption'

data_folder = os.path.join(BASE_DIR, 'data')

# root_dir = "/content/drive/MyDrive/0.졸업프로젝트_공유/1.data"
tokenizer_dir = os.path.join(BASE_DIR, 'tokenizer')

# model_dir = os.path.join(tokenizer_dir, "model")
komoran_dict = os.path.join(tokenizer_dir, "userdict.txt")

In [None]:
komoran_w2v_model_path = os.path.join(tokenizer_dir, "KOMORAN_W2V.model")
okt_w2v_model_path = os.path.join(tokenizer_dir, "OKT_W2V.model")
mecab_w2v_model_path = os.path.join(tokenizer_dir, "MECAB_W2V.model")

komoran_glove_model_path = os.path.join(tokenizer_dir, "KOMORAN_GLOVE_COLAB.model")
okt_glove_model_path = os.path.join(tokenizer_dir, "OKT_GLOVE_COLAB.model")
mecab_glove_model_path = os.path.join(tokenizer_dir, "MECAB_GLOVE.model")

## Dataset Class



In [None]:
import os
import json
import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class CaptionDataset(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """

    def __init__(self, data_folder, mode):
        super(CaptionDataset, self).__init__()
        
        if mode not in ['TRAIN', 'VAL']:
            raise ValueError(
                '"mode" must be either "TRAIN" or "VAL". Got "{}"."'.format(mode))

        self.data_folder = data_folder
        self.caption_per_image = 5
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

        filename = '{}_data.json'.format(mode.lower())
        with open(os.path.join(data_folder, filename), 'r') as fp:
            raw_data = json.load(fp)

        self.images = raw_data['images']
        self.captions = raw_data['captions']
        self.caplens = raw_data['caplens']
        self.all_captions = raw_data['all_captions']

        self.data_size = len(self.images)

    def __getitem__(self, i):
        image_path = os.path.join('data', self.images[i])
        image = self.transform(Image.open(image_path).convert('RGB'))  
        caption = self.captions[i]
        caplen = self.caplens[i]
        all_captions = self.all_captions[i]

        caption = torch.LongTensor(caption)
        caplen = torch.LongTensor([caplen])
        all_caption = torch.LongTensor(all_captions)

        return image, caption, caplen, all_caption

    def __len__(self):
        return self.data_size

## 실행

In [None]:
batch_size = 32
num_workers = 4

train_dataset = CaptionDataset(data_folder, 'TRAIN')
val_dataset = CaptionDataset(data_folder, 'VAL')

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

# 3. 모델링


## 3.1. Encoder

In [None]:
from torchvision.models import resnet50
import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size
        resnet = resnet50(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune()

    def forward(self, images):
        # (batch_size, 2048, image_size/32, image_size/32)
        out = self.resnet(images)
        # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = self.adaptive_pool(out)
        # (batch_size, encoded_image_size, encoded_image_size, 2048)
        out = out.permute(0, 2, 3, 1)
        return out

    def fine_tune(self, fine_tune=True):
        # parameter update되지 않도록 고정
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

## 3.2. Attention Module

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, encoder_feature_size, decoder_hidden_size, attention_size):
        super(Attention, self).__init__()
        self.encoder_feature_size = encoder_feature_size    #feature size of encoded img
        self.decoder_hidden_size = decoder_hidden_size  #size of decoders RNN
        self.attention_size = attention_size

        self.encoder_att = nn.Linear(encoder_feature_size, attention_size)  #transform encoded img
        self.decoder_att = nn.Linear(decoder_hidden_size, attention_size)   #transform decoders output
        self.f_beta = nn.Linear(decoder_hidden_size, encoder_feature_size)  #create a sigmoid-activated gate
        self.full_att = nn.Linear(attention_size, 1)    #caculate values to be softmax-ed

    def forward(self, encoder_output, decoder_hidden):
        # encoder_output : (batch_size, num_pixels, encoder_feature_size)
        # decoder_hideen: (batch_size, decoder_hidden_size)
        # (batch_size, num_pixels, attention_size)
        att1 = self.encoder_att(encoder_output)
        # (batch_size, attention_size)
        att2 = self.decoder_att(decoder_hidden)

        # att2.unsqueeze(1) -> (batch_size, 1, attention_size)
        # att1 + att2.unsqueeze(1) -> (batch_size, num_pixels, attention_size)
        # self.full_att(att1+att2.unsqueeze(1)) -> (batch_size, num_pixels, 1)
        att = self.full_att(F.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = F.softmax(att, dim=1)  # (batch_size, num_pixels)

        # encoder_output : (batch_size, num_pixels, encoder_feature_size)
        # alpha.unsqueeze(2) : (batch_size, num_pixels, 1)
        # encoder_output*alpha.unsqueeze(2) -> (batch_size, num_pixels, encoder_feature_size)
        attention_weighted_encoding = (encoder_output * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_feature_size)

        gate = torch.sigmoid(self.f_beta(decoder_hidden)) # [batch_size, enc_feature_size]

        # hadamard product (gate and attention_weighted_encoding)
        attention_weighted_encoding = gate * attention_weighted_encoding # [batch_size, enc_feature_size]

        return attention_weighted_encoding

## 3.3. Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, attention, tokenizer, hidden_size, dropout=0.5):
        super(Decoder, self).__init__()
        self.embedding_size = tokenizer.get_embedding_dim()
        self.hidden_size = hidden_size
        self.output_size = tokenizer.get_vocab_size()

        self.embedding = tokenizer.get_pretrained_embedding()
        self.attention = attention
        self.lstm = nn.LSTM(
            self.embedding_size + self.attention.encoder_feature_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, self.output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_outputs, captions, hidden, cell):
        attention_weights = self.attention(encoder_outputs, hidden)
        
        # embedding: [batch_size, 1, embedding_size]
        # attention: [batch_size, encoder_feature_size]

        hidden = hidden.unsqueeze(0)
        cell = cell.unsqueeze(0)

        # [batch_size, 1]
        token = captions.unsqueeze(1)
        embedded = self.dropout(self.embedding(token))
        attn = attention_weights.unsqueeze(1)

        lstm_input = torch.cat([embedded, attn], dim=2)
        lstm_output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))

        output = self.fc(lstm_output)

        return output.squeeze(1), hidden.squeeze(0), cell.squeeze(0), attention_weights

    def init_hidden(self, batch_size, device='cpu'):
        hidden = torch.zeros(batch_size, self.hidden_size, device=device)
        cell = torch.zeros(batch_size, self.hidden_size, device=device)
        return hidden, cell

## 3.4. ImageCaptioner - Integrated Model

In [None]:
import torch
import random
from torch.nn.utils.rnn import pack_padded_sequence


class ImageCaptioner(nn.Module):
    def __init__(self, tokenizer, encoder_feature_size, decoder_hidden_size, attention_size, dropout=0.5):
        super(ImageCaptioner, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.encoder_feature_size = encoder_feature_size
        self.decoder_hidden_size = decoder_hidden_size
        self.attention_size = attention_size

        self.encoder = Encoder(encoded_image_size=16).to(self.device)
        self.attention = Attention(encoder_feature_size, decoder_hidden_size, attention_size).to(self.device)
        self.decoder = Decoder(self.attention, tokenizer, decoder_hidden_size, dropout=dropout).to(self.device)

    def forward(self, images, captions, caption_lengths, teacher_forcing_ratio=0.5):
        # images: 3 * 256 * 256 으로 인코딩된거 -> [batch_size, 3, 256, 256]    
        # captions: 인코딩되고 패딩까지 추가된 토큰들 -> [batch_size, cap_len]
        # caption_lengths: 패딩이 추가안된 순수 문장 길이들 -> [batch_size, 1]
        batch_size, cap_len = captions.size()

        # img -> encoder [batch_size, img_size, img_size, enc_feature_size]
        encoder_outputs = self.encoder(images).reshape(batch_size, -1, encoder_feature_size)
        num_pixels = encoder_outputs.size(1)

        # caption length 별 내림차순 정렬
        caption_lengths = caption_lengths.squeeze(1)  # [batch_size]
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_outputs = encoder_outputs[sort_ind]
        captions = captions[sort_ind]

        hidden, cell = self.decoder.init_hidden(batch_size, device=self.device)  # [batch_size, decoder_hidden_size]
        decode_lengths = caption_lengths.tolist()
        
        # pack_padded_sequence: <PAD> 토큰들을 죄다 지워줌!!
        # batch_sizes: <PAD> 토큰을 지우니깐 문장 길이가 달라짐 
        # -> 디코더 스텝별로 들어가는 데이터 길이가 달라지니깐 
        # -> 스텝마다 몇개가 들어가는지 리턴!!!
        captions_packed = pack_padded_sequence(captions, decode_lengths, batch_first=True)
        captions_packed_data = captions_packed.data
        batch_sizes = captions_packed.batch_sizes

        # 예측값 저장하는 tensor를 만듦
        vocab_size = tokenizer.get_vocab_size()
        predictions = torch.zeros(batch_size, cap_len, vocab_size, device=self.device)

        # 초기값
        index = 0
        batch_size = batch_sizes[0]
        decoder_input = captions_packed_data[:batch_size]

        num_batch_sizes = len(batch_sizes)

        for t in range(1, num_batch_sizes):
            encoder_outputs = encoder_outputs[:batch_size]
            hidden = hidden[:batch_size]
            cell = cell[:batch_size]

            decoder_output, hidden, cell, _ = self.decoder(encoder_outputs, decoder_input, hidden, cell)
            predictions[:batch_size, t] = decoder_output

            # 다음 스텝
            batch_size = batch_sizes[t]
            index += batch_size

            # 지도학습
            teacher_forcing = random.random() < teacher_forcing_ratio
            decoder_input = captions_packed_data[index:index + batch_size] if teacher_forcing else decoder_output.argmax(1)[:batch_size]

        # predictions: 예측값 (전체 단어들에 대한 확률들 싹 다) -> [batch_size, cap_len, vocab_size]
        # captions: 정렬된 정답들
        # decode_lengths: 정렬된 문장 길이들
        # sort_ind: 원본에서의 순서 [2, 4, 1, 3, 5] => [4, 1, 3, 0, 2]
        return predictions, captions[:, 1:], decode_lengths, sort_ind

## 3.5. 모델 선언

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model parameters
encoder_feature_size = 2048
decoder_hidden_size = 1024
attention_size = 512
dropout = 0.5
learning_rate = 0.001
# decay_rate = 0.9

model = ImageCaptioner(tokenizer, encoder_feature_size, decoder_hidden_size, attention_size, dropout=dropout).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




# 4. 학습 및 검증

## 4.1. Train

In [None]:
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm


def train(model, train_loader, criterion, optimizer, device='cpu'):
    model.train()

    epoch_loss = 0
    # scaler = torch.cuda.amp.GradScaler()

    for i, (images, captions, caplens, _) in enumerate(tqdm(train_loader)):        
        images = images.to(device)
        captions = captions.to(device)
        caplens = caplens.to(device)

        optimizer.zero_grad()

        # with autocast():
        # predictions: [batch_size, cap_len, vocab_size]
        # caption: [batch_size, cap_len]
        predictions, captions, caplens, sort_ind = model(
            images, 
            captions, 
            caplens, 
            teacher_forcing_ratio=0.5
        )

        predictions = pack_padded_sequence(predictions, caplens, batch_first=True).data
        captions = pack_padded_sequence(captions, caplens, batch_first=True).data

        loss = criterion(predictions, captions)
        epoch_loss += loss.item()

        # scaler.scale(loss).backward()
        loss.backward()

        # Gradient Clipping        
        # scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), 1)

        # scaler.step(optimizer)
        # scaler.update()
        optimizer.step()


    return epoch_loss / len(train_loader)

## 4.2. Validate

In [None]:
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction


def validate(model, val_loader, criterion, device='cpu'):
    epoch_loss = 0

    model.eval()

    with torch.no_grad():
        for _, (images, captions, caplens, _) in enumerate(tqdm(val_loader)):
            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)

            predictions, captions, caplens, sort_ind = model(
                images, captions, caplens, teacher_forcing_ratio=0)
    
            preds = pack_padded_sequence(predictions, caplens, batch_first=True).data
            captions = pack_padded_sequence(captions, caplens, batch_first=True).data

            loss = criterion(preds, captions)

            epoch_loss += loss.item()

    epoch_loss /= len(val_loader)  

    return epoch_loss

## 4.3. 실행

In [None]:
import os
import math
import time
import torch
from torch.autograd import profiler
from torch.utils.tensorboard import SummaryWriter


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


model_dir = os.path.join(BASE_DIR, 'output')
checkpoint_path = os.path.join(model_dir, 'training-model-resnet50.pt')
model_path = os.path.join(model_dir, 'savepoint-resnet50.pt')

# gc.collect()
# torch.cuda.empty_cache()

# Train Parameters
num_epochs = 100
patience = 20
early_stop_counting = 0

try:
    # Load Checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # Load model and optimizer
    model.load_state_dict(checkpoint.get('model_state'))
    optimizer.load_state_dict(checkpoint.get('optimizer'))
    # Load epoch and losses
    start_epoch = checkpoint.get('epoch')
    train_loss = checkpoint.get('train_loss')
    val_loss = checkpoint.get('val_loss')

    if not val_loss:
        val_loss = float('inf')
except (FileNotFoundError, RuntimeError):
    start_epoch = 0
    train_loss = float('inf')
    val_loss = float('inf')

best_valid_loss = val_loss

print('Resume training from {} epoch, {:.4f} train loss, {:.4f} valid loss\n'.format(start_epoch, train_loss, val_loss))

for epoch in range(start_epoch, num_epochs):
    start_time = time.time()

    train_loss = train(model, train_loader, criterion, optimizer, device=device)
    state_dict = {
        'epoch': epoch + 1,
        'model_state': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss': train_loss,
    }
    torch.save(state_dict, checkpoint_path)

    val_loss = validate(model, val_loader, criterion, device=device)
    state_dict = {
        'epoch': epoch + 1,
        'model_state': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }
    torch.save(state_dict, checkpoint_path)
    
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(
        'Epoch: [{:02d}/{:02d}] | Time: {:02d}m {:02d}s | Train Loss: {:.4f} | Val. Loss: {:.4f} |' \
            .format(epoch + 1, num_epochs, epoch_mins, epoch_secs, train_loss, val_loss), 
        end=' '
    )

    if val_loss < best_valid_loss:
        # save model output state
        best_valid_loss = val_loss
        state_dict = {
            'encoder': model.encoder.state_dict(),
            'decoder': model.decoder.state_dict(),
        }
        torch.save(state_dict, model_path)
        print('| Improvement!')
    else:
        early_stop_counting += 1
        print('|')

    if early_stop_counting >= patience:
        print('Early Stopping')
        break

## 4.4. 검증

In [None]:
from nltk.translate.bleu_score import corpus_bleu


def test(model, test_dataset, test_data_size, device='cpu'):
    model.eval()

    candidates = []
    references = []

    with torch.no_grad():
        for i in range(test_data_size):
            images, captions, caplens, all_captions = test_dataset[i]
            images = images.unsqueeze(0).to(device)
            captions = captions.unsqueeze(0).to(device)
            caplens = caplens.unsqueeze(0).to(device)

            predictions, captions, caplens, sort_ind = model(
                images, captions, caplens, teacher_forcing_ratio=0)

            _, top1 = torch.max(predictions, dim=2)
            top1 = top1.squeeze(0)

            stop_words = ['[UNK]', '[CLS]', '[SEP]', '[PAD]']
            stop_words_idx = [tokenizer.token_to_idx(token) for token in stop_words]

            # convert top1(우리가 생성한 토큰) to tokens: Candidate
            prediction_translated = [
                tokenizer.idx_to_token(idx) for idx in top1
                if idx not in stop_words_idx
            ]
 
            # convert all_captions(이미지에 대한 캡션 5개) to tokens: References
            all_captions_translated = [
                [
                    tokenizer.idx_to_token(idx) for idx in caption
                    if idx not in stop_words_idx
                ] 
                for caption in all_captions
            ]
 
            candidates.append(prediction_translated)
            references.append(all_captions_translated)

    return corpus_bleu(references, candidates, emulate_multibleu=True)

print('BLEU Score: {}'.format(10 * test(model, val_dataset, 100, device=device)))

BLEU Score: 19.34


# 5. 프로파일링

In [None]:
import os
import numpy as np
import torch
from torch import nn
import torch.profiler as profiler

device = 'cuda' if torch.cuda.is_available() else 'cpu'

images = torch.rand((32, 3, 256, 256), device=device)
captions = torch.randint(32, (32, 64), device=device)
caplens = torch.randint(1, 64, (32, 1), device=device)

profile_dir = os.path.join(BASE_DIR, 'profile')

with profiler.profile(
    activities=[
        profiler.ProfilerActivity.CPU, 
        profiler.ProfilerActivity.CUDA
    ],
    with_stack=True, 
    record_shapes=True, 
    profile_memory=True
) as prof:
    predictions, captions, caplens, sort_ind = model(
        images, 
        captions, 
        caplens, 
        teacher_forcing_ratio=0.5
    )

print(prof.key_averages(group_by_stack_n=10).table(sort_by='self_cuda_time_total'))
prof.export_chrome_trace(os.path.join(profile_dir, 'profile-resnet50.pt.trace.json'))

In [None]:
%load_ext tensorboard
%tensorboard --logdir '/content/drive/MyDrive/image_caption/profile'

# 6. 실전 예측

In [None]:
import operator
from queue import PriorityQueue
from tqdm import tqdm


class BeamSearchNode:
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length

    def eval(self, alpha=1.0):
        reward = 0
        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward


def predict_with_beam_search(input_data, model_class, model_path, tokenizer):
    # input_data: 하나의 이미지 인풋 (imread 로 읽어들인)!
    # model_class: 모델 클래스
    # model_path: 우리가 모델을 저장한 곳!
    encoder = model_class.encoder
    decoder = model_class.decoder

    checkpoint = torch.load(model_path, map_location='cpu')
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    model_class.eval()

    input_data = input_data.to(device)
    encoder_outputs = encoder(input_data.unsqueeze(0)).view(1, -1, encoder_feature_size)

    # start symbol
    start_symbol = [tokenizer.token_to_idx('[CLS]')]
    decoder_input = torch.tensor(start_symbol).to(device)
    hidden = torch.zeros(1, decoder_hidden_size).to(device)
    cell = torch.zeros(1, decoder_hidden_size).to(device)

    max_length = 50
    decoded_batch = []
    attention_weights = []

    # Number of sentence to generate
    beam_width = 10
    topk = 1  # how many sentence do you want to generate

    end_nodes = []
    number_required = min((topk + 1), topk - len(end_nodes))

    # starting node -  hidden vector, previous node, word id, logp, length
    node = BeamSearchNode(hidden, None, decoder_input, 0, 1)
    nodes = PriorityQueue()

    # start the queue
    nodes.put((-node.eval(), node))
    qsize = 1

    for i in range(max_length):
        if qsize > 2000:
            break

        # fetch the best node
        score, n = nodes.get()
        decoder_input = n.wordid
        decoder_hidden = n.h

        if n.wordid.item() == tokenizer.token_to_idx('[SEP]') and n.prevNode != None:
            end_nodes.append((score, n))
            # if we reached maximum # of sentences required
            if len(end_nodes) >= number_required:
                break
            else:
                continue

        output, hidden, cell, attn = decoder(encoder_outputs, decoder_input, hidden, cell)
        attention_weights.append(attn.detach().cpu().numpy())

        # PUT HERE REAL BEAM SEARCH OF TOP
        log_prob, indices = torch.topk(output, beam_width)
        next_nodes = []

        for new_k in range(beam_width):
            decoded_t = indices[0][new_k].unsqueeze(0)
            log_p = log_prob[0][new_k].item()

            node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
            score = -node.eval()
            next_nodes.append((score, node))

        # put them into queue
        for i in range(len(next_nodes)):
            score, nn = next_nodes[i]
            nodes.put((score, nn))
        
        # increase qsize
        qsize += len(next_nodes) - 1

    # choose nbest paths, back trace them
    if len(end_nodes) == 0:
        end_nodes = [nodes.get() for _ in range(topk)]

    for score, n in sorted(end_nodes, key=operator.itemgetter(0)):
        tokens = []
        tokens.append(n.wordid.item())
        
        # back trace
        while n.prevNode != None:
            n = n.prevNode
            tokens.append(n.wordid.item())

        tokens = tokens[::-1]
        
        stop_words = [tokenizer.token_to_idx('[CLS]'), tokenizer.token_to_idx('[SEP]')]
        sentence = [tokenizer.idx_to_token(idx) for idx in tokens if idx not in stop_words]

        decoded_batch.append(sentence)

    return decoded_batch, attention_weights

In [None]:
def predict_with_greedy(input_data, model_class, model_path, tokenizer):
    # input_data: 하나의 이미지 인풋 (imread 로 읽어들인)!
    # model_class: 모델 클래스
    # model_path: 우리가 모델을 저장한 곳!
    encoder = model_class.encoder
    decoder = model_class.decoder

    checkpoint = torch.load(model_path, map_location='cpu')
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    model_class.eval()

    input_data = input_data.to(device)
    encoder_outputs = encoder(input_data.unsqueeze(0)).view(1, -1, encoder_feature_size)

    # start symbol
    start_symbol = [tokenizer.token_to_idx('[CLS]')]
    decoder_input = torch.tensor(start_symbol).to(device)
    hidden = torch.zeros(1, decoder_hidden_size).to(device)
    cell = torch.zeros(1, decoder_hidden_size).to(device)

    max_length = 50
    tokens = []
    attention_weights = []

    for i in range(max_length):
        output, hidden, cell, attn = decoder(encoder_outputs, decoder_input, hidden, cell)
        top1 = output.argmax(1)
        decoder_input = top1

        token = tokenizer.idx_to_token(top1)
        if token == '[SEP]':
            break

        tokens.append(token)
        attention_weights.append(attn.detach().cpu().numpy())

    return [tokens], attention_weights

In [None]:
!apt-get update -qq
!apt-get install fonts-nanum* -qq

In [None]:
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as mfm


def plot_attention(image, tokens, attention_weights):
    image = image.permute(1, 2, 0)  # [height, width, channels]
    
    num_elements = len(tokens)
    num_cols = 5
    num_rows = math.ceil(num_elements / num_cols)

    figure = plt.figure(figsize=(20, 20))
    prop = mfm.FontProperties(fname='/usr/share/fonts/truetype/nanum/NanumBarunGothic.ttf', size=14)

    for i in range(num_elements):
        temp_att = np.resize(attention_weights[i], (16, 16))
        ax = figure.add_subplot(num_rows, num_cols, i + 1)
        ax.set_title(tokens[i], fontproperties=prop)
        img = ax.imshow(image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())

    plt.tight_layout()
    plt.show()

In [None]:
import os
import glob
import torch
import numpy as np
from PIL import Image
from torchvision import transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model parameters
encoder_feature_size = 2048
decoder_hidden_size = 1024
attention_size = 512
dropout = 0.5

model = ImageCaptioner(tokenizer, encoder_feature_size, decoder_hidden_size, attention_size, dropout=dropout).to(device)

BASE_DIR = os.getcwd()
sample_image_folder = os.path.join(BASE_DIR, 'data', 'sample')
sample_images = glob.glob(os.path.join(sample_image_folder, '*'))
model_path = os.path.join(BASE_DIR, 'output', 'model', 'savepoint-resnet50.pt')

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# [channels, height, width]
images = [transform(Image.open(img).convert('RGB')) for img in sample_images]
num_images = len(images)

for i in range(num_images):
    tokens, attn_weights = predict_with_greedy(images[i], model, model_path, tokenizer)
    plot_attention(images[i], tokens[0], attn_weights)

    tokens, attn_weights = predict_with_beam_search(images[i], model, model_path, tokenizer)
    plot_attention(images[i], tokens[0], attn_weights)