In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
from torch.optim.optimizer import Optimizer
import matplotlib.pyplot as plt
from copy import deepcopy
import numpy as np
import random
import torch
from transformers import pipeline
import warnings 
warnings.filterwarnings('ignore')
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader
import os
import gc
gc.collect()

def get_jaccard_sim(str1, str2): 
    a = set(str1.split()) 
    b = set(str2.split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

  '"sox" backend is being deprecated. '


# set_seed(42)

In [2]:
def set_seed(seed = int):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    seed_everything(seed)
    return random_state
random_state = set_seed(42)

Global seed set to 42


# LOAD DATA

In [3]:
df = pd.read_csv('preprocess_for_SQUAD_銀行.csv',index_col=0)
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
print(train_df.shape)
print(val_df.shape)
val_df

(1160, 4)
(291, 4)


Unnamed: 0,string_X_train,Y_label,string_Y_1,string_Y_2
3174,1SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND ...,CHINA CITIC BANK,2399,2415
4253,SIGNED COMMERCIAL INVOICE IN 3 COPIESx000DFULL...,KEB HANA BANK,120,133
311,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,492,505
8266,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS INDIC...,CTBC BANK CO LTD,2379,2395
870,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,492,505
...,...,...,...,...
7357,nan ALL DOCUMENTS MUST INDICATE THIS LC NOx000...,SAUDI BRITISH BANK,2612,2630
7333,SIGNED COMMERCIAL INVOICE IN 1 ORIGINAL INDICA...,TAISHIN INTERNATIONAL BANK,185,211
4024,1 MANUALLY SIGNED COMMERCIAL INVOICE IN 3 ORIG...,BANK OF CHINA LTD,828,845
1734,1 SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND...,BANK CENTRAL ASIA,214,231


# Load Model

In [5]:
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForQuestionAnswering

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
new_tokens = df['Y_label'].values.tolist()
num_added_toks = tokenizer.add_tokens(new_tokens)
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load('Product_Data_SQuAD_model_bank.pt'))
model.eval()
nlp = pipeline('question-answering', model=model.to('cpu'), tokenizer=tokenizer)
gc.collect()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

244

# rule base

In [6]:
import numpy as np

import re

def preprocess(x):
    x = str(x)
    x = re.sub('[\u4e00-\u9fa5]', '', x) # 1.去除中文
    x = re.sub('[’!"#$%&\'()*+,/:;<=>?@[\\]^_`{|}~，。,.]', '', x) # 2.去除標點符號
    x = x.replace('\n', '').replace('\r', '').replace('\t', '') # 3.去除換行符號
    x = str.strip(x) # 4.移除左右空白
    if 'x000D' in x:
        x = x.replace('x000D','')
    return x

def get_bank(text):
    text = str(text)
    text = preprocess(text)
    keywords = ['TO ORDER OF','TO THEORDER OF','TO THE ORDER OF','TOTHE ORDER OF','TO THE ORDER+OF','TOORDER OF']
    for i in keywords:
        if i in text:
            idx = text.split(i)[1].find('BANK')
            result = preprocess(text.split(i)[1][:idx+len('BANK')])
            if 'BANK' in result:
                return result
            else:
                return None
        else:
            return None

In [7]:
result = pd.DataFrame()
result['string_X_train'] = val_df['string_X_train']
result['Y_label'] = val_df['Y_label']
result

Unnamed: 0,string_X_train,Y_label
3174,1SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND ...,CHINA CITIC BANK
4253,SIGNED COMMERCIAL INVOICE IN 3 COPIESx000DFULL...,KEB HANA BANK
311,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD
8266,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS INDIC...,CTBC BANK CO LTD
870,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD
...,...,...
7357,nan ALL DOCUMENTS MUST INDICATE THIS LC NOx000...,SAUDI BRITISH BANK
7333,SIGNED COMMERCIAL INVOICE IN 1 ORIGINAL INDICA...,TAISHIN INTERNATIONAL BANK
4024,1 MANUALLY SIGNED COMMERCIAL INVOICE IN 3 ORIG...,BANK OF CHINA LTD
1734,1 SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND...,BANK CENTRAL ASIA


In [8]:
a = val_df['string_X_train'].apply(get_bank)
a = a.dropna(axis=0)
a

8266                                         ISSUING BANK
4500    SOAP  ALLIED INDUSTRIES LTDOLD MOKA ROAD BELL ...
5501                                           VIETINBANK
7547                                 SUMITOMO MITSUI BANK
3459                                             YES BANK
1124                                      UNITEDARAB BANK
3675                                           STATE BANK
3747                                            SACOMBANK
4550                                                 BANK
1666                                           ICICI BANK
6241                                 ASIA COMMERCIAL BANK
1667                                           ICICI BANK
4307                                            SACOMBANK
3083                                           VIETINBANK
3883                                                 BANK
3428                                 SUMITOMO MITSUI BANK
3576                                                 BANK
4511          

In [9]:
result['predict'] = val_df['string_X_train'].apply(get_bank)
result

Unnamed: 0,string_X_train,Y_label,predict
3174,1SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND ...,CHINA CITIC BANK,
4253,SIGNED COMMERCIAL INVOICE IN 3 COPIESx000DFULL...,KEB HANA BANK,
311,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
8266,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS INDIC...,CTBC BANK CO LTD,ISSUING BANK
870,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
...,...,...,...
7357,nan ALL DOCUMENTS MUST INDICATE THIS LC NOx000...,SAUDI BRITISH BANK,
7333,SIGNED COMMERCIAL INVOICE IN 1 ORIGINAL INDICA...,TAISHIN INTERNATIONAL BANK,
4024,1 MANUALLY SIGNED COMMERCIAL INVOICE IN 3 ORIG...,BANK OF CHINA LTD,
1734,1 SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND...,BANK CENTRAL ASIA,BANK


