In [1]:
import torch
import torch.nn as nn
import numpy as np
from CharToIndex import CharToIndex
from MyDatasets import Cross_Validation
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F


#5文字の中心を予測
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,data,chars_file_path,device=torch.device('cpu')):
        self.data = data
        self.char2index = CharToIndex(chars_file_path)
        self.length = len(data['answer'])-4
        self.p_val_idx = torch.zeros((self.length+4,10),dtype=torch.long)
        self.p_ans_idx = torch.zeros(self.length+4,dtype=torch.long)
        self.d_ans     = torch.zeros(self.length+4,dtype=torch.long)
        self.device = device

        for i_r,chars in enumerate(data['value']):
            for i_c, idx in enumerate(map(self.char2index.get_index,chars)):
                self.p_val_idx[i_r][i_c] = idx

        for i,char in enumerate(data['answer']):
            self.p_ans_idx[i] = self.char2index.get_index(char)
            self.d_ans[i] = 1 if self.p_val_idx[i][0] == self.p_ans_idx[i] else 0 #検出器用、OCR第一出力と答えが等しければ１、異なれば０


        #距離値付きのten_hot_encodeにvalueを変換
        distances = np.nan_to_num(data['distance'])
        self.distanced_ten_hot_encoded_value = torch.full((self.length+6,len(self.char2index)),0,dtype=torch.float)
        for row,indicies in enumerate(self.p_val_idx):
            for id_distance,id_value in enumerate(indicies):
                self.distanced_ten_hot_encoded_value[row][id_value]=distances[row][id_distance]


    def __len__(self):
        return self.length


    def __getitem__(self,index):
        p_inp = self.p_val_idx[index:index+5,0].to(device)
        p_tar = self.p_ans_idx[index+2].to(device)
        d_ans = self.d_ans[index+2].to(device)
        distance = self.distanced_ten_hot_encoded_value[index+2].to(device)
        return distance,d_ans,p_inp,p_tar



class Proofreader(nn.Module):
    def __init__(self, input_size, hidden_dim, output_size,n_layers):
        super(Proofreader, self).__init__()

        self.output_size = output_size
        self.hidden_dim = hidden_dim
        self.n_layers  = n_layers
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.embedding = nn.Embedding(output_size,embedding_dim=256)
        self.rnn = nn.RNN(256, self.hidden_dim, batch_first=True,bidirectional=True)
        self.dropout = torch.nn.Dropout(p=0.5)
        self.fc = nn.Linear(self.hidden_dim*2, output_size)
        self.softmax = nn.Softmax(dim=1)
        self.to(self.device)


    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers*2, batch_size, self.hidden_dim)
        return hidden


    def forward(self, x, distance):
        batch_size = x.size(0)
        x = self.embedding(x.long())
        hidden = self.init_hidden(batch_size).to(self.device)
        out, hidden = self.rnn(x, hidden)
        out = out[:,2,:]
        out = self.dropout(out)
        out = self.fc(out)

        out.mul_(distance)
        return out



class Detector(nn.Module):
  def __init__(self,encode_size):
    super(Detector, self).__init__()
    self.fc1 = nn.Linear(encode_size, 128)
    self.fc2 = nn.Linear(128, 2)
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.to(self.device)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x


chars_file_path = r"/net/nfs2/export/home/ohno/CR_pytorch/data/tegaki_katsuji/all_chars_3812.npy"
datas_file_path = r"/net/nfs2/export/home/ohno/CR_pytorch/data/tegaki_katsuji/tegaki_distance.npz"
tokens = CharToIndex(chars_file_path)
data = np.load(datas_file_path,allow_pickle=True)

EMBEDDING_DIM = 10
HIDDEN_SIZE = 128
BATCH_SIZE = 64
VOCAB_SIZE = len(tokens)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tegaki_dataset = MyDataset(data,chars_file_path,device=device)


