In [1]:
import spacy
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import torch
import re
from tqdm.notebook import tqdm
print(torch.__version__)
print(torch.cuda.is_available())
src_path = Path('.').absolute().parent
data_path = src_path / 'data'

1.10.1
True


In [2]:
import sys
sys.path.append(str(src_path))

import yaml
import networkx as nx
from src.ontology import OntologySystem

with (src_path / 'setting_files' / 'app_settings.yml').open('r') as file:
    settings = yaml.load(file, Loader=yaml.FullLoader)

onto = OntologySystem(
    acc_name_path=data_path / 'AccountName.csv', 
    rdf_path=data_path / 'AccountRDF.xml',
    model_path=data_path / settings['ontology']['model']['model_name'],
    kwargs_graph_drawer=settings['ontology']['graph_drawer']
)

In [4]:
ACC_DICT = onto.ACC_DICT

# Test for guessing masking tokens

In [3]:
from transformers import BertForMaskedLM, BertTokenizerFast

model_path = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForMaskedLM.from_pretrained(model_path)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Question 1

Asking information based on fact and knowledge

In [None]:
# Question 1
# what is the Cost of sales ratio in last year?
threshold = 0.01
exceptions = ['BalanceSheet', 'IncomeStatement', 'CalendarOneYear']
times = ['year', 'quarter']
sentence_format = "[MASK] is the {} in the [MASK] {}?"
n_top = 15

predicted_tokens_dict = defaultdict(set)
progress_bar = tqdm(total=((len(ACC_DICT) - len(exceptions)) * len(times)))

for acc, dic in ACC_DICT.items():
    if acc in exceptions:
        continue
    account_name = dic['eng_name'].lower()
    for t in times:
        s = sentence_format.format(account_name, t.lower())

        inputs = tokenizer(s, padding=True, truncation=True, return_token_type_ids=True, return_tensors='pt')
        inputs_tensors = inputs['input_ids']
        masked = inputs_tensors.eq(tokenizer.mask_token_id)
        outputs = model(**inputs).logits[masked]
        logits_top = outputs.argsort(descending=True)[:, :n_top]
        probs_top = outputs.softmax(1).gather(1, logits_top)

        for i, m in enumerate(probs_top >= threshold):
            # tkns.append([k.item() for k in logits_top[i, m]])
            for k in logits_top[i, m]:
                tkn = tokenizer.decode(k)
                if len(re.findall(r'(\")', tkn)) == 0:
                    predicted_tokens_dict[f'[MASK]-{i}-{t}'].add(tkn)
        
        progress_bar.update(1)

In [None]:
with (data_path / 'tkns.csv').open('w') as file:
    for k, v in predicted_tokens_dict.items():
        print(','.join([k] + list(v)), file=file)

## Question 2

What if: Analysis based on fact

In [40]:
knowledge = 'BS'
knowledge_query = onto.sparql.get_predefined_knowledge(knowledge=knowledge+'R')
results = onto.sparql.query(knowledge_query)
nx_graph = onto.get_nx_graph(results)
sub_tree = nx.bfs_successors(nx_graph, source='BalanceSheet')
sub_tree = dict(sub_tree)

In [41]:
# Question 2
# what happens to the operating income when the cost of sales increases by 10% this year?

threshold = 0.01
exceptions = ['BalanceSheet', 'IncomeStatement', 'CalendarOneYear']
times = ['year', 'quarter']
# sentence_format = "what [MASK] to the {} when the {} [MASK] by {} {} in the [MASK] {}?"
sentence_format = "what will be the effect to {} if the {} [MASK] by {} {} in the [MASK] {}?"
n_top = 15
successors = []
predicted_tokens_dict = defaultdict(set)
progress_bar = tqdm()

for sub_acc, accs in sub_tree.items():
    if sub_acc in exceptions:
        continue
    sub_acc_name = ACC_DICT[sub_acc]['eng_name'].lower()
    successors.extend(accs)
    for acc in successors:
        account_name = ACC_DICT[acc]['eng_name'].lower()
        for t in times:
            s = sentence_format.format(
                account_name, sub_acc_name, 
                np.random.randint(1, 50, (1,))[0], np.random.choice(['percent', '%']),
                t.lower())

            inputs = tokenizer(s, padding=True, truncation=True, return_token_type_ids=True, return_tensors='pt')
            inputs_tensors = inputs['input_ids']
            masked = inputs_tensors.eq(tokenizer.mask_token_id)
            outputs = model(**inputs).logits[masked]
            logits_top = outputs.argsort(descending=True)[:, :n_top]
            probs_top = outputs.softmax(1).gather(1, logits_top)
            for i, m in enumerate(probs_top >= threshold):
                # tkns.append([k.item() for k in logits_top[i, m]])
                for k in logits_top[i, m]:
                    tkn = tokenizer.decode(k)
                    if len(re.findall(r'(\")', tkn)) == 0:
                        predicted_tokens_dict[f'[MASK]-{i}-{t}'].add(tkn)

            progress_bar.update(1)

