In [1]:
from torchtext.data import TabularDataset, Field, RawField
from tqdm import tqdm_notebook as tqdm
from src.keyword.data.token import get_token, find_stem_answer
from src.keyword.data.graph_util import build_graph, normalize_graph

import re

In [2]:
dataset = TabularDataset(path='../rsc/kp20k/kp20k_training.json',
                             format='json',
                             fields={'abstract': ('abstract', Field(sequential=True)),
                                     'keyword': ('keyword', Field(sequential=False))})



In [3]:
for d in tqdm(dataset):
    print(d.keyword)
    print(d.abstract)
    break

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=530809.0), HTML(value='')))

telepresence;animation;avatars;application sharing;collaborative virtual environments
['This', 'paper', 'proposes', 'using', 'virtual', 'reality', 'to', 'enhance', 'the', 'perception', 'of', 'actions', 'by', 'distant', 'users', 'on', 'a', 'shared', 'application.', 'Here,', 'distance', 'may', 'refer', 'either', 'to', 'space', '(', 'e.g.', 'in', 'a', 'remote', 'synchronous', 'collaboration)', 'or', 'time', '(', 'e.g.', 'during', 'playback', 'of', 'recorded', 'actions).', 'Our', 'approach', 'consists', 'in', 'immersing', 'the', 'application', 'in', 'a', 'virtual', 'inhabited', '3D', 'space', 'and', 'mimicking', 'user', 'actions', 'by', 'animating', 'avatars.', 'We', 'illustrate', 'this', 'approach', 'with', 'two', 'applications,', 'the', 'one', 'for', 'remote', 'collaboration', 'on', 'a', 'shared', 'application', 'and', 'the', 'other', 'to', 'playback', 'recorded', 'sequences', 'of', 'user', 'actions.', 'We', 'suggest', 'this', 'could', 'be', 'a', 'low', 'cost', 'enhancement', 'for', 'tel

In [4]:
def build_dataset(
        dataset, 
        max_src_seq_len: int,
        max_trg_seq_len: int,
        lower:bool = True, 
        valid_check:bool = True):
    
    null_ids, absent_ids = 0, 0
    
    SRC = RawField(postprocessing=None)
    TRG = RawField(postprocessing=None)
    
    for d in tqdm(dataset):
        abstract = ' '.join(d.abstract)
        keyword = d.keyword.split(';')
        
        src_filter_flag = False
        src_tokens = get_token(abstract)

        # max_seq_len 을 충족하지 못하면 넘긴다.
        if len(src_tokens) > max_src_seq_len:
            src_filter_flag = True

        if valid_check and src_filter_flag:
            continue

        trgs_tokens = []

        for trg in keyword:
            trg_filter_flag = False
            trg = trg.lower()

            # FILTER 1: remove all the abbreviations/acronyms in parentheses in keyphrases
            trg = re.sub(r'\(.*?\)', '', trg)
            trg = re.sub(r'\[.*?\]', '', trg)
            trg = re.sub(r'\{.*?\}', '', trg)

            # FILTER 2: ingore all the phrases that contains strange punctuations, very DIRTY data!
            puncts = re.findall(r'[,_\"<>\(\){}\[\]\?~`!@$%\^=]', trg)

            trg_tokens = get_token(trg)

            if len(puncts) > 0:
                continue

            if len(trg_tokens) > max_trg_seq_len:
                trg_filter_flag = True

            if valid_check and trg_filter_flag:
                continue

            if valid_check and (len(trg_tokens) > 0 and re.match(r'\d\d[a-zA-Z\-]\d\d',
                                                                 trg_tokens[0].strip())) or (
                    len(trg_tokens) > 1 and re.match(r'\d\d\w\d\d', trg_tokens[1].strip())):
                continue

            trgs_tokens.append(trg_tokens)

        if valid_check and len(trgs_tokens) == 0:
            continue

        if lower:
            src_tokens = [token.lower() for token in src_tokens]

        present_phrases = find_stem_answer(word_list=src_tokens, ans_list=trgs_tokens)

        if present_phrases is None:
            null_ids += 1
            continue

        if len(present_phrases['keyphrases']) != len(trgs_tokens):
            absent_ids += 1

        d.src = ' '.join(src_tokens)
        d.trg = present_phrases['keyphrases']
    
    dataset.fields['src'] = SRC
    dataset.fields['trg'] = TRG
    
    return dataset

In [None]:
train_dataset = build_dataset(dataset, 512, 5, True, True)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  del sys.path[0]


HBox(children=(FloatProgress(value=0.0, max=530809.0), HTML(value='')))




In [None]:
import torch


torch.save(list(train_dataset), "./train_data_not_have_graph.pt")

# def build_graph_dataset(dataset):
#     pass