In [1]:
import torch
import os
from train import id_to_string
from metrics import word_error_rate, sentence_acc
from checkpoint import load_checkpoint
from torchvision import transforms
from dataset import LoadEvalDataset, collate_eval_batch, START, PAD
from flags import Flags
from utils import get_network, get_optimizer
import csv
from torch.utils.data import DataLoader
import argparse
import random
from tqdm import tqdm

In [2]:
import copy 
import math
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

- "--checkpoint" 에 불러올 .pth 파일 주소

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument(
        "--checkpoint",
        dest="checkpoint",
        default="./log/satrn_adaptive/checkpoints/0050.pth",
        type=str,
        help="Path of checkpoint file",
)
parser.add_argument(
        "--max_sequence",
        dest="max_sequence",
        default=20,
        type=int,
        help="maximun sequence when doing inference",
)
parser.add_argument(
        "--batch_size",
        dest="batch_size",
        default=8,
        type=int,
        help="batch size when doing inference",
)

eval_dir = os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/')
file_path = os.path.join(eval_dir, 'eval_dataset/input.txt')
parser.add_argument(
        "--file_path",
        dest="file_path",
        default=file_path,
        type=str,
        help="file path when doing inference",
)

output_dir = os.environ.get('SM_OUTPUT_DATA_DIR', 'submit')
parser.add_argument(
        "--output_dir",
        dest="output_dir",
        default=output_dir,
        type=str,
        help="output directory",
)

parser = parser.parse_args([])

