In [1]:
import torch
import os
from train import id_to_string
from metrics import word_error_rate, sentence_acc
from checkpoint import load_checkpoint
from torchvision import transforms
from dataset import LoadEvalDataset, collate_eval_batch, START, PAD
from flags import Flags
from utils import get_network, get_optimizer
import csv
from torch.utils.data import DataLoader
import argparse
import random
from tqdm import tqdm

In [2]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
from collections import OrderedDict

from dataset import START, PAD

- "--checkpoint" 에 불러올 .pth 파일 주소

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument(
        "--checkpoint1",
        dest="checkpoint1",
        default="./log/ensemble/checkpoints/0034.pth",
        type=str,
        help="Path of checkpoint file",
)

parser.add_argument(
        "--checkpoint2",
        dest="checkpoint2",
        default="./log/ensemble/checkpoints/0035.pth",
        type=str,
        help="Path of checkpoint file",
)

parser.add_argument(
        "--checkpoint3",
        dest="checkpoint3",
        default="./log/ensemble/checkpoints/0036.pth",
        type=str,
        help="Path of checkpoint file",
)

parser.add_argument(
        "--max_sequence",
        dest="max_sequence",
        default=20,
        type=int,
        help="maximun sequence when doing inference",
)

parser.add_argument(
        "--batch_size",
        dest="batch_size",
        default=1,
        type=int,
        help="batch size when doing inference",
)

eval_dir = os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/')
file_path = os.path.join(eval_dir, 'eval_dataset/input.txt')
parser.add_argument(
        "--file_path",
        dest="file_path",
        default=file_path,
        type=str,
        help="file path when doing inference",
)

output_dir = os.environ.get('SM_OUTPUT_DATA_DIR', 'submit')
parser.add_argument(
        "--output_dir",
        dest="output_dir",
        default=output_dir,
        type=str,
        help="output directory",
)

parser = parser.parse_args([])

In [4]:
is_cuda = torch.cuda.is_available()

checkpoint1 = load_checkpoint(parser.checkpoint1, cuda=is_cuda)
checkpoint2 = load_checkpoint(parser.checkpoint2, cuda=is_cuda)
checkpoint3 = load_checkpoint(parser.checkpoint3, cuda=is_cuda)
options1 = Flags(checkpoint1["configs"]).get()
options2 = Flags(checkpoint2["configs"]).get()
options3 = Flags(checkpoint3["configs"]).get()

