In [1]:
import os
import argparse
import multiprocessing
import numpy as np
import random
import time
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import yaml
import PIL
from tqdm import tqdm
from checkpoint import (
    default_checkpoint,
    load_checkpoint,
    save_checkpoint,
    init_tensorboard,
    write_tensorboard,
)
from psutil import virtual_memory

from flags import Flags
from utils import get_network, get_optimizer
from dataset import dataset_loader, START, PAD,load_vocab
from scheduler import CircularLRBeta

from metrics import word_error_rate,sentence_acc

In [2]:
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument(
        "-c",
        "--config_file",
        dest="config_file",
        default="configs/SATRN.yaml",
        type=str,
        help="Path of configuration file",
    )
parser = parser.parse_args(args=[])

options = Flags(parser.config_file).get()

In [4]:
is_cuda = torch.cuda.is_available()
hardware = "cuda" if is_cuda else "cpu"
device = torch.device(hardware)

In [5]:
checkpoint = (
        load_checkpoint(options.checkpoint, cuda=is_cuda)
        if options.checkpoint != ""
        else default_checkpoint
    )
model_checkpoint = checkpoint["model"]

In [6]:
transformed = transforms.Compose(
        [
            # Resize so all images have the same size
            transforms.Resize((options.input_size.height, options.input_size.width)),
            transforms.ToTensor(),
        ])

train_data_loader, validation_data_loader, train_dataset, valid_dataset = dataset_loader(options, transformed)

In [7]:
data_ex = next(iter(train_dataset))
print(data_ex)

{'path': '/opt/ml/input/data/train_dataset/images/train_77821.jpg', 'truth': {'text': '\\therefore \\overline { B H } = 3 + 3 = 6 \\left( k m \\right)', 'encoded': [0, 191, 196, 224, 237, 166, 213, 180, 111, 108, 111, 180, 15, 229, 169, 149, 137, 1]}, 'image': tensor([[[0.7569, 0.7608, 0.7608,  ..., 0.6941, 0.6941, 0.6902],
         [0.7529, 0.7569, 0.7647,  ..., 0.7020, 0.6902, 0.6941],
         [0.7529, 0.7608, 0.7569,  ..., 0.6941, 0.6902, 0.6980],
         ...,
         [0.7647, 0.7686, 0.7765,  ..., 0.6667, 0.6667, 0.6667],
         [0.7647, 0.7686, 0.7765,  ..., 0.6745, 0.6627, 0.6588],
         [0.7647, 0.7686, 0.7725,  ..., 0.6706, 0.6667, 0.6627]]])}


In [8]:
data = next(iter(train_data_loader))
input = data["image"].cuda()
expected = data["truth"]["encoded"].cuda()

In [9]:
print(input.shape)
print(data['truth']['text'])
print(len(data['truth']['text'][0].split(' ')))