0it [00:00, ?it/s]

In [42]:
with (data_path / 'tkns.csv').open('w') as file:
    for k, v in predicted_tokens_dict.items():
        print(','.join([k] + list(v)), file=file)

## Question 3

What if: Forecasting with embedded ML

In [44]:
# Question 3
# what will be our revenue in the 4th quarter?

threshold = 0.01
exceptions = ['BalanceSheet', 'IncomeStatement', 'CalendarOneYear']
times = ['year', 'quarter']
sentence_format = "[MASK] will be the {} in the [MASK] {}?"
# sentence_format = "how is the {} going to be in the [MASK] {}?"
n_top = 15

predicted_tokens_dict = defaultdict(set)
progress_bar = tqdm(total=((len(ACC_DICT) - len(exceptions)) * len(times)))

for acc, dic in ACC_DICT.items():
    if acc in exceptions:
        continue
    account_name = dic['eng_name'].lower()
    for t in times:
        s = sentence_format.format(account_name, t.lower())

        inputs = tokenizer(s, padding=True, truncation=True, return_token_type_ids=True, return_tensors='pt')
        inputs_tensors = inputs['input_ids']
        masked = inputs_tensors.eq(tokenizer.mask_token_id)
        outputs = model(**inputs).logits[masked]
        logits_top = outputs.argsort(descending=True)[:, :n_top]
        probs_top = outputs.softmax(1).gather(1, logits_top)

        for i, m in enumerate(probs_top >= threshold):
            # tkns.append([k.item() for k in logits_top[i, m]])
            for k in logits_top[i, m]:
                tkn = tokenizer.decode(k)
                if len(re.findall(r'(\")', tkn)) == 0:
                    predicted_tokens_dict[f'[MASK]-{i}-{t}'].add(tkn)
        
        progress_bar.update(1)

  0%|          | 0/78 [00:00<?, ?it/s]

In [45]:
with (data_path / 'tkns.csv').open('w') as file:
    for k, v in predicted_tokens_dict.items():
        print(','.join([k] + list(v)), file=file)

---

# Create dataset

In [92]:
# don't need to define the future but past words cannot use in future
df = pd.read_csv(data_path / 'AccountWords.csv', encoding='utf-8')

format_dict = {
    0: ['help'],
    1: [
        # what/how, target_account, [MASK] + year/quarter
        "{} is the {} in the {}?",
        # [MASK] + year/quarter, what/how, target_account
        "In the {}, {} is the value of the {}"
    ], 
    2: [
        # target_account, subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter
        "what happens to the {} when the {} {} by {} in the {}?",
        # target_account, subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter
        "what will be the effect to {} if the {} {} by {} in the {}?",
        # reverse the relation
        # subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter, target_account
        "when the {} {} by {} in the {}, what will happen to the {}?",
        # subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter, target_account
        "if the {} {} by {} in the {}, what will be the effect to {}?"
    ],
    3: [
        # what/how, target_account, [MASK] + year/quarter
        "{} will be the {} in the {}?"
    ]
}

# TODO: maybe add the today's information after [SEP]?
context = ['HELP', 'PAST', 'FUTURE']
words = defaultdict(list)
for typ in ['year', 'quarter', 'words']:
    df_temp = df.loc[:, [typ, f'{typ}_tag', f'{typ}_desc']]
    df_temp = df_temp.loc[~df_temp[typ].isna(), :]
    for i, (w, t, desc) in df_temp.iterrows():
        words[typ].append((w, t, desc))

In [104]:
# TODO: position 만들기
# ("I was driving a BMW", {"entities": [(16,19, "PRODUCT")]})

def get_entity(s, x, tag):
    idx = s.index(x)
    return (idx, idx+len(x), tag)

def random_sampling(x_dict, x_key):
    idx_range = np.arange(len(x_dict[x_key]))
    idx = np.random.choice(idx_range, replace=False, p=np.ones(len(idx_range)) / len(idx_range))
    word, tag, desc = x_dict[x_key][idx]
    return word, tag, desc

def get_words_filtered(words, text):
    words_filtered = defaultdict(list)
    for k, v in words.items():
        for word, tag, desc in v:
            if desc != text:
                words_filtered[k].append((word, tag, desc))
    return words_filtered

exceptions = ['BalanceSheet', 'IncomeStatement', 'Ratios', 'CalendarOneYear']
times = ['year', 'quarter']

all_data = []
s_ENT = '[E]'
e_ENT = '[/E]'
f_ENT = lambda x: f'{s_ENT}{x}{e_ENT}'


## Question 1

```python
# what/how, target_account, [MASK] + year/quarter
"{} is the {} in the {}?",
```

