In [59]:
import json
import logging
import os
from argparse import Namespace

import click
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import WEIGHTS_NAME

from luke.luke_utils.entity_vocab import MASK_TOKEN

from luke.utils import set_seed
from luke.utils.trainer import Trainer, trainer_args
from luke.model import LukeForRelationClassification
from luke.re_utils import HEAD_TOKEN, TAIL_TOKEN, convert_examples_to_features, DatasetProcessor
from transformers.tokenization_roberta import RobertaTokenizer

import numpy as np

In [13]:
# !pip install wikipedia2vec

In [52]:
# Initial values (hardcode) for now
class params:
    def __init__(self):
        self.data_dir = "luke/data/tacred/json"
        self.do_train = "--no-train"
        self.train_batch_size = 4
        self.num_train_epochs = 5.0
        self.do_val = "--no-eval"
        self.eval_batch_size = 128
        self.seed = 42
        self.bert_model_name = "roberta-large"
        self.max_mention_length = 30
        self.local_rank = -1
        self.tokenizer =  RobertaTokenizer.from_pretrained(self.bert_model_name)
#         self.tokenizer = {"max_len": 512, "bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>", "init_inputs": []}


In [53]:
args = params()
args.tokenizer.pad_token_id
logger = logging.getLogger(__name__)

In [54]:
def load_and_cache_examples(args, fold="train"):

    processor = DatasetProcessor()
    if fold == "train":
        examples = processor.get_train_examples(args.data_dir)
    elif fold == "dev":
        examples = processor.get_dev_examples(args.data_dir)
    else:
        examples = processor.get_test_examples(args.data_dir)

    label_list = processor.get_label_list(args.data_dir)

    bert_model_name = args.bert_model_name

    cache_file = os.path.join(
        args.data_dir,
        "cached_" + "_".join((args.bert_model_name.split("-")[0], str(args.max_mention_length), fold)) + ".pkl",
    )
    if os.path.exists(cache_file):
        logger.info("Loading features from cached file %s", cache_file)
        features = torch.load(cache_file)
    else:
        logger.info("Creating features from dataset file")
        features = convert_examples_to_features(examples, label_list, args.tokenizer, args.max_mention_length)

        if args.local_rank in (-1, 0):
            torch.save(features, cache_file)

    
    def collate_fn(batch):
        def create_padded_sequence(attr_name, padding_value):
            tensors = [torch.tensor(getattr(o, attr_name), dtype=torch.long) for o in batch]
            return torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        return dict(
            word_ids=create_padded_sequence("word_ids", args.tokenizer.pad_token_id),
            word_attention_mask=create_padded_sequence("word_attention_mask", 0),
            word_segment_ids=create_padded_sequence("word_segment_ids", 0),
            entity_ids=create_padded_sequence("entity_ids", 0),
            entity_attention_mask=create_padded_sequence("entity_attention_mask", 0),
            entity_position_ids=create_padded_sequence("entity_position_ids", -1),
            entity_segment_ids=create_padded_sequence("entity_segment_ids", 0),
            label=torch.tensor([o.label for o in batch], dtype=torch.long),
        )

    if fold in ("dev", "test"):
        dataloader = DataLoader(features, batch_size=args.eval_batch_size, shuffle=False, collate_fn=collate_fn)
    else:
        if args.local_rank == -1:
            sampler = RandomSampler(features)
        else:
            sampler = DistributedSampler(features)
        dataloader = DataLoader(features, sampler=sampler, batch_size=args.train_batch_size, collate_fn=collate_fn)

    return dataloader, examples, features, label_list

In [55]:
dataloader, examples, features, label_list = load_and_cache_examples(args)

In [57]:
type(dataloader), type(examples), type(features), type(label_list)

(torch.utils.data.dataloader.DataLoader, list, list, list)

In [61]:
np.array(examples).shape, np.array(features).shape, np.array(label_list).shape, 

((68124,), (68124,), (42,))

In [65]:
print(label_list[0:20])

['no_relation', 'org:alternate_names', 'org:city_of_headquarters', 'org:country_of_headquarters', 'org:dissolved', 'org:founded', 'org:founded_by', 'org:member_of', 'org:members', 'org:number_of_employees/members', 'org:parents', 'org:political/religious_affiliation', 'org:shareholders', 'org:stateorprovince_of_headquarters', 'org:subsidiaries', 'org:top_members/employees', 'org:website', 'per:age', 'per:alternate_names', 'per:cause_of_death']