In [10]:
result

Unnamed: 0,string_X_train,Y_label,predict
3174,1SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND ...,CHINA CITIC BANK,
4253,SIGNED COMMERCIAL INVOICE IN 3 COPIESx000DFULL...,KEB HANA BANK,
311,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
8266,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS INDIC...,CTBC BANK CO LTD,ISSUING BANK
870,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
...,...,...,...
7357,nan ALL DOCUMENTS MUST INDICATE THIS LC NOx000...,SAUDI BRITISH BANK,
7333,SIGNED COMMERCIAL INVOICE IN 1 ORIGINAL INDICA...,TAISHIN INTERNATIONAL BANK,
4024,1 MANUALLY SIGNED COMMERCIAL INVOICE IN 3 ORIG...,BANK OF CHINA LTD,
1734,1 SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND...,BANK CENTRAL ASIA,BANK


# 接bert

In [11]:
not_find = []
for j,i in enumerate(result.iloc[:,2].values):
    if i == None:
        not_find.append(j)
len(not_find)

234

In [12]:
not_find_df = result.iloc[not_find]
not_find_df

Unnamed: 0,string_X_train,Y_label,predict
3174,1SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND ...,CHINA CITIC BANK,
4253,SIGNED COMMERCIAL INVOICE IN 3 COPIESx000DFULL...,KEB HANA BANK,
311,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
870,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALSx000DC...,MUFG BANK LTD,
3518,1BENEFICIARYS MANUALLY SIGNED COMMERCIAL INVOI...,ASKARI BANK LIMITED,
...,...,...,...
3303,1MANUALLY SIGNED COMMERCIAL INVOICE IN 2 ORIGI...,BANK OF CHINA,
7357,nan ALL DOCUMENTS MUST INDICATE THIS LC NOx000...,SAUDI BRITISH BANK,
7333,SIGNED COMMERCIAL INVOICE IN 1 ORIGINAL INDICA...,TAISHIN INTERNATIONAL BANK,
4024,1 MANUALLY SIGNED COMMERCIAL INVOICE IN 3 ORIG...,BANK OF CHINA LTD,


In [13]:
def model_predict(nlp,df):
    table = pd.DataFrame()
    for i in tqdm(df.index):
        sample = df.loc[[i]]
        string_X_train = sample['string_X_train'].values[0]
        QA_input = {
            'question': 'What is the bank name?',
            'context': string_X_train
        }
        res = nlp(QA_input)
        predict = QA_input['context'][res['start']:res['end']]
        row = pd.DataFrame({'predict:':predict},index=[i])
        table = table.append(row)
    return table

In [14]:
bert_predict = model_predict(nlp,not_find_df)
bert_predict

  0%|          | 0/234 [00:00<?, ?it/s]

Unnamed: 0,predict:
3174,SERVICECHINA CITIC BANKGUANGZHOU
4253,KEB HANA BANKx000DMARKED
311,MUFG BANK LTD
870,MUFG BANK LTD
3518,ASKARI BANK LTD
...,...
3303,BANK OF CHINA
7357,SAUDI BRITISH BANK SABBx000DGLOBAL
7333,TAISHIN INTERNATIONAL BANK
4024,BANK OF CHINA LTD


In [15]:
result.loc[bert_predict.index] = bert_predict.values
result

Unnamed: 0,string_X_train,Y_label,predict
3174,SERVICECHINA CITIC BANKGUANGZHOU,SERVICECHINA CITIC BANKGUANGZHOU,SERVICECHINA CITIC BANKGUANGZHOU
4253,KEB HANA BANKx000DMARKED,KEB HANA BANKx000DMARKED,KEB HANA BANKx000DMARKED
311,MUFG BANK LTD,MUFG BANK LTD,MUFG BANK LTD
8266,SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS INDIC...,CTBC BANK CO LTD,ISSUING BANK
870,MUFG BANK LTD,MUFG BANK LTD,MUFG BANK LTD
...,...,...,...
7357,SAUDI BRITISH BANK SABBx000DGLOBAL,SAUDI BRITISH BANK SABBx000DGLOBAL,SAUDI BRITISH BANK SABBx000DGLOBAL
7333,TAISHIN INTERNATIONAL BANK,TAISHIN INTERNATIONAL BANK,TAISHIN INTERNATIONAL BANK
4024,BANK OF CHINA LTD,BANK OF CHINA LTD,BANK OF CHINA LTD
1734,1 SIGNED COMMERCIAL INVOICE IN 3 ORIGINALS AND...,BANK CENTRAL ASIA,BANK


In [16]:
def get_acc(df,t=0.75):
    correct = []
    correct_label = []
    for i in df.index:
        jac = get_jaccard_sim(df.loc[i,'Y_label'],df.loc[i,'predict'])
        if jac >= t:
            correct.append('yes')
        else:
            correct.append('no')
    result = pd.Series(correct)
    return result.value_counts()['yes']/len(result)

def get_jac(df):
    all_jacs = []
    for i in df.index:
        all_jacs.append(get_jaccard_sim(str(df.loc[i,'Y_label']),str(df.loc[i,'predict'])))
    return np.sum(all_jacs)/len(all_jacs)

In [17]:
get_acc(result,1),get_acc(result,0.75),get_jac(result)

(0.8350515463917526, 0.8384879725085911, 0.8927684827872512)