In [105]:
data1 = []
trg_scenario = 1
n_sample = 10
progress_bar = tqdm()
words_filtered = get_words_filtered(words, text='FUTURE')
for idx_fmt, fmt in enumerate(format_dict[trg_scenario]):
    for acc, dic in ACC_DICT.items():
        if acc in exceptions:
            continue
        target_account = dic['eng_name'].lower()
        knowledge, *_ = dic['group'].split('-')

        for t in ['year', 'quarter']:
            for t_word, t_tag, t_desc in words_filtered[t]:
                entities = []
                pre_token = np.random.choice(['what', 'how'], replace=False, p=np.ones(2)/2)
                if idx_fmt == 0:
                    # what/how, target_account, [MASK] + year/quarter
                    # "{} is the {} in the {}?",
                    
                    s = fmt.format(
                        pre_token,
                        f_ENT(target_account), 
                        f_ENT(f'{t_word} {t}')
                        )
                else:
                    # [MASK] + year/quarter, what/how, target_account
                    # "In the {}, {} is the value of the {}"
                    s = fmt.format(
                        f_ENT(f'{t_word} {t}'),
                        pre_token,
                        f_ENT(target_account)
                    )
                relation = [0, 0, 0]  # no_relation, order1, order2
                # entities
                ## target_account
                entities.append(get_entity(s, f_ENT(target_account), f'{knowledge}.{acc}'))
                ## MASK year/quarter
                entities.append(get_entity(s, f_ENT(f'{t_word} {t}'), t_tag))
                
                data1.append(
                    {'question': s, 'entities': sorted(entities, key=lambda x: x[0]), 'intent': 'PAST.value', 'relation': relation}
                )
            
                progress_bar.update(1)

print(len(data1))

0it [00:00, ?it/s]

2106


## Question 2

```python
# target_account, subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter
"what happens to the {} when the {} {} by {} in the {}?"
# target_account, subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter
"what will be the effect to {} if the {} {} by {} in the {}?"
# reverse the relation
# subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter, target_account
"when the {} {} by {} in the {}, what will happen to the {}?"
# subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter, target_account
"if the {} {} by {} in the {}, what will be the effect to {}?"
```

In [106]:
def get_role_dict(onto, knowledge):
    knowledge_query = onto.sparql.get_predefined_knowledge(knowledge=knowledge)
    sparql_results = onto.sparql.query(knowledge_query)
    role_dict = defaultdict(list)
    for s, p, o in sparql_results:
        s, p, o = map(onto.graph_drawer.convert_to_string, [s, p, o])
        if s == 'CalendarOneYear' or o == 'CalendarOneYear':
            continue
        if s not in role_dict[o]:
            role_dict[o].append(s)
        
    return role_dict

def process_successor(successors, role_dict, trg_acc, acc):
    if role_dict.get(acc) is None:
        # successors[trg_acc].extend(successors[acc])
        return None
    else:
        accs = role_dict.get(acc)
        if accs is not None:
            successors[trg_acc].extend(accs)
            for acc in accs:
                process_successor(successors, role_dict, trg_acc, acc)

def get_successor(onto, knowledge, exceptions=None):
    role_dict = get_role_dict(onto, knowledge=knowledge)
    successors = defaultdict(list)
    for trg_acc in role_dict.keys():
        if (exceptions is not None) and (trg_acc in exceptions):
            continue
        process_successor(successors, role_dict, trg_acc, trg_acc)
    return successors


In [107]:
exceptions = ['BalanceSheet', 'IncomeStatement', 'Ratios', 'CalendarOneYear']

bs_successors = get_successor(onto, 'BS', exceptions)
is_successors = get_successor(onto, 'IS', exceptions)

In [108]:
data2 = []
trg_scenario = 2
n_sample = 5
progress_bar = tqdm()
words_filtered = get_words_filtered(words, text='FUTURE')