In [5]:
torch.manual_seed(options1.seed)
random.seed(options1.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
hardware = "cuda" if is_cuda else "cpu"
device = torch.device(hardware)
print("--------------------------------")
print("Running {} on device {}\n".format(options1.network, device))

--------------------------------
Running SATRN on device cuda



In [7]:
model_checkpoint1 = checkpoint1["model"]
if model_checkpoint1:
        print(
            "[+] Checkpoint\n",
            "Resuming from epoch : {}\n".format(checkpoint1["epoch"]),
        )
print(options1.input_size.height)

[+] Checkpoint
 Resuming from epoch : 34

48


In [8]:
model_checkpoint2 = checkpoint2["model"]
if model_checkpoint2:
        print(
            "[+] Checkpoint\n",
            "Resuming from epoch : {}\n".format(checkpoint2["epoch"]),
        )
print(options2.input_size.height)

[+] Checkpoint
 Resuming from epoch : 35

48


In [9]:
model_checkpoint3 = checkpoint3["model"]
if model_checkpoint3:
        print(
            "[+] Checkpoint\n",
            "Resuming from epoch : {}\n".format(checkpoint3["epoch"]),
        )
print(options3.input_size.height)

[+] Checkpoint
 Resuming from epoch : 35

48


### Get Data

In [10]:
transformed = transforms.Compose(
        [
            transforms.Resize((options1.input_size.height, options1.input_size.width)),
            transforms.ToTensor(),
        ]
    )

dummy_gt = "\sin " * parser.max_sequence  # set maximum inference sequence

root = os.path.join(os.path.dirname(parser.file_path), "images")
with open(parser.file_path, "r") as fd:
    reader = csv.reader(fd, delimiter="\t")
    data = list(reader)
test_data = [[os.path.join(root, x[0]), x[0], dummy_gt] for x in data]
test_dataset = LoadEvalDataset(
    test_data, checkpoint1["token_to_id"], checkpoint1["id_to_token"], crop=False, transform=transformed,
    rgb=options1.data.rgb
)

test_data_loader = DataLoader(
        test_dataset,
        batch_size=parser.batch_size,
        shuffle=False,
        num_workers=options1.num_workers,
        collate_fn=collate_eval_batch,
)

print(
        "[+] Data\n",
        "The number of test samples : {}\n".format(len(test_dataset)),
    )

[+] Data
 The number of test samples : 32



### Get Network

In [11]:
model1 = get_network(
        options1.network,
        options1,
        model_checkpoint1,
        device,
        test_dataset,
    )
model1.eval()
print()




In [12]:
encoder1 = model1.encoder
embedding1 = model1.decoder.embedding
pos_encoder1 = model1.decoder.pos_encoder
attention_layers1 = model1.decoder.attention_layers
generator1 = model1.decoder.generator

In [13]:
model2 = get_network(
        options2.network,
        options2,
        model_checkpoint2,
        device,
        test_dataset,
    )
model2.eval()
print()




In [14]:
encoder2 = model2.encoder
embedding2 = model2.decoder.embedding
pos_encoder2 = model2.decoder.pos_encoder
attention_layers2 = model2.decoder.attention_layers
generator2 = model2.decoder.generator

In [15]:
model3 = get_network(
        options3.network,
        options3,
        model_checkpoint3,
        device,
        test_dataset,
    )
model3.eval()
print()




In [16]:
encoder3 = model3.encoder
embedding3 = model3.decoder.embedding
pos_encoder3 = model3.decoder.pos_encoder
attention_layers3 = model3.decoder.attention_layers
generator3 = model3.decoder.generator

In [17]:
st_id = test_dataset.token_to_id['<SOS>']
pad_id = test_dataset.token_to_id['<PAD>']
layer_num = len(model1.decoder.attention_layers)

In [18]:
def pad_mask(text):
        pad_mask = text == pad_id
        pad_mask[:, 0] = False
        pad_mask = pad_mask.unsqueeze(1)

        return pad_mask

def order_mask(length):
        order_mask = torch.triu(torch.ones(length, length), diagonal=1).bool()
        order_mask = order_mask.unsqueeze(0).to(device)
        return order_mask

def text_embedding1(texts):
        tgt = embedding1(texts)
        tgt *= math.sqrt(tgt.size(2))

        return tgt
    
def text_embedding2(texts):
        tgt = embedding2(texts)
        tgt *= math.sqrt(tgt.size(2))

        return tgt

def text_embedding3(texts):
        tgt = embedding3(texts)
        tgt *= math.sqrt(tgt.size(2))

        return tgt    

### Ensemble Decode

In [19]:
results = []

for d in tqdm(test_data_loader):
    input = d["image"].to(device)
    expected = d["truth"]["encoded"].to(device)
    
    src1 = encoder1(input)
    src2 = encoder2(input)
    src3 = encoder3(input)
    
    text = expected[:, :-1]
    is_train = False
    batch_max_length = 230
    teacher_forcing_ratio = 0
    
    out = []
    num_steps = batch_max_length - 1
    target = torch.LongTensor(src1.size(0)).fill_(st_id).to(device) # [START] token
    features = [None] * layer_num
    
    features1 = []
    features2 = []
    features3 = []
    if features[0] != None:
        for i in range(len(features)):
            features1.append(copy.deepcopy(features[i].detach()))
    
        for i in range(len(features)):
            features2.append(copy.deepcopy(features[i].detach()))
            
        for i in range(len(features)):
            features3.append(copy.deepcopy(features[i].detach()))
            
    else:
        features1 = [None] * layer_num
        features2 = [None] * layer_num
        features3 = [None] * layer_num
        
    
    for t in range(num_steps):
        target = target.unsqueeze(1)
        tgt1 = text_embedding1(target)
        tgt2 = text_embedding2(target)
        tgt3 = text_embedding3(target)
        
        tgt1 = pos_encoder1(tgt1, point=t)
        tgt2 = pos_encoder2(tgt2, point=t)
        tgt3 = pos_encoder3(tgt3, point=t)
        
        tgt_mask = order_mask(t + 1).to(device)
        tgt_mask = tgt_mask[:, -1].unsqueeze(1)  # [1, (l+1)]
        
        for l, layer in enumerate(zip(attention_layers1, attention_layers2, attention_layers3)):
            tgt1 = layer[0](tgt1, features[l], src1, tgt_mask)
            features1[l] = (
                tgt1 if features1[l] == None else torch.cat([features1[l], tgt1], 1)
            )
            
            tgt2 = layer[1](tgt2, features[l], src2, tgt_mask)
            features2[l] = (
                tgt2 if features2[l] == None else torch.cat([features2[l], tgt2], 1)
            )
            
            tgt3 = layer[1](tgt3, features[l], src3, tgt_mask)
            features3[l] = (
                tgt3 if features3[l] == None else torch.cat([features3[l], tgt3], 1)
            )
        
        for i in range(len(features)):
            features[i] = (features1[i] + features2[i] + features3[i]) / 3
            
        _out1 = generator1(tgt1)  # [b, 1, c]
        _out2 = generator2(tgt2)  # [b, 1, c]
        _out3 = generator2(tgt3)  # [b, 1, c]
        
        _out = _out1 + _out2 + _out3
        
        target = torch.argmax(_out[:, -1:, :], dim=-1)  # [b, 1]
        target = target.squeeze(1)   # [b]
        out.append(_out)
        
    out = torch.stack(out, dim=1).to(device)    # [b, max length, 1, class length]
    out = out.squeeze(2)    # [b, max length, class length]
    
    decoded_values = out.transpose(1, 2)
    _, sequence = torch.topk(decoded_values, 1, dim=1)
    sequence = sequence.squeeze(1)
    sequence_str = id_to_string(sequence, test_data_loader, do_eval=1)
    
    for path, predicted in zip(d["file_path"], sequence_str):
        results.append((path, predicted))  

100%|██████████| 32/32 [02:52<00:00,  5.45s/it]


In [20]:
os.makedirs(parser.output_dir, exist_ok=True)
with open(os.path.join(parser.output_dir, "output_ensemble.csv"), "w") as w:
    for path, predicted in results:
        w.write(path + "\t" + predicted + "\n")

In [None]:
results

In [21]:
data

[['train_00000.jpg', '4 \\times 7 = 2 8'],
 ['train_00001.jpg', 'a ^ { x } > q'],
 ['train_00002.jpg', '8 \\times 9'],
 ['train_00003.jpg',
  '\\sum _ { k = 1 } ^ { n - 1 } b _ { k } = a _ { n } - a _ { 1 }'],
 ['train_00004.jpg', 'I = d q / d t'],
 ['train_00005.jpg', '\\sum \\overrightarrow { F } _ { e x t } = d'],
 ['train_00006.jpg', 'i ^ { 2 } = - 1 \\left( i = \\sqrt { - 1 } \\right)'],
 ['train_00007.jpg', '7 \\times 9 = 4 9'],
 ['train_00008.jpg',
  'F \\left( 0 , \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right) , \\left( 0 , - \\sqrt { a ^ { 2 } + b ^ { 2 } } \\right)'],
 ['train_00009.jpg', '\\left( a - 2 \\right) \\left( a - 3 \\right) = 0'],
 ['train_00010.jpg', '\\therefore b = - 9'],
 ['train_00011.jpg', '2 2 + 7 - 1 2 ='],
 ['train_00012.jpg', '7 \\div 4'],
 ['train_00013.jpg', 'f \\left( x \\right) = 4 x ^ { 3 }'],
 ['train_00014.jpg',
  'M P _ { e } = \\lim _ { \\Delta l \\to 0 } \\frac { g \\left( l + \\Delta l \\right) - g \\left( l \\right) } { \\Delta l } = \\frac { d g }