In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
from torch.utils.data import *
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import easydict
import sys
sys.path.append('./Whatiswrong')
sys.path.append('./Scatter')

import re
import six
import math
import torchvision.transforms as transforms

import utils
from utils import *
import augs
import www_model_jamo
import torch.distributed as dist
# from apex.parallel import DistributedDataParallel as DDP
import en_dataset
import ko_dataset
from albumentations import GaussNoise, IAAAdditiveGaussianNoise, Compose, OneOf
from albumentations.pytorch import ToTensor

In [26]:
import importlib
importlib.reload(utils)

<module 'utils' from './Whatiswrong/utils.py'>

### arguements

In [3]:
# opt
opt = easydict.EasyDict({
    "experiment_name" : f'{utils.SaveDir_maker(base_model = "www_jamo", base_model_dir = "./models")}',
    'saved_model' : '',
    "manualSeed" : 1111,
    "imgH" : 35 ,
    "imgW" :  90,
    "PAD" : True ,
    'batch_size' : 128,
    'data_filtering_off' : True,
    'workers' : 20,
    'rgb' :True,
    'sensitive' : True,
    'top_char' : ' !"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉ',
    'middle_char' : ' ㅏㅑㅓㅕㅗㅛㅜㅠㅡㅣㅐㅒㅔㅖㅘㅙㅚㅝㅞㅟㅢ',
    'bottom_char' : ' ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㄳㄵㄶㄺㄻㄼㄽㄾㄿㅀㅄ',
    'batch_max_length' : 25,
    'num_fiducial' : 20,
    'output_channel' : 512,
    'hidden_size' :256,
    'lr' : 1,
    'rho' : 0.95,
    'eps' : 1e-8,
    'grad_clip' : 5,
    'valInterval' : 200,
    'num_epoch' : 100,
    'input_channel' : 3,
    'FT' : True,
    'extract' : 'RCNN'
    })

top_converter = utils.AttnLabelConverter(opt.top_char)
middle_converter = utils.AttnLabelConverter(opt.middle_char)
bottom_converter = utils.AttnLabelConverter(opt.bottom_char)
opt.top_n_cls = len(top_converter.character)
opt.middle_n_cls = len(middle_converter.character)
opt.bottom_n_cls = len(bottom_converter.character)
device = torch.device('cuda') #utils.py 안에 device는 따로 세팅해줘야함

### dataset

In [4]:
""" Seed setting """
# print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed(opt.manualSeed)

# KOREAN
data=[]
# ko_hand = ko_dataset.hand_dataset(num_samples = 100000, dataset_mode = 'word', label_mode = 'jamo') #순서병신
# ko_public = ko_dataset.public_crop(mode = 'jamo') #분리 병신
ko_synthetic = ko_dataset.korean_synthetic(need_samples = 200000, mode='jamo')

# ENGLISH 
# eng_dataset = en_dataset.get_english_dataset(mode='jamo')
# eng_synthetic = en_dataset.en_synthetic(mode='jamo', need_samples=2000000) #0 for all

100%|██████████| 200000/200000 [00:09<00:00, 20277.62it/s]


In [5]:
# data.extend(ko_hand.dataset)
# data.extend(ko_public.dataset)
data.extend(ko_synthetic.dataset)
# data.extend(eng_dataset)
# data.extend(eng_synthetic.dataset)
random.shuffle(data)
print('Total number of data : ', len(data))

Total number of data :  200000


In [16]:
transformers = Compose([
                        OneOf([
#                                   augs.VinylShining(1),
                            augs.GridMask(num_grid=(10,20)),
                            augs.RandomAugMix(severity=1, width=1)], p =0.4),
                            ToTensor()
                       ])
train_custom = utils.CustomDataset_jamo(data[ : int(len(data) * 0.95)], resize_shape = (opt.imgH, opt.imgW), transformer=transformers)
valid_custom = utils.CustomDataset_jamo(data[ int(len(data) * 0.95): ], resize_shape = (opt.imgH, opt.imgW), transformer=ToTensor())

data_loader = DataLoader(train_custom, batch_size = opt.batch_size,  num_workers =15, shuffle=True, drop_last=True, pin_memory=True)
valid_loader = DataLoader(valid_custom, batch_size = opt.batch_size,  num_workers=10, shuffle=True,  drop_last=True, pin_memory=True )