for idx_fmt, fmt in enumerate(format_dict[trg_scenario]):
    for sub_tree in [bs_successors, is_successors]:
        for trg_acc, successors in sub_tree.items():
            if trg_acc in exceptions:
                continue
            target_account = ACC_DICT[trg_acc]['eng_name'].lower()
            target_knowledge, *_ = ACC_DICT[trg_acc]['group'].split('-')
            for sub_acc in successors:
                subject_account = ACC_DICT[sub_acc]['eng_name'].lower()
                subject_knowledge, *_ = ACC_DICT[trg_acc]['group'].split('-')
                n = 0
                while n < n_sample:
                    entities = []

                    apply_word, apply_tag, apply_desc = random_sampling(x_dict=words_filtered, x_key='words')
                    t = np.random.choice(times, replace=False, p=np.ones(len(times))/len(times))
                    t_word, t_tag, t_desc = random_sampling(x_dict=words_filtered, x_key=t)
                    
                    number = np.random.randint(1, 99)
                    percent = np.random.choice(['percent', '%'], replace=False, p=np.ones(2)/2)
                    
                    if idx_fmt in [0, 1]:
                        # target_account, subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter
                        s = fmt.format(
                            f_ENT(target_account),
                            f_ENT(subject_account), 
                            f_ENT(apply_word), 
                            f_ENT(f'{number} {percent}'),
                            f_ENT(f'{t_word} {t}')
                            )
                        relation = [1, 1, 2]
                    else:
                        # subject_account, [MASK], random_number + percent/%, [MASK] + year/quarter, target_account
                        s = fmt.format(
                            f_ENT(subject_account), 
                            f_ENT(apply_word), 
                            f_ENT(f'{number} {percent}'),
                            f_ENT(f'{t_word} {t}'),
                            f_ENT(target_account)
                            )
                        relation = [1, 2, 1]
                    # entities
                    ## target_account
                    entities.append(get_entity(s, f_ENT(target_account), f'{target_knowledge}.{trg_acc}'))
                    ## subject_account
                    entities.append(get_entity(s, f_ENT(subject_account), f'{subject_knowledge}.{sub_acc}'))
                    ## MASK apply words
                    entities.append(get_entity(s, f_ENT(apply_word), apply_tag))
                    ## percentages
                    entities.append(get_entity(s, f_ENT(f'{number} {percent}'), 'PERCENT'))
                    ## MASK year/quarter
                    entities.append(get_entity(s, f_ENT(f'{t_word} {t}'), t_tag))

                    d = {'question': s, 'entities': sorted(entities, key=lambda x: x[0]), 'intent': 'IF.fact', 'relation': relation}
                    if d not in data2:
                        data2.append(
                            d
                        )
                    
                    progress_bar.update(1)
                    n += 1

print(len(data2))

0it [00:00, ?it/s]

1940


## Question 3

```python
# what/how, target_account, [MASK] + year/quarter
"{} will be the {} in the {}?"
```

In [109]:
data3 = []
trg_scenario = 3
progress_bar = tqdm()
words_filtered = get_words_filtered(words, text='PAST')

for fmt in format_dict[trg_scenario]:
    for acc, dic in ACC_DICT.items():
        if acc in exceptions:
            continue
        target_account = dic['eng_name'].lower()
        knowledge, *_ = dic['group'].split('-')
        for t in ['year', 'quarter']:
            for t_word, t_tag, t_desc in words_filtered[t]:
                entities = []
                s = fmt.format(
                    np.random.choice(['what', 'how']), 
                    f_ENT(target_account), 
                    f_ENT(f'{t_word} {t}')
                    )
                relation = [0, 0, 0]
                # entities
                ## target_account
                entities.append(get_entity(s, f_ENT(target_account), f'{knowledge}.{acc}'))
                ## MASK year/quarter
                entities.append(get_entity(s, f_ENT(f'{t_word} {t}'), t_tag))
                
                data3.append(
                    {'question': s, 'entities': entities, 'intent': 'IF.forecast', 'relation': relation}
                )
                
                progress_bar.update(1)

print(len(data3))

0it [00:00, ?it/s]

1014


In [110]:
all_data = data1 + data2 + data3

---

# Post-process for entities

In [111]:
import json

special_len = len(s_ENT)+len(e_ENT)

for k, x in tqdm(enumerate(all_data), total=len(all_data)):
    all_data[k]['question'] = x['question'].replace(s_ENT, '').replace(e_ENT, '')
    for i, (s, e, ent) in enumerate(x['entities']):
        new_s = s-i*special_len
        new_e = new_s+(e-s)-special_len
        all_data[k]['entities'][i] = (new_s, new_e, ent)

with (data_path / 'all_data.jsonl').open('w', encoding='utf-8') as file:
    for line in tqdm(all_data, total=len(all_data), desc='saving'):
        file.write(json.dumps(line) + '\n')

  0%|          | 0/5060 [00:00<?, ?it/s]

saving:   0%|          | 0/5060 [00:00<?, ?it/s]

In [112]:
import spacy
from transformers import BertTokenizerFast

