In [1]:
!pip install transformers
!git clone https://github.com/hkbae20/npex20.git

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 3.2MB/s 
[?25hCollecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 7.2MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 10.3MB/s 
[?25hCollecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (

In [2]:
import os
import json
import random
import logging
import sys

from itertools import chain
from tqdm.notebook import tqdm

from npex20.utils.data import pad_ids, truncate_sequences
from npex20.scripts.dataset_walker import DatasetWalker
from npex20.scripts.knowledge_reader import KnowledgeReader

from npex20.dataset import BaseDataset, SPECIAL_TOKENS

import torch

from transformers import (
    AdamW,
    AutoConfig,
    AutoTokenizer,
    GPT2DoubleHeadsModel,
    GPT2LMHeadModel,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)

logger = logging.getLogger(__name__)


# Track 1: Knowledge Selection

In [3]:
class KnowledgeSelectionDataset(BaseDataset):
    def __init__(self, args, tokenizer, split_type, labels=True, labels_file=None):
        super(KnowledgeSelectionDataset, self).__init__(args, tokenizer, split_type, labels, labels_file)
        if self.args.negative_sample_method not in ["all", "mix", "oracle"]:
            raise ValueError("negative_sample_method must be all, mix, or oracle, got %s" % self.args.negative_sample_method)

    def _knowledge_to_string(self, doc, name=""):
        join_str = " %s " % self.knowledge_sep_token
        return join_str.join([name, doc["title"], doc["body"]])

    def __getitem__(self, index):
        example = self.examples[index]

        this_inst = {
            "dialog_id": example["dialog_id"],
            "input_ids": [],
            "token_type_ids": [],
            "mc_token_ids": []
        }

        if self.split_type != "train":
            # if eval_all_snippets is set, we use all snippets as candidates
            if self.args.eval_all_snippets:
                candidates = list(self.snippets.keys())
            else:
                candidates = example["candidates"]
        else:
            if self.args.negative_sample_method == "all":
                candidates = list(self.snippets.keys())
            elif self.args.negative_sample_method == "mix":
                candidates = example["candidates"] + random.sample(list(self.snippets.keys()), k=len(example["candidates"]))
            elif self.args.negative_sample_method == "oracle":
                candidates = example["candidates"]
            else: 
                raise ValueError("negative_sample_method must be all, mix, or oracle, got %s" % self.args.negative_sample_method)

        candidate_keys = candidates
        this_inst["candidate_keys"] = candidate_keys
        candidates = [self.snippets[cand_key] for cand_key in candidates]

        if self.split_type == "train":
            candidates = self._shrink_label_cands(example["knowledge"], candidates)
            # candidates: [examples["knowledge"] + neg_sampled_knowledge]

        label_idx = candidates.index(example["knowledge"])
            
        this_inst["label_idx"] = label_idx
        for cand in candidates:
            instance, _ = self.build_input_from_segments(
                cand,
                example["history"]
            )
            this_inst["input_ids"].append(instance["input_ids"])
            this_inst["token_type_ids"].append(instance["token_type_ids"])
            this_inst["mc_token_ids"].append(instance["mc_token_ids"])

        return this_inst

    def build_input_from_segments(self, knowledge, history):
        """ Build a sequence of input from 2 segments: knowledge and history"""
        instance = {}

        sequence = [[self.bos]] + history
        sequence_with_speaker = [
            [self.speaker1 if (len(sequence) - i) % 2 == 0 else self.speaker2] + s
            for i, s in enumerate(sequence[1:])
        ]
        sequence = [sequence[0]] + sequence_with_speaker + [[self.knowledge_tag] + knowledge + [self.eos]]

        instance["input_ids"] = list(chain(*sequence))
        instance["token_type_ids"] = [0 for s in sequence[:-1] for _ in s] + [1 for _ in sequence[-1]]
        instance["mc_token_ids"] = len(instance["input_ids"]) - 1

        return instance, sequence
    
    def _shrink_label_cands(self, label, candidates):
        shrunk_label_cands = candidates.copy()
        shrunk_label_cands.remove(label)
        shrunk_label_cands = random.sample(shrunk_label_cands, k=self.args.n_candidates-1)
        shrunk_label_cands.append(label)
        random.shuffle(shrunk_label_cands)

        return shrunk_label_cands

    def collate_fn(self, batch):
        input_ids = [ids for ins in batch for ids in ins["input_ids"]]
        token_type_ids = [ids for ins in batch for ids in ins["token_type_ids"]]
        mc_token_ids = [id for ins in batch for id in ins["mc_token_ids"]]
        label_idx = [ins["label_idx"] for ins in batch]

        data_info = {
            "dialog_ids": [ins["dialog_id"] for ins in batch],
            "candidate_keys": [ins["candidate_keys"] for ins in batch]
        }

        batch_size = len(batch)
        n_candidates = len(batch[0]["input_ids"])
        input_ids = torch.tensor(
            pad_ids(input_ids, self.pad)
        ).view(batch_size, n_candidates, -1)
        
        token_type_ids = torch.tensor(
            pad_ids(token_type_ids, self.pad)
        ).view(batch_size, n_candidates, -1)

        lm_labels = torch.full_like(input_ids, -100)
        mc_token_ids = torch.tensor(mc_token_ids).view(batch_size, n_candidates)
        label_idx = torch.tensor(label_idx)

        return input_ids, token_type_ids, mc_token_ids, lm_labels, label_idx, data_info


## Dataset arguments

In [4]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        
args = Namespace(
        dataroot = 'npex20/data',
        task = "selection",
        history_max_tokens = 128,
        history_max_utterances = 10000,
        knowledge_file = "knowledge.json",
        knowledge_max_tokens = 128,
        n_candidates = 3,
        negative_sample_method = "mix",
        eval_all_snippets = None, 
        local_rank = -1
)

In [5]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [6]:
train_dataset = KnowledgeSelectionDataset(args, tokenizer, split_type="train")

HBox(children=(FloatProgress(value=0.0, max=30000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30000.0), HTML(value='')))




In [7]:
# train_dataset[1]

In [8]:
train_dataset.collate_fn([train_dataset[1]])

(tensor([[[  100,   100,  2833,  1999,  2568,  1029,   100,  1045,  1005,  1040,
            2066,  2000,  2031,  2070,  2822,  2833,  1012,   100,  2008, 25142,
            2091,  1996,  4825,  9804,  2000,  2184,  1012,  2003,  2045,  1037,
            3976,  2846,  2017,  2052,  2066,  2000,  2994,  1999,  1029,   100,
            1045,  2572,  2559,  2005,  1037, 17844, 21125,  2173,  2000,  4521,
            1010,  1045,  2572,  2036,  2559,  2000,  2338,  1037,  2282,  1999,
            1996,  2958,  4113,  2160,  3309,  1012,   100,  2029,  5246,  2097,
            2017,  2022,  6595,  2012,  1996,  2958,  4113,  2282,  2160,  1029,
             100,  2077,  1045, 10797,  1045,  2031,  1037,  2261,  3980,  1012,
            2054,  2181,  2003,  1996,  3309,  2284,  1999,  1029,   100,  1996,
            3309,  2003,  1999,  1996,  2148,  2181,  1012,   100,  2079,  2027,
            2031,  2393,  2005,  9776,  5581,  1029,   100,  2748,  1012,  2023,
            3200,  2038,  78

In [9]:
print(train_dataset.examples[1].keys())
print()
for key in train_dataset.examples[1].keys():
    print(key)
    print(train_dataset.examples[1][key])
    print()

dict_keys(['history', 'knowledge', 'candidates', 'response', 'response_text', 'label', 'knowledge_seeking', 'dialog_id'])

history
[[2833, 1999, 2568, 1029], [1045, 1005, 1040, 2066, 2000, 2031, 2070, 2822, 2833, 1012], [2008, 25142, 2091, 1996, 4825, 9804, 2000, 2184, 1012, 2003, 2045, 1037, 3976, 2846, 2017, 2052, 2066, 2000, 2994, 1999, 1029], [1045, 2572, 2559, 2005, 1037, 17844, 21125, 2173, 2000, 4521, 1010, 1045, 2572, 2036, 2559, 2000, 2338, 1037, 2282, 1999, 1996, 2958, 4113, 2160, 3309, 1012], [2029, 5246, 2097, 2017, 2022, 6595, 2012, 1996, 2958, 4113, 2282, 2160, 1029], [2077, 1045, 10797, 1045, 2031, 1037, 2261, 3980, 1012, 2054, 2181, 2003, 1996, 3309, 2284, 1999, 1029], [1996, 3309, 2003, 1999, 1996, 2148, 2181, 1012], [2079, 2027, 2031, 2393, 2005, 9776, 5581, 1029], [2748, 1012, 2023, 3200, 2038, 7801, 5581, 1012, 2151, 2060, 3160, 1029], [2003, 1037, 9425, 4003, 11701, 2005, 2023, 21725, 1029]]

knowledge
[2958, 4113, 2160, 1026, 3716, 1035, 19802, 1028, 2054, 7909, 7

In [10]:
" ".join(tokenizer.convert_ids_to_tokens(train_dataset.snippets["hotel__11__0"]))

'bridge guest house < knowledge _ sep > are pets allowed here ? < knowledge _ sep > no , pets are not allowed at this property .'

# Track 2 : Response generation

In [11]:
class ResponseGenerationDataset(BaseDataset):
    def __init__(self, args, tokenizer, split_type, labels=True, labels_file=None):
        super(ResponseGenerationDataset, self).__init__(args, tokenizer, split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        instance, _ = self.build_input_from_segments(
            example["knowledge"],
            example["history"],
            example["response"]
        )
        return instance

    def collate_fn(self, batch):
        input_ids = [ins["input_ids"] for ins in batch]
        token_type_ids = [ins["token_type_ids"] for ins in batch]
        lm_labels = [ins["lm_labels"] for ins in batch]

        input_ids = torch.tensor(pad_ids(input_ids, self.pad))
        token_type_ids = torch.tensor(pad_ids(token_type_ids, self.pad))
        lm_labels = torch.tensor(pad_ids(lm_labels, -100))

        return input_ids, token_type_ids, lm_labels


class ResponseGenerationEvalDataset(BaseDataset):
    def __init__(self, args, tokenizer, split_type, labels=True, labels_file=None):
        super(ResponseGenerationEvalDataset, self).__init__(args, tokenizer, split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        return example

    def collate_fn(self, batch):
        return batch


In [12]:
args.task = "generation"
args.labels_file = "npex20/data/val/labels.json"

In [13]:
generation_datset = ResponseGenerationDataset(args, tokenizer, split_type="train")

HBox(children=(FloatProgress(value=0.0, max=30000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30000.0), HTML(value='')))




In [14]:
generation_datset[1].keys()

dict_keys(['input_ids', 'token_type_ids', 'mc_token_ids', 'lm_labels'])

In [15]:
generation_eval_datset = ResponseGenerationEvalDataset(args, tokenizer, split_type="val",
                                                      labels_file=args.labels_file)

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [16]:
generation_eval_datset[1].keys()

dict_keys(['history', 'knowledge', 'candidates', 'response', 'response_text', 'label', 'knowledge_seeking', 'dialog_id'])