## train

In [27]:
# def Criterion(input, target, size_average=True):
#     """Categorical cross-entropy with logits input and one-hot target"""
#     l = -(target * torch.log(F.softmax(input, dim=1) + 1e-10)).sum(1)
#     if size_average:
#         l = l.mean()
#     else:
#         l = l.sum()
#     return l

def train(opt):
    model = www_model_jamo.STR(opt, device)
    print('model parameters. height {}, width {}, num of fiducial {}, input channel {}, output channel {}, hidden size {}, \
    batch max length {}'.format(opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.batch_max_length))
    
    # weight initialization
    for name, param, in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initializaed')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
                
        except Exception as e :
            if 'weight' in name:
                param.data.fill_(1)
            continue

    if opt.saved_model != '':
        base_path = './models'
        print(f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}')
        
        try :
            model.load_state_dict(torch.load(os.path.join(base_path, opt.saved_model)))
            print('loading complete ')    
        except Exception as e:
            print(e)
            print('coud not find model')
            
    #data parallel for multi GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train() 
     
    # loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) #ignore [GO] token = ignore index 0
    log_avg = Averager()
    
    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p : p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))
    
    # optimizer
    optimizer = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps)
#     optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1)
    
    # opt log
    with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '---------------------Options-----------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log +=f'{str(k)} : {str(v)}\n'
        opt_log +='---------------------------------------------\n'
        opt_file.write(opt_log)
        
    #start training
    
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1

    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(data_loader):
#             image_tensors, top, mid, bot = data_point 
# #             print(f'top : {top}')
# #             print(f'mid : {mid}')
# #             print(f'bot : {bot}')
#             image = image_tensors.to(device)
#             text_top, length_top = top_converter.encode(top, batch_max_length = opt.batch_max_length)
#             text_mid, length_mid = middle_converter.encode(mid, batch_max_length = opt.batch_max_length)
#             text_bot, length_bot = bottom_converter.encode(bot, batch_max_length = opt.batch_max_length)
#             batch_size = image.size(0)
            
# #             onehot = torch.FloatTensor(batch_size, opt.batch_max_length+2, len(opt.character)+2).zero_().to(device)
# #             text_onehot = onehot.scatter(dim = 2, index = text.unsqueeze(2).to(device), value = 1 ) #(bs, batch_max_length, num_characters)
                
#             pred_top, pred_mid, pred_bot = model(image, text_top[:,:-1], text_mid[:,:-1], text_bot[:,:-1])
# #             preds = model(image, text_onehot[:, : -1])
# #             target = text_onehot[:, 1:]
#             cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1))
#             cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1))
#             cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1))
# #             cost = Criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1, target.shape[-1]))
#             cost = cost_top + cost_mid + cost_bot
    
#             loss_avg = Averager()

#             model.zero_grad()
#             cost.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) #gradient clipping with 5
#             optimizer.step()

#             loss_avg.add(cost)

            #validation
            if (n_iter % opt.valInterval == 0) & (n_iter!=0):
                elapsed_time = time.time() - start_time
                with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log:
                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, pred_top_str, pred_mid_str, pred_bot_str, label_top, label_mid, label_bot, infer_time, length_of_data = utils.validation_jamo(model, criterion, valid_loader, top_converter, middle_converter, bottom_converter, opt)
                    model.train()

                    present_time = time.localtime()
                    loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.95)}]\n'+\
                    f'Train loss : {loss_avg.val():0.5f}, Valid loss : {valid_loss:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}'
                    loss_avg.reset()

                    current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}'

                    #keep the best
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy.pth')

                    if current_norm_ED > best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth')

                    best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                    loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                    print(loss_model_log)
                    log.write(loss_model_log+'\n')

                    dashed_line = '-'*80
                    head = f'{"Ground Truth":25s} | {"Prediction" :25s}| T/F'
                    predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                    random_idx  = np.random.choice(range(len(labels)), size= 5, replace=False)
                    for gt, pred in zip(list(np.asarray([label_top, label_mid, label_bot])[:, random_idx]), list(np.asarray([pred_top_str, pred_mid_str, pred_bot_str])[:, random_idx])):