class NLUTokenizer:
    def __init__(
        self, 
        hugg_path='bert-base-uncased', 
        spacy_path='en_core_web_sm'
    ):
        self.tokenizer = BertTokenizerFast.from_pretrained(hugg_path)
        self.spacy_nlp = spacy.load(spacy_path)

    def tokenize(self, text):
        return self.tokenizer.tokenize(text)

    def decode(self, token_ids, **kwargs):
        return self.tokenizer.decode(token_ids, **kwargs)


    def __call__(self, text, **kwargs):
        return self.tokenizer(text, **kwargs)

    @classmethod
    def offsets_to_iob_tags(cls, encodes, ents, get_acc_relation=False):
        """
        ```
        IOB SCHEME
        I - Token is inside an entity.
        O - Token is outside an entity.
        B - Token is the beginning of an entity.

        BILUO SCHEME
        B - Token is the beginning of a multi-token entity.
        I - Token is inside a multi-token entity.
        L - Token is the last token of a multi-token entity.
        U - Token is a single-token unit entity.
        O - Token is outside an entity.
        ```
        method: IOB SCHEME
        modified from https://github.com/explosion/spaCy/blob/9d63dfacfc85e7cd6db7190bd742dfe240205de5/spacy/training/iob_utils.py#L63

        encodes: batch encodes from huggingface TokenizerFast
        ents: entities with start & end characters in sentences + entity
        """
        acc_relation = list()

        starts, ends = dict(), dict()
        for tkn_idx, (s_idx, e_idx) in enumerate(encodes['offset_mapping']):
            if s_idx == e_idx == 0:
                continue
            starts[s_idx] = tkn_idx
            ends[e_idx] = tkn_idx
        
        char_in_ents = {}
        labels = ['-'] * len(encodes['input_ids'])
        for s_char, e_char, ent in ents:
            if not ent:
                for s in starts:
                    labels[starts[s]] = 'O'
            else:
                for char_idx in range(s_char, e_char):
                    if char_idx in char_in_ents.keys():
                        raise ValueError(f'Trying to Overlapping same tokens: {char_in_ents[char_idx]} / {(s_char, e_char, ent)}')
                    char_in_ents[char_idx] = (s_char, e_char, ent)
                s_token = starts.get(s_char)
                e_token = ends.get(e_char)

                if s_token is not None and e_token is not None:
                    labels[s_token] = f'B-{ent}'
                    # add relation
                    if get_acc_relation and len(ent.split('.')) > 1:
                        acc_relation.append((s_token, e_token+1))

                    for i in range(s_token + 1, e_token+1):
                        labels[i] = f'I-{ent}'
                        
        entity_chars = set()
        for s_char, e_char, ent in ents:
            for i in range(s_char, e_char):
                entity_chars.add(i)
        for token_idx, (s, e) in enumerate(encodes['offset_mapping']):
            for i in range(s, e):
                if i in entity_chars:
                    break
            else:
                labels[token_idx] = 'O'
        if '-' in labels:
            raise ValueError('Some Tokens are not properly assigned' + f'{labels}')

        return labels, acc_relation

    def pad_tags(self, input_ids, tags, pad_idx:int=-100):
        padded_tags = [pad_idx] * len(input_ids)
        j = 0
        for i, tkn_id in enumerate(input_ids):
            if tkn_id in self.tokenizer.all_special_ids:
                continue
            padded_tags[i] = tags[j]
            j += 1
        return padded_tags

    # def get_labels(self, intent, input_ids, tags, pad_idx:int=-100):
    #     labels = self.pad_tags(input_ids, tags, pad_idx)
    #     labels[0] = intent
    #     return labels

nlu_tokenizer = NLUTokenizer(hugg_path='bert-base-uncased', spacy_path='en_core_web_sm')

In [113]:
import json
import pickle

with (data_path / 'all_data.jsonl').open('r', encoding='utf-8') as file:
    data = file.readlines()
    all_data = []
    for line in tqdm(data, total=len(data), desc='loading'):
        all_data.append(json.loads(line))

processed_data = []
pad_id = nlu_tokenizer.tokenizer.pad_token_type_id
for k, x in tqdm(enumerate(all_data), total=len(all_data), desc='processing data'):
    encodes = nlu_tokenizer(text=x['question'], add_special_tokens=False, return_offsets_mapping=True)
    has_relation = x['relation'][0]
    tags, acc_relation = nlu_tokenizer.offsets_to_iob_tags(encodes, ents=x['entities'], get_acc_relation=has_relation)
    if acc_relation:
        # if there is relation process the coordinates of tokens
        # target: 1 / subject: 2
        # relation = [has_relation, target_coor, subject_coor]
        # network will be guessing has_relation,  
        a, b = x['relation'][1:]
        input_ids = encodes['input_ids']
        if a == 1 and b == 2:
            s_trg, e_trg = acc_relation[0]
            s_sub, e_sub = acc_relation[1]
        elif a == 2 and b == 1:
            s_trg, e_trg = acc_relation[1]
            s_sub, e_sub = acc_relation[0]
        # plus 1 for add cls token in the front of sentences
        target_relation = (s_trg+1, e_trg+1)
        subject_relation = (s_sub+1, e_sub+1)
        relation = [has_relation, target_relation, subject_relation]
    else:
        relation = [has_relation, (0,0), (0,0)]
    
    processed_data.append((x['question'], tags, x['intent'], relation))

with (data_path / 'all_data_processed.pickle').open('wb') as file:
    pickle.dump(processed_data, file)

loading:   0%|          | 0/5060 [00:00<?, ?it/s]

processing data:   0%|          | 0/5060 [00:00<?, ?it/s]

---

# Data Splits

In [114]:
import pickle
from sklearn.model_selection import StratifiedShuffleSplit

seed=777
with (data_path / 'all_data_processed.pickle').open('rb') as file:
    all_data = pickle.load(file)