In [4]:
is_cuda = torch.cuda.is_available()
checkpoint = load_checkpoint(parser.checkpoint, cuda=is_cuda)
options = Flags(checkpoint["configs"]).get()
torch.manual_seed(options.seed)
random.seed(options.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
hardware = "cuda" if is_cuda else "cpu"
device = torch.device(hardware)
print("--------------------------------")
print("Running {} on device {}\n".format(options.network, device))

--------------------------------
Running SATRN on device cuda



In [6]:
model_checkpoint = checkpoint["model"]
if model_checkpoint:
    print(
            "[+] Checkpoint\n",
            "Resuming from epoch : {}\n".format(checkpoint["epoch"]),
    )

[+] Checkpoint
 Resuming from epoch : 50



## Get Data

In [7]:
transformed = transforms.Compose(
        [
            transforms.Resize((options.input_size.height, options.input_size.width)),
            transforms.ToTensor(),
        ]
    )

In [8]:
dummy_gt = "\sin " * parser.max_sequence  # set maximum inference sequence

In [9]:
root = os.path.join(os.path.dirname(parser.file_path), "images")
with open(parser.file_path, "r") as fd:
    reader = csv.reader(fd, delimiter="\t")
    data = list(reader)

In [10]:
test_data = [[os.path.join(root, x[0]), x[0], dummy_gt] for x in data]
test_dataset = LoadEvalDataset(
        test_data, checkpoint["token_to_id"], checkpoint["id_to_token"], crop=False, transform=transformed,
        rgb=options.data.rgb
    )
test_data_loader = DataLoader(
        test_dataset,
        batch_size=parser.batch_size,
        shuffle=False,
        num_workers=options.num_workers,
        collate_fn=collate_eval_batch,
    )

In [11]:
print(
        "[+] Data\n",
        "The number of test samples : {}\n".format(len(test_dataset)),
    )

[+] Data
 The number of test samples : 32



## Get Network

In [23]:
model = get_network(
        options.network,
        options,
        model_checkpoint,
        device,
        test_dataset,
    )
model.eval()
print()




In [13]:
encoder = model.encoder
embedding = model.decoder.embedding
pos_encoder = model.decoder.pos_encoder
attention_layers = model.decoder.attention_layers
generator = model.decoder.generator

In [24]:
st_id = test_dataset.token_to_id['<SOS>']
pad_id = test_dataset.token_to_id['<PAD>']
layer_num = len(model.decoder.attention_layers)

In [19]:
def pad_mask(text):
        pad_mask = text == pad_id
        pad_mask[:, 0] = False
        pad_mask = pad_mask.unsqueeze(1)

        return pad_mask

def order_mask(length):
        order_mask = torch.triu(torch.ones(length, length), diagonal=1).bool()
        order_mask = order_mask.unsqueeze(0).to(device)
        return order_mask

def text_embedding(texts):
        tgt = embedding(texts)
        tgt *= math.sqrt(tgt.size(2))

        return tgt

In [20]:
class BeamSearchNode(object):
    def __init__(self, src_batch, previousNode, target_batch, features, logProb, length):
        self.src_batch = src_batch
        self.prevNode = previousNode
        self.target_batch = target_batch
        self.features = features
        self.logp = logProb
        self.leng = length
        
    def eval(self, alpha = 1.0):
        reward = 0
        
        return self.logp / float(self.leng - 1 +1e6) + alpha * reward

## Beam Search decode

In [21]:
results = []

In [22]:
for d in tqdm(test_data_loader):
    input = d["image"].to(device)
    expected = d["truth"]["encoded"].to(device)
    
    src = encoder(input)
    text = expected[:, :-1]
    is_train = False
    batch_max_length = 230
    teacher_forcing_ratio = 0
    
    beam_width = 5
    topk = 1  # how many sentence do you want to generate
    decoded_batch = []
    
    for idx in range(src.size(0)):
        num_steps = batch_max_length - 1
        target = torch.LongTensor(src.size(0)).fill_(st_id).to(device) # [START] token
        features = [None] * layer_num

        target_batch = target[idx].unsqueeze(-1)
        target_batch = target_batch.unsqueeze(-1)
        src_batch = src[idx,:,:].unsqueeze(0)

        endnodes = []
        number_required = 5

        node = BeamSearchNode(src_batch, None, target_batch, features, 0, 1)
        nodes = PriorityQueue()

        nodes.put((-node.eval(), node))
        flag = 0

        # each step 
        for t in range(num_steps):        
            if flag == 1:
                break
            nextnodes = []

            # ecah beam search candidate
            while nodes.queue != []:

                score, n = nodes.get()
                target_batch = n.target_batch
                src_batch = n.src_batch
                if n.features[0] != None:
                    features = []
                    for i in range(len(n.features)):
                        with torch.no_grad():
                            features.append(copy.deepcopy(n.features[i].detach()))
                else:
                    features = n.features

                tgt = text_embedding(target_batch)
                tgt = pos_encoder(tgt, point=t)
                tgt_mask = order_mask(t + 1)
                tgt_mask = tgt_mask[:, -1].unsqueeze(1) 
                        
                if n.target_batch.item() == 1 and n.prevNode != None:
                    endnodes.append((score, n))
                    # if we reached maximum # of sentences required
                    if len(endnodes) >= number_required:
                        flag = 1
                        break
                    else:
                        continue
                        
                # each attention layer 
                for l, layer in enumerate(attention_layers):
                    tgt = layer(tgt, features[l], src_batch, tgt_mask)
                    features[l] = (tgt if features[l] == None else torch.cat([features[l], tgt], 1))

                _out = generator(tgt)
                log_prob, indexes = torch.topk(_out, beam_width)

                for new_k in range(beam_width):
                    decoded_t = indexes[0][0][new_k].view(1, -1)
                    log_p = log_prob[0][0][new_k].item()

                    node = BeamSearchNode(src_batch, n, decoded_t, features, n.logp + log_p, n.leng + 1)
                    score = -node.eval()
                    nextnodes.append((score, node))
            
            if flag == 0:
                #sorting nextnode
                sorted(nextnodes, key=operator.itemgetter(0))
                # put them into queue
                for i in range(beam_width):
                    score, nn = nextnodes[i]
                    nodes.put((score, nn))

                        
            
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(1)]
            
            
        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.target_batch)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.target_batch)

            del utterance[-1]
            utterance = utterance[::-1]
            utterances.append(utterance)
            break

        seq_len = len(utterances[0])
        if num_steps - seq_len > 0:
            for i in range(num_steps - seq_len):
                utterances[0].append(torch.LongTensor(1).fill_(2).to(device))
        decoded_batch.append(utterances)
            

            
    sequence = torch.tensor(decoded_batch)
    sequence = sequence.squeeze(1)
    sequence_str = id_to_string(sequence, test_data_loader, do_eval=1)
    for path, predicted in zip(d["file_path"], sequence_str):
            results.append((path, predicted))