[31mERROR: No such char --> [0mb'\xe3\x82\x91'
[31mERROR: No such char --> [0mb'\xe7\xb8\x8a'


In [3]:

def eval(proofreader,valid_dataloader):
    batch_size = next(iter(valid_dataloader))[0].size(0)
    p_runnning_accu = 0
    proofreader.eval()

    for d_x,d_y,p_x,p_y in valid_dataloader:
        #修正器の処理
        p_output = proofreader(p_x,d_x)
        p_prediction = p_output.data.max(1)[1] #予測結果
        p_runnning_accu += p_prediction.eq(p_y.data).sum().item()/batch_size

    p_accu = p_runnning_accu/len(valid_dataloader)

    return p_accu

def examination(detector,proofreader,valid_dataloader,show_out=False):
    batch_size = next(iter(valid_dataloader))[0].size(0)
    threshold = torch.full((batch_size,2),1).to(device)
    runnning_accu = 0
    rnn_collect_cnt=0
    ocr_collect_cnt=0
    detector.eval()
    proofreader.eval()
    for d_x,d_y,p_x,p_y in valid_dataloader:
        #検出器の処理
        ocr_pred = p_x[:,2] #OCR第一候補
        d_output = detector(d_x)
        compare = (F.softmax(d_output,dim=1) >= threshold).long()
        flg_ocr = compare.data.max(1)[1]
        ocr = torch.mul(ocr_pred,flg_ocr)

        #修正器の処理
        p_output = proofreader(p_x,d_x)
        rnn_pred = p_output.data.max(1)[1] #RNNの予測結果
        flg_rnn = torch.logical_not(flg_ocr,out=torch.empty(batch_size,dtype=torch.long).to(device))#rnnの出力を使用するか
        rnn = torch.mul(rnn_pred,flg_rnn)

        prediction = torch.add(ocr,rnn)
        runnning_accu += prediction.eq(p_y.data).sum().item()/batch_size
        rnn_collect_cnt += prediction.eq(p_y.data).sum().item()
        ocr_collect_cnt += ocr_pred.eq(p_y.data).sum().item()
        if show_out:
            print('\nＯＣＲ: ',end='')
            for idx in ocr_pred.data:
                print(tokens.get_decoded_char(idx),end='')

            print('\nＲＮＮ: ',end='')
            for idx in rnn_pred.data:
                print(tokens.get_decoded_char(idx),end='')

            print('\n検出　: ',end='')
            for idx,ocr_idx in zip(ocr.data,ocr_pred.data):
                if idx == 0:
                    print('[',tokens.get_decoded_char(ocr_idx),']',end='')
                else :
                    print(tokens.get_decoded_char(ocr_idx),end='')

            print('\n予測　: ',end='')
            for idx in prediction:
                print(tokens.get_decoded_char(idx),end='')

            print('\n正解　: ',end='')
            for idx in p_y.data:
                print(tokens.get_decoded_char(idx),end='')

            print()
    accuracy = runnning_accu/len(valid_dataloader)
    return accuracy,rnn_collect_cnt,ocr_collect_cnt



In [4]:

def examination2(proofreader,valid_dataloader,show_out=False):
    batch_size = next(iter(valid_dataloader))[0].size(0)
    threshold = torch.full((batch_size,2),1).to(device)
    runnning_accu = 0
    rnn_collect_cnt=0
    ocr_collect_cnt=0

    proofreader.eval()
    for d_x,d_y,p_x,p_y in valid_dataloader:
        ocr_pred = p_x[:,2] #OCR第一候補
        #修正器の処理
        p_output = proofreader(p_x,d_x)
        rnn_pred = p_output.data.max(1)[1] #RNNの予測結果
        runnning_accu += rnn_pred.eq(p_y.data).sum().item()/batch_size

        if show_out:
            print('\nＯＣＲ: ',end='')
            for idx in ocr_pred.data:
                print(tokens.get_decoded_char(idx),end='')

            print('\nＲＮＮ: ',end='')
            for idx in rnn_pred.data:
                print(tokens.get_decoded_char(idx),end='')

            print('\n正解　: ',end='')
            for idx in p_y.data:
                print(tokens.get_decoded_char(idx),end='')

            print()
    accuracy = runnning_accu/len(valid_dataloader)
    return accuracy,rnn_collect_cnt,ocr_collect_cnt

In [5]:

def examination3(proofreader,valid_dataloader):
    batch_size = next(iter(valid_dataloader))[0].size(0)
    count_arr = np.zeros((len(tokens),2),dtype=int)
    success_cnt=0
    misread_cnt=0
    fatal_cnt = 0
    proofreader.eval()
    for d_x,d_y,p_x,p_y in valid_dataloader:
        ocr_pred = p_x[:,2] #OCR第一候補
        #修正器の処理
        p_output = proofreader(p_x,d_x)
        rnn_pred = p_output.data.max(1)[1] #RNNの予測結果

        for ocr,pred,ans  in zip(ocr_pred,rnn_pred,p_y):
            if ocr != ans and pred == ans:
                count_arr[ans.item()][0]+=1 #ocrの間違いを修正できた回数
                success_cnt+=1
            if ocr != ans:
                count_arr[ans.item()][1]+=1 #ocrの間違い回数
                misread_cnt+=1
            if ocr == ans and pred != ans:
                fatal_cnt+=1

    return count_arr,success_cnt,fatal_cnt,misread_cnt

In [9]:
cross_validation = Cross_Validation(tegaki_dataset)
k_num = cross_validation.k_num #デフォルトは10
# k_num=1
acc_record=[]
all_rnn = 0
all_ocr = 0
count_arr = np.zeros((len(tokens),2),dtype=int)
success_cnt=0
misread_cnt=0
fatal_cnt = 0
##学習
for k_idx in range(k_num):
    train_dataset,valid_dataset = cross_validation.get_datasets(k_idx=k_idx)
    valid_dataloader=DataLoader(valid_dataset,batch_size=BATCH_SIZE,shuffle=False,drop_last=True)

    proofreader = Proofreader(VOCAB_SIZE, hidden_dim=HIDDEN_SIZE, output_size=VOCAB_SIZE, n_layers=1)
    proofreader.load_state_dict(torch.load("/net/nfs2/export/home/ohno/CR_pytorch/Culmination/Learned_models/proof_k"+str(k_idx+1)))
    acc,_,_ = examination2(proofreader,valid_dataloader)
    acc_record.append(acc)
    examination2(proofreader,valid_dataloader,show_out=True)
    # cnt_arr,suc_cnt,fat_cnt,mis_cnt = examination3(proofreader,valid_dataloader)
    # success_cnt += suc_cnt
    # misread_cnt += mis_cnt
    # fatal_cnt   += fat_cnt
print(np.mean(acc_record))
acc_record


ＯＣＲ: もご刹囲゛ただきまして．誠にありがと５ござぃます。さて１お客様の定韻・足期貯金は１右訪のとおり満期を迎えますので．ご案内ぃたレ末
ＲＮＮ: もご到囲いただきまして<UNK>我にありがとらございます。さて１お客様の定誤<UNK>定期貯金は１右誌のとおリ満期を迎えますので<UNK>ご案内いたし末
正解　: もご利用いただきまして、誠にありがとうございます。さて、お客様の定額・定期預金は、右記のとおり満期を迎えますので、ご案内いたしま

ＯＣＲ: す．今後ともー層のご愛顧を賜りますょうお願ぃ申しあげますっ（払戻しおょび預け替えのお手続きの際に必要となる書類等）貯金証書１総合
ＲＮＮ: す<UNK>今後ともー層のご愛顧を賜りますようお験い中しあげますつ（払戻しおよび預け替えのお手続きの際に必要となる書類等）貯金証書１総合
正解　: す。今後とも一層のご愛顧を賜りますようお願い申し上げます。（払戻しおよび預け替えのお手続きの際に必要となる書類等）貯金証書、総合

ＯＣＲ: 口座通帳（無遍帳型総合ロ座の場合はキャッシュカード１（ぉ届け印おょびご本人であるこヒを碓認できる証明書類（お名前・ご住所・生年月
ＲＮＮ: ロ座通帳（無遍限型総合ロ座の場合はキャッシュカードノ（お届け印およびご本んであることを確認できる証明書類（お名前、ご住所・生年月
正解　: 口座通帳（無通帳型総合口座の場合はキャッシュカード）、お届け印およびご本人であることを確認できる証明書類（お名前・ご住所・生年月

ＯＣＲ: 日の入）た運転免許証等１法人名義っ場合は登認事項証明書等．団体名義の場合は規約の写し筆もああせてお持５くださぃ）このことにつ゛て
ＲＮＮ: 日の入）た運転免評証等１法人名義つ場合は登認事漠証明書等<UNK>団体名義の場合は規約の軍し箏もあわせてお持ちください）このことについて
正解　: 日の入った運転免許証等、法人名義の場合は登記事項証明書等、団体名義の場合は規約の写し等もあわせてお持ちください）このことについて

ＯＣＲ: に日本争術振輿会か５別添のとおソ通知があルまレた。フきまレては司卦の研究者使用ルしルと科研費ハンドブゥ７もごニ送頂き二盈究童旦過
ＲＮＮ: に日本争術振興会から別逵のとおリ通知があルました。っきましては周郵の研究者使用ルしルと科研悪ハンドブッ７もご二送頂きニ亜

KeyboardInterrupt: 

In [7]:
# for idx in np.argsort(-count_arr[:,0]):
#     if count_arr[idx][0] == 0 and count_arr[idx][1] >0:
#         print(tokens.get_decoded_char(idx),end=' ')
#         print(count_arr[idx][0],'回', end=' ')
#         print(count_arr[idx][1],'回')


In [8]:
len(tegaki_dataset)

15011