questions, tags, intents, relations = list(zip(*all_data))
# split to train & test
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=seed)
train_idx, test_idx = list(*splitter.split(questions, intents))

train_data = []
test_data = []

tags_set = set()
for idx in tqdm(range(len(questions)), total=len(questions), desc='spliting data'):
    # process tags and intents
    data = (questions[idx], tags[idx], intents[idx], relations[idx])

    for t in tags[idx]:
        tags_set.add(t)

    if idx in train_idx:
        train_data.append(data)
    elif idx in test_idx:
        test_data.append(data)
    else:
        raise ValueError("Index Error")

intents2id = {'None': 0}
for intent in set(intents):
    if intents2id.get(intent) is None:
        intents2id[intent] = len(intents2id)

tags2id = {'[PAD]': 0, 'O': 1}
for t in tags_set:
    if tags2id.get(t) is None:
        tags2id[t] = len(tags2id)

with (data_path / 'all_data_splitted.pickle').open('wb') as file:
    pickle.dump({
        'train': train_data, 
        'test': test_data, 
        }, file)

with (data_path / 'all_data_ids.pickle').open('wb') as file:
    pickle.dump({
        'tags2id': tags2id, 
        'intents2id': intents2id
        }, file)

spliting data:   0%|          | 0/5060 [00:00<?, ?it/s]

# Traning

- Entities
- Entities Relation (subject, target)

## Dataset

In [4]:
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import pickle

class NLUDataset(Dataset):
    def __init__(
        self, data, 
        tags2id=None, 
        intents2id=None, 
        hugg_path='bert-base-uncased', 
        spacy_path='en_core_web_sm', 
        max_len=128,
    ):
        self.questions, self.tags, self.intents, self.relations = list(zip(*data))
        self.tokenizer = NLUTokenizer(hugg_path, spacy_path)
        # question, entities, intent
        self.tags2id = tags2id
        self.intents2id = intents2id
        self.max_len = max_len

    def __getitem__(self, index):
        question = self.questions[index]
        tags = list(map(self.tags2id.get, self.tags[index]))
        intent = self.intents2id.get(self.intents[index])
        relation = self.relations[index]

        encodes = self.tokenizer(
            question, 
            return_offsets_mapping=False,
            padding='max_length', 
            truncation=True, 
            max_length=self.max_len, 
        )
        # labels = intent + tags
        tags = self.tokenizer.pad_tags(
            input_ids=encodes['input_ids'], 
            tags=tags, 
            pad_idx=0,
        )

        item = {k: torch.as_tensor(v) for k, v in encodes.items()}

        item['intent'] =  torch.as_tensor(intent)
        item['tags'] = torch.as_tensor(tags)
        item['has_relation'] = torch.as_tensor(relation[0])
        item['target_relation'] = torch.as_tensor(relation[1])
        item['subject_relation'] = torch.as_tensor(relation[2])
        return item

    def __len__(self):
        return len(self.questions)

class NLUDataModule(pl.LightningDataModule):
    def __init__(
        self, data_path:Path, ids_path:Path,
        batch_size:int=32, 
        max_len:int=128,
        test_size=0.1,
        num_workers=0,
        seed=777
    ):
        super().__init__()
        self.data_path = data_path
        self.ids_path = ids_path
        self.batch_size = batch_size
        self.max_len = max_len
        self.test_size = test_size
        self.seed = seed
        self.num_workers = num_workers

    def load_data(self):
        with Path(self.data_path).open('rb') as file:
            data = pickle.load(file)
        
        with Path(self.ids_path).open('rb') as file:
            ids = pickle.load(file)

        self.train_data = data['train']
        self.test_data = data['test']
        self.tags2id = ids['tags2id']
        self.intents2id = ids['intents2id']

    def prepare_data(self):
        self.load_data()
        
    def train_dataloader(self):
        train_dataset = NLUDataset(
            self.train_data, 
            tags2id=self.tags2id, 
            intents2id=self.intents2id,
            max_len=self.max_len
        )
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        val_dataset = NLUDataset(
            self.test_data, 
            tags2id=self.tags2id, 
            intents2id=self.intents2id,
            max_len=self.max_len
        )
        return DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

In [5]:
max_len = 64
data_module = NLUDataModule(
    data_path=data_path / 'all_data_splitted.pickle',
    ids_path=data_path / 'all_data_ids.pickle',
    batch_size=64, 
    max_len=max_len,
    num_workers=0
)
data_module.prepare_data()

In [6]:
train_loader = data_module.train_dataloader()
for x in train_loader:
    break

In [None]:
# from collections import Counter

# train_loader = data_module.train_dataloader()

# dist = Counter() 
# for x in tqdm(train_loader, total=len(train_loader)):
#     dist.update(list(x['intent'].numpy()))

# test_loader = data_module.val_dataloader()

