In [31]:
from torchtext.data import TabularDataset, Field, RawField
from tqdm import tqdm_notebook as tqdm
from src.keyword.data.token import get_token, find_stem_answer
from src.keyword.data.graph_util import build_graph, normalize_graph
from transformers import BertTokenizer
from torchtext.data import BucketIterator, Dataset

import re

In [72]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(tokens):
    return tokenizer(
        tokens,
        padding='max_length',
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )

SRC = Field(tokenize=tokenize, sequential=True)
TRG = Field(tokenize=tokenize, sequential=True)

dataset = TabularDataset(path='../rsc/preprocessed/kp20k.valid.json',
                             format='json',
                             fields={'doc_words': ('doc_words', SRC),
                                     'keyphrases': ('keyphrases', TRG)})

In [75]:
data_fields = [('src', SRC), ('trg', TRG)]

train_iterator = BucketIterator.splits(
    dataset,
    batch_size = 32, 
    sort_key = lambda x: len(x.batch.doc_words),
    sort_within_batch=False)

In [86]:
for i, batch in enumerate(train_iterator):
    print(vars(batch.dataset.doc_words['data']))
    break

KeyError: 'data'

In [33]:
def build_dataset(
        dataset, 
        max_src_seq_len: int,
        max_trg_seq_len: int,
        lower:bool = True, 
        valid_check:bool = True):
    
    null_ids, absent_ids = 0, 0
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    SRC = Field(tokenize=tokenizer)
    TRG = Field(tokenize=tokenizer)
    
    for d in tqdm(dataset):
        abstract = ' '.join(d.abstract)
        keyword = d.keyword.split(';')
        
        src_filter_flag = False
        src_tokens = get_token(abstract)

        # max_seq_len 을 충족하지 못하면 넘긴다.
        if len(src_tokens) > max_src_seq_len:
            src_filter_flag = True

        if valid_check and src_filter_flag:
            continue

        trgs_tokens = []

        for trg in keyword:
            trg_filter_flag = False
            trg = trg.lower()

            # FILTER 1: remove all the abbreviations/acronyms in parentheses in keyphrases
            trg = re.sub(r'\(.*?\)', '', trg)
            trg = re.sub(r'\[.*?\]', '', trg)
            trg = re.sub(r'\{.*?\}', '', trg)

            # FILTER 2: ingore all the phrases that contains strange punctuations, very DIRTY data!
            puncts = re.findall(r'[,_\"<>\(\){}\[\]\?~`!@$%\^=]', trg)

            trg_tokens = get_token(trg)

            if len(puncts) > 0:
                continue

            if len(trg_tokens) > max_trg_seq_len:
                trg_filter_flag = True

            if valid_check and trg_filter_flag:
                continue

            if valid_check and (len(trg_tokens) > 0 and re.match(r'\d\d[a-zA-Z\-]\d\d',
                                                                 trg_tokens[0].strip())) or (
                    len(trg_tokens) > 1 and re.match(r'\d\d\w\d\d', trg_tokens[1].strip())):
                continue

            trgs_tokens.append(trg_tokens)

        if valid_check and len(trgs_tokens) == 0:
            continue

        if lower:
            src_tokens = [token.lower() for token in src_tokens]

        present_phrases = find_stem_answer(word_list=src_tokens, ans_list=trgs_tokens)

        if present_phrases is None:
            null_ids += 1
            continue

        if len(present_phrases['keyphrases']) != len(trgs_tokens):
            absent_ids += 1
        
#         print(src_tokens)
#         print(present_phrases['keyphrases'])
        d.src = ' '.join(src_tokens)
        d.trg = present_phrases['keyphrases']
#         break
    
    dataset.fields['src'] = SRC
    dataset.fields['trg'] = TRG
    
    return dataset

In [34]:
train_dataset = build_dataset(dataset, 512, 5, True, True)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app


HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))




In [36]:
print(vars(train_dataset[0]))

{'abstract': ['We', 'investigate', 'the', 'problem', 'of', 'delay', 'constrained', 'maximal', 'information', 'collection', 'for', 'CSMA-based', 'wireless', 'sensor', 'networks.', 'We', 'study', 'how', 'to', 'allocate', 'the', 'maximal', 'allowable', 'transmission', 'delay', 'at', 'each', 'node,', 'such', 'that', 'the', 'amount', 'of', 'information', 'collected', 'at', 'the', 'sink', 'is', 'maximized', 'and', 'the', 'total', 'delay', 'for', 'the', 'data', 'aggregation', 'is', 'within', 'the', 'given', 'bound.', 'We', 'formulate', 'the', 'problem', 'by', 'using', 'dynamic', 'programming', 'and', 'propose', 'an', 'optimal', 'algorithm', 'for', 'the', 'optimal', 'assignment', 'of', 'transmission', 'attempts.', 'Based', 'on', 'the', 'analysis', 'of', 'the', 'optimal', 'solution,', 'we', 'propose', 'a', 'distributed', 'greedy', 'algorithm.', 'It', 'is', 'shown', 'to', 'have', 'a', 'similar', 'performance', 'as', 'the', 'optimal', 'one.'], 'keyword': 'algorithms;design;performance;sensor netw

In [39]:
from torchtext.data import BucketIterator

train_iterator = BucketIterator.splits(
    train_dataset, 
    batch_size = 128, 
    sort_key = lambda x: len(x.src),
    sort_within_batch=False)



In [43]:
for batch in train_iterator:
    print(vars(batch))
    break

{'batch_size': 128, 'train': True, 'dataset': <torchtext.data.example.Example object at 0x2a0d59f50>, 'batch_size_fn': None, 'iterations': 0, 'repeat': False, 'shuffle': True, 'sort': False, 'sort_within_batch': False, 'sort_key': <function <lambda> at 0x2af93d200>, 'device': device(type='cpu'), 'random_shuffler': <torchtext.data.utils.RandomShuffler object at 0x2af581650>, '_iterations_this_epoch': 0, '_random_state_this_epoch': None, '_restored_from_state': False}


In [None]:
import torch


torch.save(list(train_dataset), "./train_data_not_have_graph.pt")

In [3]:
from transformers import BertTokenizer
from torchtext.data import Dataset


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

SRC = Field(tokenize = tokenizer) 
TGT = Field(tokenize = tokenizer)
GRH = RawField(postprocessing=None)



In [9]:
import torch

%time 
train_dataset_ = torch.load('./train_data_not_have_graph.pt')

CPU times: user 8 µs, sys: 6 µs, total: 14 µs
Wall time: 30 µs


RuntimeError: [enforce fail at inline_container.cc:144] . PytorchStreamReader failed reading zip archive: failed finding central directory