In [360]:
from torchtext.data import Field, BucketIterator, RawField
from torchtext.vocab import GloVe
from tqdm import tqdm

In [361]:
SRC = Field(tokenize=lambda x: x.split(" "), lower=True, batch_first=True)
TRG = Field(tokenize=lambda x: x.split(" "), init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)



In [362]:
from torchtext.data import TabularDataset


dataset = TabularDataset(path='../rsc/preprocessed/kp20k.valid_100_lines.json',
                         format='json',
                         fields={
                                 'doc_words':('text', SRC), 
                                 'keyphrases': ('label', TRG)}
                        )



In [363]:
print(vars(dataset[0])['text'])

['we', 'investigate', 'the', 'problem', 'of', 'delay', 'constrained', 'maximal', 'information', 'collection', 'for', 'csma', 'based', 'wireless', 'sensor', 'networks', '.', 'we', 'study', 'how', 'to', 'allocate', 'the', 'maximal', 'allowable', 'transmission', 'delay', 'at', 'each', 'node', ',', 'such', 'that', 'the', 'amount', 'of', 'information', 'collected', 'at', 'the', 'sink', 'is', 'maximized', 'and', 'the', 'total', 'delay', 'for', 'the', 'data', 'aggregation', 'is', 'within', 'the', 'given', 'bound', '.', 'we', 'formulate', 'the', 'problem', 'by', 'using', 'dynamic', 'programming', 'and', 'propose', 'an', 'optimal', 'algorithm', 'for', 'the', 'optimal', 'assignment', 'of', 'transmission', 'attempts', '.', 'based', 'on', 'the', 'analysis', 'of', 'the', 'optimal', 'solution', ',', 'we', 'propose', 'a', 'distributed', 'greedy', 'algorithm', '.', 'it', 'is', 'shown', 'to', 'have', 'a', 'similar', 'performance', 'as', 'the', 'optimal', 'one', '.']


In [364]:
len(dataset)

84

In [365]:
print(vars(dataset[0])['label'])

['algorithms', '__;__', 'performance', '__;__', 'data', 'aggregation', '__;__', 'sensor', 'networks']


In [366]:
SRC.build_vocab(dataset, vectors=GloVe(name='6B', dim=100))
TRG.build_vocab(dataset, vectors=GloVe(name='6B', dim=100))

In [367]:
print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (de) vocabulary: 2648
Unique tokens in target (en) vocabulary: 439


In [368]:
BATCH_SIZE = 8

train_iterator = BucketIterator(
    dataset,
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.text),
    sort_within_batch=True
)

In [369]:
import torch


for i, batch in enumerate(train_iterator):
    print(batch.text[0])
#     print(' '.join([SRC.vocab.itos[i.item()] for i in batch.text[0]]))
#     print(' '.join([TRG.vocab.itos[i.item()] for i in batch.label[0]]))
    
#     batch.text = torch.einsum('ij->ji', batch.text)
#     batch.label = torch.einsum('ij->ji', batch.label)
    
#     print(batch.text.shape)
#     print(' '.join([SRC.vocab.itos[i.item()] for i in batch.text[0]]))
#     print(' '.join([TRG.vocab.itos[i.item()] for i in batch.label[0]]))
    
    break
#     print(' '.join([SRC.vocab.itos[i.item()] for i in batch.text[0]])

tensor([ 305,    4, 2296,  849,    4,  155,  131,   14,  346,    3,  659,    3,
           5,  390, 1420,    6,   62, 1037,   13,    6,   83, 1920,   65, 1209,
          78,   41,   44,   22,   16,    8, 2137,  316,   17, 1624, 2171,   65,
        2417,  131,   17,   18, 1072,    7, 1793,    6, 1065,    7,  557,    3,
          15,   28, 1674,    8,  144, 2540,  121,    4,  155,  390,   17, 2616,
         557,   21,    8,  300,    4,   62, 1309,    4, 1897, 1444,    6,  155,
         131,  317,  432,  362,    6, 1534, 1065,    7, 1372,  970,    3,  390,
         386, 2532, 1084,   23,  155,  131,   18,  461,   31,    2, 2220,    4,
        1547,    6, 1599,    3,   15,  197,  390, 1389,    7, 2037, 1630,   65,
           2,  432,  362,    4,  155,  131,    6,   93,   54,  290, 2348, 1019,
          33, 2082, 2556, 2595,  863,    3,    8, 2533,    4, 1309,    4,  155,
         131,   10,   35,    6, 1501,    9, 2048, 1356,   18,  207,    3,    2,
          35,  121,   26,   25,   58,   

In [405]:
from src.keyword.data.graph_util import build_graph, normalize_graph

def batch_graph(grhs):
    """ batch a list of graphs
    @param grhs: list(tensor,...) 
    """
    b = len(grhs)  # batch size
    graph_dims = [len(g) for g in grhs]
    s = max(graph_dims)  # max seq length
    
    G = torch.zeros([b, s, s])
    for i, g in enumerate(grhs):
        s_ = graph_dims[i]
        G[i,:s_,:s_] = g
    return G

def build_graph_dataset(dataset: TabularDataset):
    GRH = RawField(postprocessing=batch_graph)

    for d in tqdm(dataset):
        token_len = len(d.text)
        G = build_graph(token_len, token_len)
        A_f = G['forward']
        A_b = G['backward']
        d.A_f = normalize_graph(A_f)
        d.A_b = normalize_graph(A_b)

    dataset.fields['A_f'] = GRH
    dataset.fields['A_b'] = GRH

    return dataset

In [406]:
dataset = build_graph_dataset(dataset)

100%|██████████| 84/84 [00:57<00:00,  1.45it/s]


In [407]:
print(dataset[0].A_f)
print(dataset[0].A_f.shape)

tensor([[0.1601, 0.1602, 0.0802,  ..., 0.0024, 0.0027, 0.0038],
        [0.0000, 0.1604, 0.1605,  ..., 0.0025, 0.0027, 0.0038],
        [0.0000, 0.0000, 0.1606,  ..., 0.0025, 0.0028, 0.0039],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.4000, 0.4472, 0.3162],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5000, 0.7071],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000]])