# test_dist = Counter() 
# for x in tqdm(test_loader, total=len(test_loader)):
#     test_dist.update(list(x['intent'].numpy()))

## Modeling

In [22]:
nn.ModuleDict()

ModuleDict()

In [26]:
import torch.nn as nn
import torchmetrics
import pytorch_lightning as pl
from transformers import BertForTokenClassification, BertConfig

class BertPooler(nn.Module):
    def __init__(self, config):
        """from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert/modeling_bert.py#L627"""
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class RelationNetwork(nn.Module):
    def __init__(self, hidden_size, output_size):
        """output_size = max_len*4 + 1 (has_relation) """
        super().__init__()
        self.output_size = output_size
        self.relation_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, output_size*4+1)
        )
    
    def forward(self, x):
        o = self.relation_net(x)
        has_relation = o[:, 0:1].squeeze(-1).contiguous()
        relations = o[:, 1:]
        s_target, e_target, s_subject, e_subject = map(lambda x: x.squeeze(-1).contiguous(), relations.split(self.output_size, dim=-1))
        return has_relation, s_target, e_target, s_subject, e_subject

class NLUModel(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters() 
        # self.hparams: model_path, intent_size, tags_size, max_len
        self.outputs_keys = ['tags', 'intent', 'has_relation', 's_target', 'e_target', 's_subject', 'e_subject']
        # Networks
        cfg = BertConfig()
        self.bert_ner = BertForTokenClassification.from_pretrained(self.hparams.model_path, num_labels=self.hparams.tags_size)
        self.bert_pooler = BertPooler(cfg)
        self.intent_network = nn.Linear(cfg.hidden_size, self.hparams.intent_size)
        self.relation_network = RelationNetwork(cfg.hidden_size, self.hparams.max_len)
        
        # losses
        if self.hparams.stage == 'train':
            self.losses = {
                'bce': nn.BCEWithLogitsLoss(),
                'ce': nn.CrossEntropyLoss()
            }
            # metrics
            self.metrics = nn.ModuleDict({
                'train_': self.create_metrics(prefix='train_'),
                'val_': self.create_metrics(prefix='val_')
            })
    def contiguous(self, x):
        return x.squeeze(-1).contiguous()

    def create_metrics(self, prefix='train_'):
        m = nn.ModuleDict()
        metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(), torchmetrics.Precision(), torchmetrics.Recall()])
        for k in self.outputs_keys:
            m[k] = metrics.clone(prefix+k+'_')
        return m

    def _forward_bert(self, input_ids, token_type_ids, attention_mask):
        outputs = self.bert_ner.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        return outputs.last_hidden_state

    def _forward_tags(self, last_hidden_state):
        tags_outputs = self.bert_ner.dropout(last_hidden_state)
        tags_logits = self.bert_ner.classifier(tags_outputs)
        # intent
        return tags_logits.view(-1, self.hparams.tags_size)

    def _forward_intent(self, pooled_outputs):
        intent_logits = self.intent_network(pooled_outputs)
        return intent_logits

    def _forward_relation(self, pooled_outputs):
        # pooled_outputs: (B, max_len, 768)
        # has_relation: (B, )
        # s_target, e_target, s_subject, e_subject: (B, max_len)
        has_relation_logits, s_target_logits, e_target_logits, s_subject_logits, e_subject_logits = \
            self.relation_network(pooled_outputs)
        return has_relation_logits, s_target_logits, e_target_logits, s_subject_logits, e_subject_logits

    # def _get_relation_inputs(self, last_hidden_state, relation):
    #     x = torch.stack([last_hidden_state[i, s:e].mean(0) for i, (s, e) in enumerate(relation)])
    #     return x 

    # def _forward_relation(self, last_hidden_state, target_relation, subject_relation):
    #     target_inputs = self._get_relation_inputs(last_hidden_state, target_relation)
    #     subject_inputs = self._get_relation_inputs(last_hidden_state, subject_relation)
    #     relation_inputs = torch.concat([last_hidden_state[:, 0], target_inputs, subject_inputs], dim=1)
    #     has_relation_logits, s_target_logits, e_target_logits, s_subject_logits, e_subject_logits = self.relation_network(relation_inputs)
    #     return has_relation_logits, s_target_logits, e_target_logits, s_subject_logits, e_subject_logits

    def forward(self, input_ids, token_type_ids, attention_mask):
        # tags
        batch_size = input_ids.size(0)
        last_hidden_state = self._forward_bert(input_ids, token_type_ids, attention_mask)
        tags_logits = self._forward_tags(last_hidden_state)

        # intent
        pooled_outputs = self.bert_pooler(last_hidden_state)
        intent_logits = self._forward_intent(pooled_outputs)
        # relation
        has_relation_logits, s_target_logits, e_target_logits, s_subject_logits, e_subject_logits = \
            self._forward_relation(pooled_outputs)

        return {
            'tags': tags_logits,                       # (B*max_len, tags_size)
            'intent': intent_logits,                   # (B, intent_size)
            'has_relation': has_relation_logits,       # (B, )
            's_target': s_target_logits,               # (B, max_len)
            'e_target': e_target_logits,               # (B, max_len)
            's_subject': s_subject_logits,             # (B, max_len)
            'e_subject': e_subject_logits              # (B, max_len)
        }

    def forward_all(self, batch, prefix='train_'):
        outputs = self.forward(
            input_ids=batch['input_ids'], 
            token_type_ids=batch['token_type_ids'], 
            attention_mask=batch['attention_mask'], 
        )
        s_target, e_target = map(self.contiguous, batch['target_relation'].split(1, dim=-1))
        s_subject, e_subject = map(self.contiguous, batch['subject_relation'].split(1, dim=-1))
        targets = {
            'tags': batch['tags'].view(-1),         # (B*max_len, )
            'intent': batch['intent'],              # (B, )
            'has_relation': batch['has_relation'],  # (B, )
            's_target': s_target,                   # (B, )
            'e_target': e_target,                   # (B, )
            's_subject': s_subject,                 # (B, )
            'e_subject': e_subject                  # (B, )
        }
        loss = self.cal_loss(outputs, targets)
        self.log(f'{prefix}loss', loss)
        # logging
        self.cal_metrics(outputs, targets, prefix=prefix)
        return loss

    def cal_loss(self, outputs, targets):
        has_relation_loss = self.losses['bce'](outputs['has_relation'], targets['has_relation'].float())

        tags_loss = self.losses['ce'](outputs['tags'], targets['tags'])
        intent_loss = self.losses['ce'](outputs['intent'], targets['intent'])
        s_target_loss = self.losses['ce'](outputs['s_target'], targets['s_target'])
        e_target_loss = self.losses['ce'](outputs['e_target'], targets['e_target'])
        s_subject_loss = self.losses['ce'](outputs['s_subject'], targets['s_subject'])
        e_subject_loss = self.losses['ce'](outputs['e_subject'], targets['e_subject'])

        return tags_loss + intent_loss + s_target_loss + e_target_loss + s_subject_loss + e_subject_loss + has_relation_loss

    def cal_metrics(self, outputs, targets, prefix='train_'):
        outputs_metrics = defaultdict()
        for k in self.outputs_keys:
            for k_sub, v in self.metrics[prefix][k](outputs[k], targets[k]).items():
                outputs_metrics[k_sub] = v
        self.log_dict(outputs_metrics)

    def training_step(self, batch, batch_idx):
        loss = self.forward_all(batch, prefix='train_')
        return loss

    def validation_step(self, batch, batch_idx):   
        loss = self.forward_all(batch, prefix='val_')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

    def predict(self, input_ids, token_type_ids, attention_mask):
        outputs = self.forward(input_ids, token_type_ids, attention_mask)
        predicts = self._predict_from_outputs(outputs)
        return predicts

    def _predict_from_outputs(self, outputs):
        predicts = {k: outputs[k].argmax(-1) for k in ['tags', 'intent', 's_target', 'e_target', 's_subject', 'e_subject']}
        predicts['has_relation'] = (outputs['has_relation'].sigmoid() >= 0.5).byte()
        return predicts