#                         gt = gt[: gt.find('[s]')]
#                         pred = pred[: pred.find('[s]')]
                        print(f'gt : {gt}')
                        print(f'pred : {pred}')
                        predicted_result_log += f'{gt:25s} | {pred:25s} | \t{str(pred == gt)}\n'
                    predicted_result_log += f'{dashed_line}'
                    print(predicted_result_log)
                    log.write(predicted_result_log+'\n')

        if (n_epoch) % 5 ==0:
            torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')

## main

In [30]:
os.makedirs(f'./models/{opt.experiment_name}', exist_ok=True)

# set seed
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed(opt.manualSeed)

# set GPU
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()

# if opt.num_gpu > 1:
#     print('-------- Use multi GPU setting --------')
#     opt.workers = opt.workers * opt.num_gpu
#     opt.batch_size = opt.batch_size * opt.num_gpu

train(opt)

model parameters. height 35, width 90, num of fiducial 20, input channel 3, output channel 512, hidden size 256,     batch max length 25
Skip Trans.LocalizationNetwork.localization_fc2.weight as it is already initializaed
Skip Trans.LocalizationNetwork.localization_fc2.bias as it is already initializaed
Tranable params :  8643606


KeyError: 'ㅐ'

In [None]:
#먼저실행함
from jamo import h2j, j2hcj, j2h

In [None]:
j2hcj(h2j('회'))

----------------

## validation by visualization

In [None]:
import matplotlib.font_manager as fm
fm.get_fontconfig_fonts()
font_location = '/usr/share/fonts/truetype/nanum/NanumBarunGothic.ttf'
fontprop = fm.FontProperties(fname=font_location)
device = torch.device('cpu') 

class Valid_visualizer():
    def __init__(self, opt, model_path, val_path, visual_samples, device):
        self.opt = opt
        self.model_path = model_path
        self.val_path = val_path
        self.dataset = self._get_dataset()
        self.visual_samples = visual_samples
        self.device = device
        
    def _load_model(self):
        model = www_model.STR(self.opt, self.device)
        model.load_state_dict(torch.load(self.model_path))
        model.to(self.device)
        model.eval()
        return model
    
    def _get_dataset(self):
        val_list = os.listdir(self.val_path)
        val_dataset = []
        label = 'ㄱ'
        for val in val_list:
#             img = Image.open(f'./val/{val}').convert('RGB')
            val_dataset.append([os.path.join(self.val_path, val), label])
        return val_dataset
    
    def _get_valid_loader(self):

        test_streamer = utils.Dataset_streamer(self.dataset, resize_shape = (opt.imgH, opt.imgW), transformer=ToTensor())
#         _AlignCollate = utils.AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=True)
#         test_loader = DataLoader(test_streamer, batch_size = len(self.dataset), num_workers =0, collate_fn = _AlignCollate,)
        test_loader = DataLoader(test_streamer, batch_size = self.visual_samples, num_workers =0)
        return iter(test_loader)
    
    
    def valid_visualize(self):
        random.shuffle(self.dataset)
        test_loader_iter = self._get_valid_loader()
        image_tensor, label = next(test_loader_iter)
        model = self._load_model()
        output = model(input = image_tensor, text= ' ', is_train=False)
        pred_index = output.max(2)[1]
        pred_length = torch.IntTensor([opt.batch_max_length] * self.visual_samples).to(device)
        pred_decode = converter.decode(pred_index, pred_length)
        preds = []
        
        for pred in pred_decode:
            pred_temp = pred[ : pred.find('[s]')]
        #             pred_temp = join_jamos(pred_temp)
            preds.append(pred_temp)
        
        n_cols = 5
        n_rows = int(np.ceil(self.visual_samples/n_cols))
        last = self.visual_samples % n_cols
        fig, axes = plt.subplots(n_rows, n_cols)
        fig.set_size_inches((30, 30))
        i=0      
        for row in range(n_rows):
            for col in range(n_cols):
                axes[row][col].imshow(Image.open(self.dataset[i][0]))
                axes[row][col].set_xlabel(f'Prediction : {preds[i]}', fontproperties=fontprop, fontsize=30)
                i+=1
                if (row==n_rows-1) & (col==last-1):
                    break

In [None]:
vv = Valid_visualizer(opt, model_path = './models/www_0708/2/best_accuracy_91.602.pth', val_path = './val', visual_samples = 8, device= device)

In [None]:
vv.valid_visualize()