torch.Size([107, 107])


In [408]:
BATCH_SIZE = 8

train_iterator = Iterator(
    dataset,
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.text),
    sort_within_batch=True
)



In [401]:
def batch_graph(grhs):
    """ batch a list of graphs
    @param grhs: list(tensor,...) 
    """
    b = len(grhs)  # batch size
    graph_dims = [len(g) for g in grhs]
    s = max(graph_dims)  # max seq length
    
    G = torch.zeros([b, s, s])
    for i, g in enumerate(grhs):
        s_ = graph_dims[i]
        G[i,:s_,:s_] = g
    return G

In [413]:
for i, batch in enumerate(train_iterator):
#     print(batch_graph(batch.A_f).shape)
#     print(batch_graph(batch.A_f)[-1])
    print(batch.A_f.shape)
    print(batch.text.shape)
    for j in range(BATCH_SIZE):
        print(batch.text[j].shape)
    for j in range(BATCH_SIZE):
        print(batch.A_f[j].shape)
    break

torch.Size([8, 244, 244])
torch.Size([8, 244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])
torch.Size([244, 244])


In [462]:
SRC = Field(tokenize=lambda x: x.split(" "), lower=True, batch_first=True)
TRG = Field(tokenize=lambda x: x.split(" "), init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)


train_dataset, valid_dataset, test_dataset = TabularDataset.splits(
    path='../rsc/preprocessed',
    train='kp20k.train_100_lines.json',
    validation='kp20k.valid_100_lines.json',
    test='kp20k.test_100_lines.json',
    format='json',
    fields={
        'doc_words':('text', SRC), 
        'keyphrases': ('label', TRG)
    }
)

train_dataset = build_graph_dataset(train_dataset)

100%|██████████| 84/84 [00:47<00:00,  1.76it/s]


In [463]:
SRC.build_vocab(dataset)
TRG.build_vocab(dataset)

In [477]:
import torch

torch.save(list(train_dataset), './train_dataset.pt')
# torch.save(SRC, 'SRC.data')
# torch.save(TRG, 'TRG.vocab')

In [481]:
!pip install dill



In [486]:
import dill


with open('./train_dataset.pkl', 'wb') as f:
    dill.dump(list(train_dataset), f)
    f.close()
    
with open('./SRC.pkl', 'wb') as f:
    dill.dump(SRC, f)
    f.close()

with open('./TRG.pkl', 'wb') as f:
    dill.dump(TRG, f)
    f.close()

In [487]:
with open('./SRC.pkl', 'rb') as f:
    src_data = dill.load(f)
    f.close()

with open('./TRG.pkl', 'rb') as f:
    trg_data = dill.load(f)
    f.close()
    
with open('./train_dataset.pkl', 'rb') as f:
    loaded_dataset = dill.load(f)
    f.close()


In [490]:
GRH = RawField(postprocessing=None)

data_fields = [('text', src_data), ('label', trg_data), ('A_f', GRH), ('A_b', GRH)]

load = Dataset(loaded_dataset, data_fields)

In [491]:
it = Iterator(
        dataset=load,
        batch_size=4,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
    )



In [492]:
for batch in it:
    print(batch)
    break


[torchtext.data.batch.Batch of size 4]
	[.text]:[torch.LongTensor of size 4x196]
	[.label]:[torch.LongTensor of size 4x11]
	[.A_f]:[tensor([[0.1459, 0.1460, 0.0730,  ..., 0.0013, 0.0014, 0.0020],
        [0.0000, 0.1460, 0.1461,  ..., 0.0013, 0.0014, 0.0020],
        [0.0000, 0.0000, 0.1461,  ..., 0.0013, 0.0014, 0.0020],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.4000, 0.4472, 0.3162],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5000, 0.7071],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000]]), tensor([[0.1463, 0.1463, 0.0732,  ..., 0.0013, 0.0014, 0.0020],
        [0.0000, 0.1464, 0.1464,  ..., 0.0013, 0.0014, 0.0020],
        [0.0000, 0.0000, 0.1465,  ..., 0.0013, 0.0014, 0.0020],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.4000, 0.4472, 0.3162],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5000, 0.7071],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000]]), tensor([[0.1538, 0.1539, 0.0770,  ..., 0.0018, 0.0020, 0.0029],
     



In [471]:
from torchtext.data import Dataset

# data_fields = [('text', SRC), ('label', TGT), ('A_f', GRH), ('A_b', GRH)]

# dataset = (Dataset(torch.load('./train_dataset.pt'), data_fields))
dataset = torch.load('./train_dataset.pt')

In [472]:
it = Iterator(
        dataset=dataset,
        batch_size=4,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
    )

In [473]:
for batch in it:
    print(batch)
    break

AttributeError: 'list' object has no attribute 'fields'