In [27]:
with Path(data_path / 'all_data_ids.pickle').open('rb') as file:
    ids = pickle.load(file)
tags2id = ids['tags2id']
intents2id = ids['intents2id']

hparams = {
    'stage': 'train',
    'model_path': 'bert-base-uncased', 
    'intent_size': len(intents2id), 
    'tags_size': len(tags2id), 
    'max_len': 64,
    'lr': 1e-3,
    'load_path': None
}

model = NLUModel(**hparams)

data_module = NLUDataModule(
    data_path=data_path / 'all_data_splitted.pickle',
    ids_path=data_path / 'all_data_ids.pickle',
    batch_size=32, 
    max_len=64,
    num_workers=0
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

In [11]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

log_path = src_path / 'logs'
checkpoint_path = src_path / 'checkpoints'

logger = TensorBoardLogger(save_dir=str(log_path), name="NLU")
checkpoint_callback = ModelCheckpoint(
    dirpath=str(checkpoint_path), 
    save_top_k=2,
    monitor='val_loss'
)
progress_callback = TQDMProgressBar(refresh_rate=20)
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=3, 
    logger=logger, 
    callbacks=[checkpoint_callback, progress_callback]
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [243]:
trainer.fit(
    model, datamodule=data_module
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                       | Params
----------------------------------------------------------------
0 | bert_ner         | BertForTokenClassification | 108 M 
1 | bert_pooler      | BertPooler                 | 590 K 
2 | intent_network   | Linear                     | 3.1 K 
3 | relation_network | RelationNetwork            | 789 K 
----------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
441.356   Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
data_module.prepare_data()
train_loader = data_module.train_dataloader()
for x in train_loader:
    break