In [22]:
import random

In [1]:
import torch
from transformers import AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [3]:
from Preprocessor import Preprocessor
preprocessor = Preprocessor(root_dir='/media/dmlab/My Passport/DATA/CABERT')

In [9]:
company_names = preprocessor.company_names
company_names[:5]

['aar corp',
 'abbott laboratories',
 'worlds inc',
 'acme united corp',
 'adams resources & energy  inc ']

In [12]:
text = 'amazon com inc completed our initial public offering in may 1997 and our common stock is listed on the nasdaq global select market under the symbol amazon com inc as used herein amazon com inc com inc amazon com inc our and similar terms include amazon com inc com inc inc'
text

'amazon com inc completed our initial public offering in may 1997 and our common stock is listed on the nasdaq global select market under the symbol amazon com inc as used herein amazon com inc com inc amazon com inc our and similar terms include amazon com inc com inc inc'

In [17]:
input_tokens = tokenizer.convert_ids_to_tokens(tokenizer(text)['input_ids'])
print(input_tokens)

['[CLS]', 'amazon', 'com', 'inc', 'completed', 'our', 'initial', 'public', 'offering', 'in', 'may', '1997', 'and', 'our', 'common', 'stock', 'is', 'listed', 'on', 'the', 'nas', '##da', '##q', 'global', 'select', 'market', 'under', 'the', 'symbol', 'amazon', 'com', 'inc', 'as', 'used', 'here', '##in', 'amazon', 'com', 'inc', 'com', 'inc', 'amazon', 'com', 'inc', 'our', 'and', 'similar', 'terms', 'include', 'amazon', 'com', 'inc', 'com', 'inc', 'inc', '[SEP]']


In [18]:
company_names_selected = list(filter(lambda x: x.lower() in text.lower(), company_names))
company_names_selected

['amazon com inc']

In [19]:
company_name_to_tokens = {company_name:tokenizer.convert_ids_to_tokens(tokenizer(company_name)['input_ids'])[1:-1] for company_name in company_names}
company_name_to_tokens['amazon com inc']

['amazon', 'com', 'inc']

In [21]:
def find_sub_list_indices(sl,l):
    results=[]
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):  
        if l[ind:ind+sll]==sl:
            results.append((ind,ind+sll-1))
    return results


range_list_of_indices = []
for name in company_names_selected: 
    range_list_of_indices.extend(find_sub_list_indices(company_name_to_tokens[name], input_tokens))
range_list_of_indices

[(1, 3), (29, 31), (36, 38), (41, 43), (49, 51)]

In [31]:
grouped_indices = [] 
for (start_idx, end_idx) in range_list_of_indices:
    grouped_indices.append(list(range(start_idx, end_idx+1)))
grouped_indices

[[1, 2, 3], [29, 30, 31], [36, 37, 38], [41, 42, 43], [49, 50, 51]]

In [32]:
random.shuffle(grouped_indices)
grouped_indices

[[36, 37, 38], [49, 50, 51], [1, 2, 3], [41, 42, 43], [29, 30, 31]]

In [33]:
# 기업명을 제외한 토큰들
cand_indexes = []
for idx in range(1, len(input_tokens)-2+1): # [CLS], [SEP] 제외
    if idx in [item for sub in grouped_indices for item in sub]: 
        continue
    else:
        cand_indexes.append([idx])
print(cand_indexes)

[[4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [32], [33], [34], [35], [39], [40], [44], [45], [46], [47], [48], [52], [53], [54]]


In [34]:
random.shuffle(cand_indexes) # 셔플링
print(cand_indexes)

[[9], [34], [47], [13], [39], [18], [21], [12], [40], [26], [32], [6], [24], [17], [15], [25], [22], [27], [33], [4], [45], [46], [7], [54], [14], [10], [11], [28], [53], [44], [8], [52], [23], [16], [20], [19], [35], [48], [5]]


In [35]:
cand_indexes = grouped_indices + cand_indexes
print(cand_indexes)

[[36, 37, 38], [49, 50, 51], [1, 2, 3], [41, 42, 43], [29, 30, 31], [9], [34], [47], [13], [39], [18], [21], [12], [40], [26], [32], [6], [24], [17], [15], [25], [22], [27], [33], [4], [45], [46], [7], [54], [14], [10], [11], [28], [53], [44], [8], [52], [23], [16], [20], [19], [35], [48], [5]]