torch.Size([36, 1, 128, 128])
['\\left( \\sqrt \\left[ n \\right] { a } \\right) ^ { m } = \\sqrt \\left[ n \\right] { a ^ { m } }', '= \\frac { F _ { f } } { 2 }', '5 6 - 4 2 = 1 4', '4 9 \\times \\frac { 3 } { 7 } =', '2 \\times \\left( - 2 \\right) + 3', '4 \\times 1 0 ^ { 1 0 }', '\\therefore \\angle x = 4 5 ^ { \\circ }', '1 8 . 6 \\times 4 + 7 \\times 2', 'a ^ { 2 } + b ^ { 2 } = 2 1 , a - b = 3', '\\alpha = \\frac { \\sum \\tau _ { e x t } } { I } = \\frac { 2 \\left( m _ { f } - m _ { d } \\right) g \\cos \\theta } { l \\left[ \\left( M / 3 \\right) + m _ { f } + m _ { d } \\right] }', 'x = \\frac { - \\left( b \\right) \\pm \\sqrt { \\left( b \\right) ^ { 2 } - 4 a c } } { 2 a }', '\\overline { P _ { 1 } P _ { 2 } }', '\\frac { 3 \\sqrt { 3 } } { 4 }', '\\frac { 5 } { 1 2 } \\times 4 =', '{ C H } = h \\tan 3 2 ^ { \\circ } \\left( { m } \\right)', '\\left( x - 4 \\right) \\left( x - 8 \\right) = 0', 'n \\to \\infty', 'x = - \\frac { 5 } { 3 }', ', m x - y + m + 1 = 0', '1 - 0 

In [12]:
print(expected)
print(len(expected[0]))

tensor([[  0, 229,  13,  ...,  -1,  -1,  -1],
        [  0, 180,  82,  ...,  -1,  -1,  -1],
        [  0,  70,  15,  ...,  -1,  -1,  -1],
        ...,
        [  0,  96, 180,  ...,  -1,  -1,  -1],
        [  0, 204, 180,  ...,  -1,  -1,  -1],
        [  0, 178,  10,  ...,  -1,  -1,  -1]], device='cuda:0')
61


In [13]:
expected[expected == -1] = train_data_loader.dataset.token_to_id[PAD]

## Encoder

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random

from dataset import START, PAD
from networks.SATRN import TransformerEncoderFor2DFeatures, TransformerDecoder

In [16]:
input_size=options.data.rgb
hidden_dim=options.SATRN.encoder.hidden_dim
filter_size=options.SATRN.encoder.filter_dim
head_num=options.SATRN.encoder.head_num
layer_num=options.SATRN.encoder.layer_num
dropout_rate=options.dropout_rate

In [17]:
encoder = TransformerEncoderFor2DFeatures(input_size, hidden_dim, filter_size, head_num, layer_num, dropout_rate)
encoder.cuda()

TransformerEncoderFor2DFeatures(
  (shallow_cnn): DeepCNN300(
    (conv0): Conv2d(1, 48, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (block1): DenseBlock(
      (block): Sequential(
        (0): BottleneckBlock(
          (norm1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv1): Conv2d(48, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(72, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (1): BottleneckBlock(
          (norm1): BatchNorm2d(72, eps=1e-05, 

In [18]:
encoder_result = encoder(input)
encoder_result.shape

torch.Size([36, 256, 300])

## Decoder

In [19]:
from networks.SATRN import PositionEncoder1D, TransformerDecoderLayer

In [20]:
num_classes=len(train_dataset.id_to_token)
src_dim=options.SATRN.decoder.src_dim
hidden_dim=options.SATRN.decoder.hidden_dim
filter_dim=options.SATRN.decoder.filter_dim
head_num=options.SATRN.decoder.head_num
dropout_rate=options.dropout_rate
pad_id=train_dataset.token_to_id[PAD]
st_id=train_dataset.token_to_id[START]
layer_num=options.SATRN.decoder.layer_num

In [21]:
embedding = nn.Embedding(num_classes + 1, hidden_dim)
embedding.cuda()

Embedding(246, 128)

In [22]:
pos_encoder = PositionEncoder1D(
            in_channels=hidden_dim, dropout=dropout_rate
        )
pos_encoder.cuda()

PositionEncoder1D(
  (dropout): Dropout(p=0.1, inplace=False)
)

In [23]:
attention_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    hidden_dim, src_dim, filter_dim, head_num, dropout_rate
                )
                for _ in range(layer_num)
            ]
        )
attention_layers.cuda()

ModuleList(
  (0): TransformerDecoderLayer(
    (self_attention_layer): MultiHeadAttention(
      (q_linear): Linear(in_features=128, out_features=128, bias=True)
      (k_linear): Linear(in_features=128, out_features=128, bias=True)
      (v_linear): Linear(in_features=128, out_features=128, bias=True)
      (attention): ScaledDotProductAttention(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (out_linear): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (self_attention_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (attention_layer): MultiHeadAttention(
      (q_linear): Linear(in_features=128, out_features=128, bias=True)
      (k_linear): Linear(in_features=300, out_features=128, bias=True)
      (v_linear): Linear(in_features=300, out_features=128, bias=True)
      (attention): ScaledDotProductAttention(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (out_linear): Linear

In [24]:
generator = nn.Linear(hidden_dim, num_classes)
generator.cuda()

Linear(in_features=128, out_features=245, bias=True)

In [25]:
pad_id = pad_id
st_id = st_id

In [26]:
def pad_mask(text):
        pad_mask = text == pad_id
        pad_mask[:, 0] = False
        pad_mask = pad_mask.unsqueeze(1)

        return pad_mask

def order_mask(length):
        order_mask = torch.triu(torch.ones(length, length), diagonal=1).bool()
        order_mask = order_mask.unsqueeze(0).to(device)
        return order_mask

def text_embedding(texts):
        tgt = embedding(texts)
        tgt *= math.sqrt(tgt.size(2))

        return tgt

## Beam Search decode

In [154]:
import copy

In [155]:
class BeamSearchNode(object):
    def __init__(self, src_batch, previousNode, target_batch, features, logProb, length):
        self.src_batch = src_batch
        self.prevNode = previousNode
        self.target_batch = target_batch
        self.features = features
        self.logp = logProb
        self.leng = length
        
    def eval(self, alpha = 1.0):
        reward = 0
        
        return self.logp / float(self.leng - 1 +1e6) + alpha * reward

In [156]:
src = encoder_result
text = expected[:, :-1]
is_train = False
batch_max_length = 10
teacher_forcing_ratio = 0

In [157]:
beam_width = 3
topk = 1  # how many sentence do you want to generate
decoded_batch = []

In [158]:
# each sentence
for idx in range(src.size(0)):
    
    out = []
    num_steps = batch_max_length - 1
    target = torch.LongTensor(src.size(0)).fill_(st_id).to(device) # [START] token
    features = [None] * layer_num
    
    target_batch = target[idx].unsqueeze(-1)
    target_batch = target_batch.unsqueeze(-1)
    src_batch = src[idx,:,:].unsqueeze(0)
    
    endnodes = []
    number_required = 3
    
    node = BeamSearchNode(src_batch, None, target_batch, features, 0, 1)
    nodes = PriorityQueue()
     
    nodes.put((-node.eval(), node))
    flag = 0
    
    # each step 
    for t in range(num_steps):
        
        if flag == 1:
            break
            
        nextnodes = []
        
        # ecah beam search candidate
        while nodes.queue != []:
            
            score, n = nodes.get()
            target_batch = n.target_batch
            src_batch = n.src_batch
            #features = n.features
            if n.features[0] != None:
                features = []
                for i in range(len(n.features)):
                    with torch.no_grad():
                        features.append(copy.deepcopy(n.features[i].detach()))
            else:
                features = n.features
                
            tgt = text_embedding(target_batch)
            tgt = pos_encoder(tgt, point=t)
            tgt_mask = order_mask(t + 1)
            tgt_mask = tgt_mask[:, -1].unsqueeze(1) 
            
            if n.target_batch.item() == 1 and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    flag = 1
                    break
                else:
                    continue
                    
            # each attention layer 
            for l, layer in enumerate(attention_layers):
                tgt = attention_layers[l](tgt, features[l], src_batch, tgt_mask)
                features[l] = (tgt if features[l] == None else torch.cat([features[l], tgt], 1))
            
            _out = generator(tgt)
            log_prob, indexes = torch.topk(_out, beam_width)
            
            for new_k in range(beam_width):
                decoded_t = indexes[0][0][new_k].view(1, -1)
                log_p = log_prob[0][0][new_k].item()

                node = BeamSearchNode(src_batch, n, decoded_t, features, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))
            
        if flag == 0:
            #sorting nextnode
            sorted(nextnodes, key=operator.itemgetter(0))
            # put them into queue
            for i in range(beam_width):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
        
    
    if len(endnodes) == 0:
        endnodes = [nodes.get() for _ in range(1)]
    
    utterances = []
    for score, n in sorted(endnodes, key=operator.itemgetter(0)):
        utterance = []
        utterance.append(n.target_batch)
        # back trace
        while n.prevNode != None:
            n = n.prevNode
            utterance.append(n.target_batch)

        utterance = utterance[:0:-1]
        utterances.append(utterance)
        
    
    seq_len = len(utterances[0])
    if num_steps - seq_len > 0:
        for i in range(num_steps - seq_len):
            utterances[0].append(torch.LongTensor(1).fill_(2).to(device))
    #print(len(utterances[0]))
    decoded_batch.append(utterances)
    #print(decoded_batch[-1])
torch.tensor(decoded_batch)

tensor([[[  0,  49,  69,  69,  70,  53, 164,  41, 137]],

        [[  0, 217, 239, 138, 124, 150, 193,  73, 105]],

        [[  0,   5, 143, 230,  71,  26,  98, 214, 180]],

        [[  0,   5,  15, 230, 228, 121, 143,  81,  29]],

        [[  0, 105,  78,  26,  98, 214, 109,  29, 213]],

        [[  0, 173, 236, 120, 157, 193, 140, 236,  26]],

        [[  0, 217, 236, 143, 193, 175,  29,  29,  29]],

        [[  0,   5, 228, 121, 143, 195, 175,  29,  29]],

        [[  0, 236,  26, 224,  26,  98, 180, 208, 121]],

        [[  0,   5,  32, 200, 164,  41,  26, 206, 171]],

        [[  0, 236,  29, 121, 143, 195, 175, 180,  98]],

        [[  0,   5, 228, 143, 230,  71, 236,  29,  29]],

        [[  0,   5, 228, 143, 193, 143, 193,  45,  58]],

        [[  0,   5, 228, 143, 193,  45, 112, 146, 121]],

        [[  0,  49,  71, 236, 143, 195, 175, 180, 208]],

        [[  0,   5, 193, 236, 171, 102, 108, 236,  26]],

        [[  0,   5,  32, 180, 208, 121, 143, 230, 228]],

        [[  0,