100%|██████████| 4/4 [00:35<00:00,  8.97s/it]


In [156]:
os.makedirs(parser.output_dir, exist_ok=True)
with open(os.path.join(parser.output_dir, "output1.csv"), "w") as w:
    for path, predicted in results:
        w.write(path + "\t" + predicted + "\n")

- PR

In [157]:
results

[('train_00000.jpg', '4 \\times 7 = 2 8 '),
 ('train_00001.jpg', 'a ^ { x } > q '),
 ('train_00002.jpg', '8 \\times 9 '),
 ('train_00003.jpg',
  '\\sum _ { k = 1 } ^ { n - 1 } b _ { k } = a _ { n } - a _ { 1 } '),
 ('train_00004.jpg', 'I = d q / d t t '),
 ('train_00005.jpg', '\\sum \\overrightarrow { F } _ { e x t } = d d '),
 ('train_00006.jpg', 'i ^ { 2 } = - 1 \\left( i = \\sqrt { - 1 } \\right) '),
 ('train_00007.jpg', '7 \\times 9 = 4 9 9 '),
 ('train_00008.jpg',
  'F \\left( 0 , \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right) \\left( 0 , - \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right) '),
 ('train_00009.jpg', '\\left( a - 2 \\right) \\left( a - 3 \\right) = 0 '),
 ('train_00010.jpg', '\\therefore b = - 9 9 '),
 ('train_00011.jpg', '2 2 + 7 - 1 2 = '),
 ('train_00012.jpg', '7 \\div 4 '),
 ('train_00013.jpg', 'f \\left( x \\right) = 4 x ^ { 3 '),
 ('train_00014.jpg',
  'M P _ { l } = \\lim _ { \\Delta l \\to 0 } \\frac { g \\left( l + \\Delta l - g l \\right) \\right) \\right) } } } = = \\f

- G.T

In [144]:
data

[['train_00000.jpg', '4 \\times 7 = 2 8'],
 ['train_00001.jpg', 'a ^ { x } > q'],
 ['train_00002.jpg', '8 \\times 9'],
 ['train_00003.jpg',
  '\\sum _ { k = 1 } ^ { n - 1 } b _ { k } = a _ { n } - a _ { 1 }'],
 ['train_00004.jpg', 'I = d q / d t'],
 ['train_00005.jpg', '\\sum \\overrightarrow { F } _ { e x t } = d'],
 ['train_00006.jpg', 'i ^ { 2 } = - 1 \\left( i = \\sqrt { - 1 } \\right)'],
 ['train_00007.jpg', '7 \\times 9 = 4 9'],
 ['train_00008.jpg',
  'F \\left( 0 , \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right) , \\left( 0 , - \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right)'],
 ['train_00009.jpg', '\\left( a - 2 \\right) \\left( a - 3 \\right) = 0'],
 ['train_00010.jpg', '\\therefore b = - 9'],
 ['train_00011.jpg', '2 2 + 7 - 1 2 ='],
 ['train_00012.jpg', '7 \\div 4'],
 ['train_00013.jpg', 'f \\left( x \\right) = 4 x ^ { 3 }'],
 ['train_00014.jpg',
  'M P _ { e } = \\lim _ { \\Delta l \\to 0 } \\frac { g \\left( l + \\Delta l \\right) - g \\left( l \\right) } { \\Delta l } = \\frac { d g }