In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
 %cd drive/MyDrive/Xpersona-master/crosslingual

/content/drive/MyDrive/Xpersona-master/crosslingual


In [None]:
from tqdm import tqdm
import json

lang_list = ["En", "Zh", "Fr", "Id", "It", "Jp", "Ko"]

for lang in lang_list:
    for split in ["train", "valid", "test"]:
        if lang == "En":
            file_name = "%s_persona_%s.json" % (lang, split)
        elif split == "train":
            file_name = "%s_persona_%s_corrected.json" % (lang, split)
        else:
            file_name = "%s_persona_split_%s_human_annotated.json" % (lang, split)
        print(file_name)
        with open("/content/drive/MyDrive/Xpersona-master/dataset/" + file_name, 'rb') as json_file:
            data = json.load(json_file)
            file_out_x = open("/content/drive/MyDrive/Xpersona-master/%s.x.%s" % (split, lang.lower()), "w")  # dialog history
            file_out_y = open("/content/drive/MyDrive/Xpersona-master/%s.y.%s" % (split, lang.lower()), "w")  # response
            for each_dialog in tqdm(data):
                # preprocess each dialogue
                persona_list = each_dialog["persona"]
                persona_str = ""
                # persona
                for persona in persona_list:
                    persona_str = persona_str + persona + " "

                # dialogue
                dialogue_tuples = each_dialog["dialogue"]
                turns = []
                for tuple_ in dialogue_tuples:
                    for turn in tuple_:
                        turns.append(turn)

                for idx in range(len(turns)):
                    if idx % 2 == 0:
                        continue
                    if idx == 1:
                        user_turn = turns[idx-1]
                        system_turn = turns[idx]
                        file_out_x.write(persona_str + user_turn + "\n")
                        file_out_y.write(system_turn + "\n")
                    else:
                        user_turn1 = turns[idx-3]
                        system_turn1 = turns[idx-2]
                        user_turn2 = turns[idx-1]
                        system_turn2 = turns[idx]
                        file_out_x.write(persona_str + user_turn1 + " " + system_turn1 + " " + user_turn2 + "\n")
                        file_out_y.write(system_turn2 + "\n")

            file_out_x.close()
            file_out_y.close()

En_persona_train.json


100%|██████████| 16878/16878 [00:00<00:00, 31283.86it/s]


En_persona_valid.json


100%|██████████| 1000/1000 [00:00<00:00, 41896.12it/s]


En_persona_test.json


100%|██████████| 1000/1000 [00:00<00:00, 28314.06it/s]


Zh_persona_train_corrected.json


100%|██████████| 16878/16878 [00:00<00:00, 27581.90it/s]


Zh_persona_split_valid_human_annotated.json


100%|██████████| 222/222 [00:00<00:00, 24538.44it/s]


Zh_persona_split_test_human_annotated.json


100%|██████████| 222/222 [00:00<00:00, 18538.91it/s]


Fr_persona_train_corrected.json


100%|██████████| 16878/16878 [00:00<00:00, 24501.80it/s]


Fr_persona_split_valid_human_annotated.json


100%|██████████| 248/248 [00:00<00:00, 27001.72it/s]


Fr_persona_split_test_human_annotated.json


100%|██████████| 249/249 [00:00<00:00, 17515.25it/s]


Id_persona_train_corrected.json


100%|██████████| 16878/16878 [00:00<00:00, 29786.65it/s]


Id_persona_split_valid_human_annotated.json


100%|██████████| 484/484 [00:00<00:00, 26655.68it/s]


Id_persona_split_test_human_annotated.json


100%|██████████| 484/484 [00:00<00:00, 24175.24it/s]


It_persona_train_corrected.json


100%|██████████| 16878/16878 [00:01<00:00, 15146.85it/s]


It_persona_split_valid_human_annotated.json


100%|██████████| 140/140 [00:00<00:00, 16843.63it/s]


It_persona_split_test_human_annotated.json


100%|██████████| 140/140 [00:00<00:00, 12641.33it/s]


Jp_persona_train_corrected.json


100%|██████████| 16878/16878 [00:00<00:00, 21601.86it/s]


Jp_persona_split_valid_human_annotated.json


100%|██████████| 275/275 [00:00<00:00, 13024.47it/s]


Jp_persona_split_test_human_annotated.json


100%|██████████| 275/275 [00:00<00:00, 17683.65it/s]


Ko_persona_train_corrected.json


100%|██████████| 16878/16878 [00:00<00:00, 19991.66it/s]


Ko_persona_split_valid_human_annotated.json


100%|██████████| 299/299 [00:00<00:00, 18256.81it/s]


Ko_persona_split_test_human_annotated.json


100%|██████████| 300/300 [00:00<00:00, 16211.12it/s]


In [None]:
import logging
import time
from datetime import timedelta


class LogFormatter():

    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message) if message else ''


def create_logger(filepath, rank):
    """
    Create a logger.
    Use a different log file for each process.
    """
    # create log formatter
    log_formatter = LogFormatter()

    # create file handler and set level to debug
    if filepath is not None:
        if rank > 0:
            filepath = '%s-%i' % (filepath, rank)
        file_handler = logging.FileHandler(filepath, "a")
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(log_formatter)

    # create console handler and set level to info
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(log_formatter)

    # create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    if filepath is not None:
        logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    return logger

In [None]:
import os
import re
import sys
import pickle
import random
import getpass
import argparse
import subprocess
import numpy as np
import torch




FALSY_STRINGS = {'off', 'false', '0'}
TRUTHY_STRINGS = {'on', 'true', '1'}

DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser()
DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt', "lambda_s2slm"]


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def bool_flag(s):
    """
    Parse boolean arguments from the command line.
    """
    if s.lower() in FALSY_STRINGS:
        return False
    elif s.lower() in TRUTHY_STRINGS:
        return True
    else:
        raise argparse.ArgumentTypeError("Invalid value for a boolean flag!")


def initialize_exp(params):
    """
    Initialize the experience:
    - dump parameters
    - create a logger
    """
    # dump parameters
    get_dump_path(params)
    pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb'))

    # get running command
    command = ["python", sys.argv[0]]
    for x in sys.argv[1:]:
        if x.startswith('--'):
            assert '"' not in x and "'" not in x
            command.append(x)
        else:
            assert "'" not in x
            if re.match('^[a-zA-Z0-9_]+$', x):
                command.append("%s" % x)
            else:
                command.append("'%s'" % x)
    command = ' '.join(command)
    params.command = command + ' --exp_id "%s"' % params.exp_id

    # check experiment name
    assert len(params.exp_name.strip()) > 0

    # create a logger
    logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0))
    logger.info("============ Initialized logger ============")
    logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
    logger.info("The experiment will be stored in %s\n" % params.dump_path)
    logger.info("Running command: %s" % command)
    logger.info("")
    return logger


def get_dump_path(params):
    """
    Create a directory to store the experiment.
    """
    dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path
    assert len(params.exp_name) > 0

    # create the sweep path if it does not exist
    sweep_path = os.path.join(dump_path, params.exp_name)
    if not os.path.exists(sweep_path):
        subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()

    # create an ID for the job if it is not given in the parameters.
    # if we run on the cluster, the job ID is the one of Chronos.
    # otherwise, it is randomly generated
    if params.exp_id == '':
        chronos_job_id = os.environ.get('CHRONOS_JOB_ID')
        slurm_job_id = os.environ.get('SLURM_JOB_ID')
        assert chronos_job_id is None or slurm_job_id is None
        exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id
        if exp_id is None:
            chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
            while True:
                exp_id = ''.join(random.choice(chars) for _ in range(10))
                if not os.path.isdir(os.path.join(sweep_path, exp_id)):
                    break
        else:
            assert exp_id.isdigit()
        params.exp_id = exp_id

    # create the dump folder / update parameters
    params.dump_path = os.path.join(sweep_path, params.exp_id)
    if not os.path.isdir(params.dump_path):
        subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()


def to_cuda(*args):
    """
    Move tensors to CUDA.
    """
    return [None if x is None else x.cuda() for x in args]


def restore_segmentation(path):
    """
    Take a file segmented with BPE and restore it to its original segmentation.
    """
    assert os.path.isfile(path)
    restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s"
    subprocess.Popen(restore_cmd % path, shell=True).wait()


def parse_lambda_config(params):
    """
    Parse the configuration of lambda coefficient (for scheduling).
    x = "3"                  # lambda will be a constant equal to x
    x = "0:1,1000:0"         # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations
    x = "0:0,1000:0,2000:1"  # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000
    """
    for name in DYNAMIC_COEFF:
        x = getattr(params, name)
        split = x.split(',')
        if len(split) == 1:
            setattr(params, name, float(x))
            setattr(params, name + '_config', None)
        else:
            split = [s.split(':') for s in split]
            assert all(len(s) == 2 for s in split)
            assert all(k.isdigit() for k, _ in split)
            assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
            setattr(params, name, float(split[0][1]))
            setattr(params, name + '_config', [(int(k), float(v)) for k, v in split])


def get_lambda_value(config, n_iter):
    """
    Compute a lambda value according to its schedule configuration.
    """
    ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
    if len(ranges) == 0:
        assert n_iter >= config[-1][0]
        return config[-1][1]
    assert len(ranges) == 1
    i = ranges[0]
    x_a, y_a = config[i]
    x_b, y_b = config[i + 1]
    return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)


def update_lambdas(params, n_iter):
    """
    Update all lambda coefficients.
    """
    for name in DYNAMIC_COEFF:
        config = getattr(params, name + '_config')
        if config is not None:
            setattr(params, name, get_lambda_value(config, n_iter))


def set_sampling_probs(data, params):
    """
    Set the probability of sampling specific languages / language pairs during training.
    """
    coeff = params.lg_sampling_factor
    if coeff == -1:
        return
    assert coeff > 0

    # monolingual data
    params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v]
    if len(params.mono_list) > 0:
        probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list])
        probs /= probs.sum()
        probs = np.array([p ** coeff for p in probs])
        probs /= probs.sum()
        params.mono_probs = probs

    # parallel data
    params.para_list = [k for k, v in data['para'].items() if 'train' in v]
    if len(params.para_list) > 0:
        probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list])
        probs /= probs.sum()
        probs = np.array([p ** coeff for p in probs])
        probs /= probs.sum()
        params.para_probs = probs


def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions):
    """
    Concat batches with different languages.
    """
    assert reset_positions is False or lang1_id != lang2_id
    lengths = len1 + len2
    if not reset_positions:
        lengths -= 1
    slen, bs = lengths.max().item(), lengths.size(0)

    x = x1.new(slen, bs).fill_(pad_idx)
    x[:len1.max().item()].copy_(x1)
    positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device)
    langs = x1.new(slen, bs).fill_(lang1_id)

    for i in range(bs):
        l1 = len1[i] if reset_positions else len1[i] - 1
        x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i])
        if reset_positions:
            positions[l1:, i] -= len1[i]
        langs[l1:, i] = lang2_id

    assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs

    return x, lengths, positions, langs


def truncate(x, lengths, max_len, eos_index):
    """
    Truncate long sentences.
    """
    if lengths.max().item() > max_len:
        x = x[:max_len].clone()
        lengths = lengths.clone()
        for i in range(len(lengths)):
            if lengths[i] > max_len:
                lengths[i] = max_len
                x[max_len - 1, i] = eos_index
    return x, lengths


def shuf_order(langs, params=None, n=5):
    """
    Randomize training order.
    """
    if len(langs) == 0:
        return []

    if params is None:
        return [langs[i] for i in np.random.permutation(len(langs))]

    # sample monolingual and parallel languages separately
    mono = [l1 for l1, l2 in langs if l2 is None]
    para = [(l1, l2) for l1, l2 in langs if l2 is not None]

    # uniform / weighted sampling
    if params.lg_sampling_factor == -1:
        p_mono = None
        p_para = None
    else:
        p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono])
        p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para])
        p_mono = p_mono / p_mono.sum()
        p_para = p_para / p_para.sum()

    s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else []
    s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else []

    assert len(s_mono) + len(s_para) > 0
    return [(lang, None) for lang in s_mono] + s_para


def find_modules(module, module_name, module_instance, found):
    """
    Recursively find all instances of a specific module inside a module.
    """
    if isinstance(module, module_instance):
        found.append((module_name, module))
    else:
        for name, child in module.named_children():
            name = ('%s[%s]' if name.isdigit() else '%s.%s') % (module_name, name)
            find_modules(child, name, module_instance, found)


def mask_out_v2(params, x, lens, enc_lens=None):
    """
    Decide of random words to mask out, and what target they get assigned.
    Args:
        enc_lens: lengths of encoder parts, if set, only target seq part will be masked.
    """
    slen, bs = x.size()

    if params.sample_alpha == 0:
        pred_mask = np.random.rand(slen, bs) <= params.word_pred
        # pred_mask = torch.from_numpy(pred_mask.astype(np.uint8))
        pred_mask = torch.from_numpy(pred_mask.astype(np.bool_))
    else:
        raise NotImplementedError
    
    # do not predict padding
    pred_mask[x == params.pad_index] = 0
    pred_mask[0] = 0
    # if enc_lens isn't None, do not predict src seq
    if enc_lens is not None:
        arng = torch.arange(slen)
        pred_mask[enc_lens[None, :] > arng[:, None]] = 0
    
    # TODO what if all words are not masked?
    _x_real = x[pred_mask]
    _x_mask = _x_real.clone().fill_(params.mask_index)
    x = x.masked_scatter(pred_mask, _x_mask)

    return x, _x_real, pred_mask

In [None]:
import re
import math
import inspect

import torch
from torch import optim


class Adam(optim.Optimizer):
    """
    Same as https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py,
    without amsgrad, with step in a tensor, and states initialization in __init__.
    It was important to add `.item()` in `state['step'].item()`.
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0  # torch.zeros(1)
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)

    def __setstate__(self, state):
        super().__setstate__(state)

    def step(self, closure=None):
        """
        Step.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # if group['weight_decay'] != 0:
                #     grad.add_(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['eps'])
                # denom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']  # .item()
                bias_correction2 = 1 - beta2 ** state['step']  # .item()
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'] * group['lr'], p.data)

                p.data.addcdiv_(-step_size, exp_avg, denom)

        return loss


class AdamInverseSqrtWithWarmup(Adam):
    """
    Decay the LR based on the inverse square root of the update number.
    We also support a warmup phase where we linearly increase the learning rate
    from some initial learning rate (`warmup-init-lr`) until the configured
    learning rate (`lr`). Thereafter we decay proportional to the number of
    updates, with a decay factor set to align with the configured learning rate.
    During warmup:
        lrs = torch.linspace(warmup_init_lr, lr, warmup_updates)
        lr = lrs[update_num]
    After warmup:
        lr = decay_factor / sqrt(update_num)
    where
        decay_factor = lr * sqrt(warmup_updates)
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7,
                 exp_factor=0.5):
        super().__init__(
            params,
            lr=warmup_init_lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )

        # linearly warmup for the first warmup_updates
        self.warmup_updates = warmup_updates
        self.warmup_init_lr = warmup_init_lr
        warmup_end_lr = lr
        self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates

        # then, decay prop. to the inverse square root of the update number
        self.exp_factor = exp_factor
        self.decay_factor = warmup_end_lr * warmup_updates ** self.exp_factor

        # total number of updates
        for param_group in self.param_groups:
            param_group['num_updates'] = 0

    def get_lr_for_step(self, num_updates):
        if num_updates < self.warmup_updates:
            return self.warmup_init_lr + num_updates * self.lr_step
        else:
            return self.decay_factor * (num_updates ** -self.exp_factor)

    def step(self, closure=None):
        super().step(closure)
        for param_group in self.param_groups:
            param_group['num_updates'] += 1
            param_group['lr'] = self.get_lr_for_step(param_group['num_updates'])


class AdamCosineWithWarmup(Adam):
    """
    Assign LR based on a cyclical schedule that follows the cosine function.
    See https://arxiv.org/pdf/1608.03983.pdf for details.
    We also support a warmup phase where we linearly increase the learning rate
    from some initial learning rate (``--warmup-init-lr``) until the configured
    learning rate (``--lr``).
    During warmup::
      lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
      lr = lrs[update_num]
    After warmup::
      lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
    where ``t_curr`` is current percentage of updates within the current period
    range and ``t_i`` is the current period range, which is scaled by ``t_mul``
    after every iteration.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7,
                 min_lr=1e-9, init_period=1000000, period_mult=1, lr_shrink=0.75):
        super().__init__(
            params,
            lr=warmup_init_lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )

        # linearly warmup for the first warmup_updates
        self.warmup_updates = warmup_updates
        self.warmup_init_lr = warmup_init_lr
        warmup_end_lr = lr
        self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates

        # then, apply cosine scheduler
        self.min_lr = min_lr
        self.max_lr = lr
        self.period = init_period
        self.period_mult = period_mult
        self.lr_shrink = lr_shrink

        # total number of updates
        for param_group in self.param_groups:
            param_group['num_updates'] = 0

    def get_lr_for_step(self, num_updates):
        if num_updates < self.warmup_updates:
            return self.warmup_init_lr + num_updates * self.lr_step
        else:
            t = num_updates - self.warmup_updates
            if self.period_mult == 1:
                pid = math.floor(t / self.period)
                t_i = self.period
                t_curr = t - (self.period * pid)
            else:
                pid = math.floor(math.log(1 - t / self.period * (1 - self.period_mult), self.period_mult))
                t_i = self.period * (self.period_mult ** pid)
                t_curr = t - (1 - self.period_mult ** pid) / (1 - self.period_mult) * self.period
            lr_shrink = self.lr_shrink ** pid
            min_lr = self.min_lr * lr_shrink
            max_lr = self.max_lr * lr_shrink
            return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))

    def step(self, closure=None):
        super().step(closure)
        for param_group in self.param_groups:
            param_group['num_updates'] += 1
            param_group['lr'] = self.get_lr_for_step(param_group['num_updates'])


def get_optimizer(parameters, s):
    """
    Parse optimizer parameters.
    Input should be of the form:
        - "sgd,lr=0.01"
        - "adagrad,lr=0.1,lr_decay=0.05"
    """
    if "," in s:
        method = s[:s.find(',')]
        optim_params = {}
        for x in s[s.find(',') + 1:].split(','):
            split = x.split('=')
            assert len(split) == 2
            assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
            optim_params[split[0]] = float(split[1])
    else:
        method = s
        optim_params = {}

    if method == 'adadelta':
        optim_fn = optim.Adadelta
    elif method == 'adagrad':
        optim_fn = optim.Adagrad
    elif method == 'adam':
        optim_fn = Adam
        optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
        optim_params.pop('beta1', None)
        optim_params.pop('beta2', None)
    elif method == 'adam_inverse_sqrt':
        optim_fn = AdamInverseSqrtWithWarmup
        optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
        optim_params.pop('beta1', None)
        optim_params.pop('beta2', None)
    elif method == 'adam_cosine':
        optim_fn = AdamCosineWithWarmup
        optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
        optim_params.pop('beta1', None)
        optim_params.pop('beta2', None)
    elif method == 'adamax':
        optim_fn = optim.Adamax
    elif method == 'asgd':
        optim_fn = optim.ASGD
    elif method == 'rmsprop':
        optim_fn = optim.RMSprop
    elif method == 'rprop':
        optim_fn = optim.Rprop
    elif method == 'sgd':
        optim_fn = optim.SGD
        assert 'lr' in optim_params
    else:
        raise Exception('Unknown optimization method: "%s"' % method)

    # check that we give good parameters to the optimizer
    expected_args = inspect.getargspec(optim_fn.__init__)[0]
    assert expected_args[:2] == ['self', 'params']
    if not all(k in expected_args[2:] for k in optim_params.keys()):
        raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
            str(expected_args[2:]), str(optim_params.keys())))

    return optim_fn(parameters, **optim_params)

In [None]:
from logging import getLogger
import math
import numpy as np
import torch


logger = getLogger()


class StreamDataset(object):

    def __init__(self, sent, pos, bs, params):
        """
        Prepare batches for data iterator.
        """
        bptt = params.bptt
        self.eos = params.eos_index

        # checks
        assert len(pos) == (sent == self.eos).sum()
        assert len(pos) == (sent[pos[:, 1]] == self.eos).sum()

        n_tokens = len(sent)
        n_batches = math.ceil(n_tokens / (bs * bptt))
        t_size = n_batches * bptt * bs

        buffer = np.zeros(t_size, dtype=sent.dtype) + self.eos
        buffer[t_size - n_tokens:] = sent
        buffer = buffer.reshape((bs, n_batches * bptt)).T
        self.data = np.zeros((n_batches * bptt + 1, bs), dtype=sent.dtype) + self.eos
        self.data[1:] = buffer

        self.bptt = bptt
        self.n_tokens = n_tokens
        self.n_batches = n_batches
        self.n_sentences = len(pos)
        self.lengths = torch.LongTensor(bs).fill_(bptt)

    def __len__(self):
        """
        Number of sentences in the dataset.
        """
        return self.n_sentences

    def select_data(self, a, b):
        """
        Only select a subset of the dataset.
        """
        if not (0 <= a < b <= self.n_batches):
            logger.warning("Invalid split values: %i %i - %i" % (a, b, self.n_batches))
            return
        assert 0 <= a < b <= self.n_batches
        logger.info("Selecting batches from %i to %i ..." % (a, b))

        # sub-select
        self.data = self.data[a * self.bptt:b * self.bptt]
        self.n_batches = b - a
        self.n_sentences = (self.data == self.eos).sum().item()

    def get_iterator(self, shuffle, subsample=1):
        """
        Return a sentences iterator.
        """
        indexes = (np.random.permutation if shuffle else range)(self.n_batches // subsample)
        for i in indexes:
            a = self.bptt * i
            b = self.bptt * (i + 1)
            yield torch.from_numpy(self.data[a:b].astype(np.int64)), self.lengths


class Dataset(object):

    def __init__(self, sent, pos, params):

        self.eos_index = params.eos_index
        self.pad_index = params.pad_index
        self.batch_size = params.batch_size
        self.tokens_per_batch = params.tokens_per_batch
        self.max_batch_size = params.max_batch_size

        self.sent = sent
        self.pos = pos
        self.lengths = self.pos[:, 1] - self.pos[:, 0]

        # check number of sentences
        assert len(self.pos) == (self.sent == self.eos_index).sum()

        # # remove empty sentences
        # self.remove_empty_sentences()

        # sanity checks
        self.check()

    def __len__(self):
        """
        Number of sentences in the dataset.
        """
        return len(self.pos)

    def check(self):
        """
        Sanity checks.
        """
        eos = self.eos_index
        assert len(self.pos) == (self.sent[self.pos[:, 1]] == eos).sum()  # check sentences indices
        # assert self.lengths.min() > 0                                     # check empty sentences

    def batch_sentences(self, sentences):
        """
        Take as input a list of n sentences (torch.LongTensor vectors) and return
        a tensor of size (slen, n) where slen is the length of the longest
        sentence, and a vector lengths containing the length of each sentence.
        """
        # sentences = sorted(sentences, key=lambda x: len(x), reverse=True)
        lengths = torch.LongTensor([len(s) + 2 for s in sentences])
        sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index)

        sent[0] = self.eos_index
        for i, s in enumerate(sentences):
            if lengths[i] > 2:  # if sentence not empty
                sent[1:lengths[i] - 1, i].copy_(torch.from_numpy(s.astype(np.int64)))
            sent[lengths[i] - 1, i] = self.eos_index

        return sent, lengths

    def remove_empty_sentences(self):
        """
        Remove empty sentences.
        """
        init_size = len(self.pos)
        indices = np.arange(len(self.pos))
        indices = indices[self.lengths[indices] > 0]
        self.pos = self.pos[indices]
        self.lengths = self.pos[:, 1] - self.pos[:, 0]
        logger.info("Removed %i empty sentences." % (init_size - len(indices)))
        self.check()

    def remove_long_sentences(self, max_len):
        """
        Remove sentences exceeding a certain length.
        """
        assert max_len >= 0
        if max_len == 0:
            return
        init_size = len(self.pos)
        indices = np.arange(len(self.pos))
        indices = indices[self.lengths[indices] <= max_len]
        self.pos = self.pos[indices]
        self.lengths = self.pos[:, 1] - self.pos[:, 0]
        logger.info("Removed %i too long sentences." % (init_size - len(indices)))
        self.check()

    def select_data(self, a, b):
        """
        Only select a subset of the dataset.
        """
        assert 0 <= a < b <= len(self.pos)
        logger.info("Selecting sentences from %i to %i ..." % (a, b))

        # sub-select
        self.pos = self.pos[a:b]
        self.lengths = self.pos[:, 1] - self.pos[:, 0]

        # re-index
        min_pos = self.pos.min()
        max_pos = self.pos.max()
        self.pos -= min_pos
        self.sent = self.sent[min_pos:max_pos + 1]

        # sanity checks
        self.check()

    def get_batches_iterator(self, batches, return_indices):
        """
        Return a sentences iterator, given the associated sentence batches.
        """
        assert type(return_indices) is bool

        for sentence_ids in batches:
            if 0 < self.max_batch_size < len(sentence_ids):
                np.random.shuffle(sentence_ids)
                sentence_ids = sentence_ids[:self.max_batch_size]
            pos = self.pos[sentence_ids]
            sent = [self.sent[a:b] for a, b in pos]
            sent = self.batch_sentences(sent)
            yield (sent, sentence_ids) if return_indices else sent

    def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, seed=None, return_indices=False):
        """
        Return a sentences iterator.
        """
        assert seed is None or shuffle is True and type(seed) is int
        rng = np.random.RandomState(seed)
        n_sentences = len(self.pos) if n_sentences == -1 else n_sentences
        assert 0 < n_sentences <= len(self.pos)
        assert type(shuffle) is bool and type(group_by_size) is bool
        assert group_by_size is False or shuffle is True

        # sentence lengths
        lengths = self.lengths + 2

        # select sentences to iterate over
        if shuffle:
            indices = rng.permutation(len(self.pos))[:n_sentences]
        else:
            indices = np.arange(n_sentences)

        # group sentences by lengths
        if group_by_size:
            indices = indices[np.argsort(lengths[indices], kind='mergesort')]

        # create batches - either have a fixed number of sentences, or a similar number of tokens
        if self.tokens_per_batch == -1:
            batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
        else:
            batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch
            _, bounds = np.unique(batch_ids, return_index=True)
            batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
            if bounds[-1] < len(indices):
                batches.append(indices[bounds[-1]:])

        # optionally shuffle batches
        if shuffle:
            rng.shuffle(batches)

        # sanity checks
        assert n_sentences == sum([len(x) for x in batches])
        assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches])
        # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences))  # slow

        # return the iterator
        return self.get_batches_iterator(batches, return_indices)


class ParallelDataset(Dataset):

    def __init__(self, sent1, pos1, sent2, pos2, params):

        self.eos_index = params.eos_index
        self.pad_index = params.pad_index
        self.batch_size = params.batch_size
        self.tokens_per_batch = params.tokens_per_batch
        self.max_batch_size = params.max_batch_size

        self.sent1 = sent1
        self.sent2 = sent2
        self.pos1 = pos1
        self.pos2 = pos2
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]

        # check number of sentences
        assert len(self.pos1) == (self.sent1 == self.eos_index).sum()
        assert len(self.pos2) == (self.sent2 == self.eos_index).sum()

        # remove empty sentences
        self.remove_empty_sentences()

        # sanity checks
        self.check()

    def __len__(self):
        """
        Number of sentences in the dataset.
        """
        return len(self.pos1)

    def check(self):
        """
        Sanity checks.
        """
        eos = self.eos_index
        assert len(self.pos1) == len(self.pos2) > 0                          # check number of sentences
        assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum()  # check sentences indices
        assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum()  # check sentences indices
        assert eos <= self.sent1.min() < self.sent1.max()                    # check dictionary indices
        assert eos <= self.sent2.min() < self.sent2.max()                    # check dictionary indices
        assert self.lengths1.min() > 0                                       # check empty sentences
        assert self.lengths2.min() > 0                                       # check empty sentences

    def remove_empty_sentences(self):
        """
        Remove empty sentences.
        """
        init_size = len(self.pos1)
        indices = np.arange(len(self.pos1))
        indices = indices[self.lengths1[indices] > 0]
        indices = indices[self.lengths2[indices] > 0]
        self.pos1 = self.pos1[indices]
        self.pos2 = self.pos2[indices]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        logger.info("Removed %i empty sentences." % (init_size - len(indices)))
        self.check()

    def remove_long_sentences(self, max_len):
        """
        Remove sentences exceeding a certain length.
        """
        assert max_len >= 0
        if max_len == 0:
            return
        init_size = len(self.pos1)
        indices = np.arange(len(self.pos1))
        indices = indices[self.lengths1[indices] <= max_len]
        indices = indices[self.lengths2[indices] <= max_len]
        self.pos1 = self.pos1[indices]
        self.pos2 = self.pos2[indices]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        logger.info("Removed %i too long sentences." % (init_size - len(indices)))
        self.check()
    
    def cut_long_sentences(self, max_len1, max_len2):
        assert max_len1 > 0 and max_len2 > 0

        def _cut(length, pos, max_len):
            # indices to cut
            indices = np.arange(len(pos))
            indices = indices[length[indices] > max_len]
            pos[indices, 1] = pos[indices, 0] + max_len
            length[indices] = max_len

        _cut(self.lengths1, self.pos1, max_len1)
        _cut(self.lengths2, self.pos2, max_len2)

    def select_data(self, a, b):
        """
        Only select a subset of the dataset.
        """
        assert 0 <= a < b <= len(self.pos1)
        logger.info("Selecting sentences from %i to %i ..." % (a, b))

        # sub-select
        self.pos1 = self.pos1[a:b]
        self.pos2 = self.pos2[a:b]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]

        # re-index
        min_pos1 = self.pos1.min()
        max_pos1 = self.pos1.max()
        min_pos2 = self.pos2.min()
        max_pos2 = self.pos2.max()
        self.pos1 -= min_pos1
        self.pos2 -= min_pos2
        self.sent1 = self.sent1[min_pos1:max_pos1 + 1]
        self.sent2 = self.sent2[min_pos2:max_pos2 + 1]

        # sanity checks
        # self.check()

    def get_batches_iterator(self, batches, return_indices):
        """
        Return a sentences iterator, given the associated sentence batches.
        """
        assert type(return_indices) is bool

        for sentence_ids in batches:
            if 0 < self.max_batch_size < len(sentence_ids):
                np.random.shuffle(sentence_ids)
                sentence_ids = sentence_ids[:self.max_batch_size]
            pos1 = self.pos1[sentence_ids]
            pos2 = self.pos2[sentence_ids]
            sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1])
            sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2])
            yield (sent1, sent2, sentence_ids) if return_indices else (sent1, sent2)

    def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False):
        """
        Return a sentences iterator.
        """
        n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences
        assert 0 < n_sentences <= len(self.pos1)
        assert type(shuffle) is bool and type(group_by_size) is bool

        # sentence lengths
        lengths = self.lengths1 + self.lengths2 + 4

        # select sentences to iterate over
        if shuffle:
            indices = np.random.permutation(len(self.pos1))[:n_sentences]
        else:
            indices = np.arange(n_sentences)

        # group sentences by lengths
        if group_by_size:
            indices = indices[np.argsort(lengths[indices], kind='mergesort')]

        # create batches - either have a fixed number of sentences, or a similar number of tokens
        if self.tokens_per_batch == -1:
            batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
        else:
            batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch
            _, bounds = np.unique(batch_ids, return_index=True)
            batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
            if bounds[-1] < len(indices):
                batches.append(indices[bounds[-1]:])

        # optionally shuffle batches
        if shuffle:
            np.random.shuffle(batches)

        # sanity checks
        assert n_sentences == sum([len(x) for x in batches])
        assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches])
        # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences))  # slow

        # return the iterator
        return self.get_batches_iterator(batches, return_indices)


class TripleDataset(Dataset):

    def __init__(self, sent1, pos1, sent2, pos2, sent3, pos3, params):

        self.eos_index = params.eos_index
        self.pad_index = params.pad_index
        self.batch_size = params.batch_size
        self.tokens_per_batch = params.tokens_per_batch
        self.max_batch_size = params.max_batch_size

        self.sent1 = sent1
        self.sent2 = sent2
        self.sent3 = sent3
        self.pos1 = pos1
        self.pos2 = pos2
        self.pos3 = pos3
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        self.lengths3 = self.pos3[:, 1] - self.pos3[:, 0]

        # check number of sentences
        assert len(self.pos1) == (self.sent1 == self.eos_index).sum()
        assert len(self.pos2) == (self.sent2 == self.eos_index).sum()
        assert len(self.pos3) == (self.sent3 == self.eos_index).sum()

        # remove empty sentences
        self.remove_empty_sentences()

        # sanity checks
        self.check()

    def __len__(self):
        """
        Number of sentences in the dataset.
        """
        return len(self.pos1)

    def check(self):
        """
        Sanity checks.
        """
        eos = self.eos_index
        assert len(self.pos1) == len(self.pos2) == len(self.pos3) > 0                          # check number of sentences
        assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum()  # check sentences indices
        assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum()  # check sentences indices
        assert len(self.pos3) == (self.sent3[self.pos3[:, 1]] == eos).sum()  # check sentences indices
        assert eos <= self.sent1.min() < self.sent1.max()                    # check dictionary indices
        assert eos <= self.sent2.min() < self.sent2.max()                    # check dictionary indices
        assert eos <= self.sent3.min() < self.sent3.max()                    # check dictionary indices
        assert self.lengths1.min() > 0                                       # check empty sentences
        assert self.lengths2.min() > 0                                       # check empty sentences
        assert self.lengths3.min() > 0                                       # check empty sentences


    def remove_empty_sentences(self):
        """
        Remove empty sentences.
        """
        init_size = len(self.pos1)
        indices = np.arange(len(self.pos1))
        indices = indices[self.lengths1[indices] > 0]
        indices = indices[self.lengths2[indices] > 0]
        indices = indices[self.lengths3[indices] > 0]
        self.pos1 = self.pos1[indices]
        self.pos2 = self.pos2[indices]
        self.pos3 = self.pos3[indices]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        self.lengths3 = self.pos3[:, 1] - self.pos3[:, 0]
        logger.info("Removed %i empty sentences." % (init_size - len(indices)))
        self.check()

    def remove_long_sentences(self, max_len):
        """
        Remove sentences exceeding a certain length.
        """
        assert max_len >= 0
        if max_len == 0:
            return
        init_size = len(self.pos1)
        indices = np.arange(len(self.pos1))
        indices = indices[self.lengths1[indices] <= max_len]
        indices = indices[self.lengths2[indices] <= max_len]
        indices = indices[self.lengths3[indices] <= max_len]
        self.pos1 = self.pos1[indices]
        self.pos2 = self.pos2[indices]
        self.pos3 = self.pos3[indices]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        self.lengths3 = self.pos3[:, 1] - self.pos3[:, 0]
        logger.info("Removed %i too long sentences." % (init_size - len(indices)))
        self.check()
    
    def cut_long_sentences(self, max_len1, max_len2, max_len3):
        assert max_len1 > 0 and max_len2 > 0 and max_len3 > 0

        def _cut(length, pos, max_len):
            # indices to cut
            indices = np.arange(len(pos))
            indices = indices[length[indices] > max_len]
            pos[indices, 1] = pos[indices, 0] + max_len
            length[indices] = max_len

        _cut(self.lengths1, self.pos1, max_len1)
        _cut(self.lengths2, self.pos2, max_len2)
        _cut(self.lengths3, self.pos3, max_len3)


    def select_data(self, a, b):
        """
        Only select a subset of the dataset.
        """
        assert 0 <= a < b <= len(self.pos1)
        logger.info("Selecting sentences from %i to %i ..." % (a, b))

        # sub-select
        self.pos1 = self.pos1[a:b]
        self.pos2 = self.pos2[a:b]
        self.pos3 = self.pos3[a:b]
        self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
        self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
        self.lengths3 = self.pos3[:, 1] - self.pos3[:, 0]

        # re-index
        min_pos1 = self.pos1.min()
        max_pos1 = self.pos1.max()
        min_pos2 = self.pos2.min()
        max_pos2 = self.pos2.max()
        min_pos3 = self.pos3.min()
        max_pos3 = self.pos3.max()
        self.pos1 -= min_pos1
        self.pos2 -= min_pos2
        self.pos3 -= min_pos3
        self.sent1 = self.sent1[min_pos1:max_pos1 + 1]
        self.sent2 = self.sent2[min_pos2:max_pos2 + 1]
        self.sent3 = self.sent3[min_pos3:max_pos3 + 1]

        # sanity checks
        self.check()

    def get_batches_iterator(self, batches, return_indices):
        """
        Return a sentences iterator, given the associated sentence batches.
        """
        assert type(return_indices) is bool

        for sentence_ids in batches:
            if 0 < self.max_batch_size < len(sentence_ids):
                np.random.shuffle(sentence_ids)
                sentence_ids = sentence_ids[:self.max_batch_size]
            pos1 = self.pos1[sentence_ids]
            pos2 = self.pos2[sentence_ids]
            pos3 = self.pos3[sentence_ids]
            sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1])
            sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2])
            sent3 = self.batch_sentences([self.sent3[a:b] for a, b in pos3])
            yield (sent1, sent2, sent3, sentence_ids) if return_indices else (sent1, sent2, sent3)

    def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False):
        """
        Return a sentences iterator.
        """
        n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences
        assert 0 < n_sentences <= len(self.pos1)
        assert type(shuffle) is bool and type(group_by_size) is bool

        # sentence lengths
        lengths = self.lengths1 + self.lengths2 + self.lengths3 + 6

        # select sentences to iterate over
        if shuffle:
            indices = np.random.permutation(len(self.pos1))[:n_sentences]
        else:
            indices = np.arange(n_sentences)

        # group sentences by lengths
        if group_by_size:
            indices = indices[np.argsort(lengths[indices], kind='mergesort')]

        # create batches - either have a fixed number of sentences, or a similar number of tokens
        if self.tokens_per_batch == -1:
            batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
        else:
            batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch
            _, bounds = np.unique(batch_ids, return_index=True)
            batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
            if bounds[-1] < len(indices):
                batches.append(indices[bounds[-1]:])

        # optionally shuffle batches
        if shuffle:
            np.random.shuffle(batches)

        # sanity checks
        assert n_sentences == sum([len(x) for x in batches])
        assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches])
        # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences))  # slow

        # return the iterator
        return self.get_batches_iterator(batches, return_indices)

In [None]:
import os
import numpy as np
import torch
from logging import getLogger


logger = getLogger()


BOS_WORD = '<s>'
EOS_WORD = '</s>'
PAD_WORD = '<pad>'
UNK_WORD = '<unk>'

SPECIAL_WORD = '<special%i>'
SPECIAL_WORDS = 10

SEP_WORD = SPECIAL_WORD % 0
MASK_WORD = SPECIAL_WORD % 1


class Dictionary(object):

    def __init__(self, id2word, word2id, counts):
        assert len(id2word) == len(word2id) == len(counts)
        self.id2word = id2word
        self.word2id = word2id
        self.counts = counts
        self.bos_index = word2id[BOS_WORD]
        self.eos_index = word2id[EOS_WORD]
        self.pad_index = word2id[PAD_WORD]
        self.unk_index = word2id[UNK_WORD]
        self.check_valid()

    def __len__(self):
        """
        Returns the number of words in the dictionary.
        """
        return len(self.id2word)

    def __getitem__(self, i):
        """
        Returns the word of the specified index.
        """
        return self.id2word[i]

    def __contains__(self, w):
        """
        Returns whether a word is in the dictionary.
        """
        return w in self.word2id

    def __eq__(self, y):
        """
        Compare this dictionary with another one.
        """
        self.check_valid()
        y.check_valid()
        if len(self.id2word) != len(y):
            return False
        return all(self.id2word[i] == y[i] for i in range(len(y)))

    def check_valid(self):
        """
        Check that the dictionary is valid.
        """
        assert self.bos_index == 0
        assert self.eos_index == 1
        assert self.pad_index == 2
        assert self.unk_index == 3
        assert all(self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS))
        assert len(self.id2word) == len(self.word2id) == len(self.counts)
        assert set(self.word2id.keys()) == set(self.counts.keys())
        for i in range(len(self.id2word)):
            assert self.word2id[self.id2word[i]] == i
        last_count = 1e18
        for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1):
            count = self.counts[self.id2word[i]]
            assert count <= last_count
            last_count = count

    def index(self, word, no_unk=False):
        """
        Returns the index of the specified word.
        """
        if no_unk:
            return self.word2id[word]
        else:
            return self.word2id.get(word, self.unk_index)

    def max_vocab(self, max_vocab):
        """
        Limit the vocabulary size.
        """
        assert max_vocab >= 1
        init_size = len(self)
        self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab}
        self.word2id = {v: k for k, v in self.id2word.items()}
        self.counts = {k: v for k, v in self.counts.items() if k in self.word2id}
        self.check_valid()
        logger.info("Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)."
                    % (max_vocab, init_size, len(self), init_size - len(self)))

    def min_count(self, min_count):
        """
        Threshold on the word frequency counts.
        """
        assert min_count >= 0
        init_size = len(self)
        self.id2word = {k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS}
        self.word2id = {v: k for k, v in self.id2word.items()}
        self.counts = {k: v for k, v in self.counts.items() if k in self.word2id}
        self.check_valid()
        logger.info("Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)."
                    % (min_count, init_size, len(self), init_size - len(self)))

    @staticmethod
    def read_vocab(vocab_path):
        """
        Create a dictionary from a vocabulary file.
        """
        skipped = 0
        assert os.path.isfile(vocab_path), vocab_path
        word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3}
        for i in range(SPECIAL_WORDS):
            word2id[SPECIAL_WORD % i] = 4 + i
        counts = {k: 0 for k in word2id.keys()}
        f = open(vocab_path, 'r', encoding='utf-8')
        for i, line in enumerate(f):
            if '\u2028' in line:
                skipped += 1
                continue
            line = line.rstrip().split()
            if len(line) != 2:
                skipped += 1
                continue
            assert len(line) == 2, (i, line)
            # assert line[0] not in word2id and line[1].isdigit(), (i, line)
            assert line[1].isdigit(), (i, line)
            if line[0] in word2id:
                skipped += 1
                print('%s already in vocab' % line[0])
                continue
            if not line[1].isdigit():
                skipped += 1
                print('Empty word at line %s with count %s' % (i, line))
                continue
            word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped  # shift because of extra words
            counts[line[0]] = int(line[1])
        f.close()
        id2word = {v: k for k, v in word2id.items()}
        dico = Dictionary(id2word, word2id, counts)
        logger.info("Read %i words from the vocabulary file." % len(dico))
        if skipped > 0:
            logger.warning("Skipped %i empty lines!" % skipped)
        return dico

    @staticmethod
    def index_data(path, bin_path, dico):
        """
        Index sentences with a dictionary.
        """
        if bin_path is not None and os.path.isfile(bin_path):
            print("Loading data from %s ..." % bin_path)
            data = torch.load(bin_path)
            assert dico == data['dico']
            return data

        positions = []
        sentences = []
        unk_words = {}

        # index sentences
        f = open(path, 'r', encoding='utf-8')
        for i, line in enumerate(f):
            if i % 1000000 == 0 and i > 0:
                print(i)
            s = line.rstrip().split()
            # skip empty sentences
            if len(s) == 0:
                print("Empty sentence in line %i." % i)
            # index sentence words
            count_unk = 0
            indexed = []
            for w in s:
                word_id = dico.index(w, no_unk=False)
                # if we find a special word which is not an unknown word, skip the sentence
                if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3:
                    logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id))
                    continue
                assert word_id >= 0
                indexed.append(word_id)
                if word_id == dico.unk_index:
                    unk_words[w] = unk_words.get(w, 0) + 1
                    count_unk += 1
            # add sentence
            positions.append([len(sentences), len(sentences) + len(indexed)])
            sentences.extend(indexed)
            sentences.append(1)  # EOS index
        f.close()

        # tensorize data
        positions = np.int64(positions)
        if len(dico) < 1 << 16:
            sentences = np.uint16(sentences)
        elif len(dico) < 1 << 31:
            sentences = np.int32(sentences)
        else:
            raise Exception("Dictionary is too big.")
        assert sentences.min() >= 0
        data = {
            'dico': dico,
            'positions': positions,
            'sentences': sentences,
            'unk_words': unk_words,
        }
        if bin_path is not None:
            print("Saving the data to %s ..." % bin_path)
            torch.save(data, bin_path, pickle_protocol=4)

        return data

In [None]:
from logging import getLogger
import os
import numpy as np
import torch




logger = getLogger()


def process_binarized(data, params):
    """
    Process a binarized dataset and log main statistics.
    """
    dico = data['dico']
    assert ((data['sentences'].dtype == np.uint16) and (len(dico) < 1 << 16) or
            (data['sentences'].dtype == np.int32) and (1 << 16 <= len(dico) < 1 << 31))
    logger.info("%i words (%i unique) in %i sentences. %i unknown words (%i unique) covering %.2f%% of the data." % (
        len(data['sentences']) - len(data['positions']),
        len(dico), len(data['positions']),
        sum(data['unk_words'].values()), len(data['unk_words']),
        100. * sum(data['unk_words'].values()) / (len(data['sentences']) - len(data['positions']))
    ))
    if params.max_vocab != -1:
        assert params.max_vocab > 0
        logger.info("Selecting %i most frequent words ..." % params.max_vocab)
        dico.max_vocab(params.max_vocab)
        data['sentences'][data['sentences'] >= params.max_vocab] = dico.index(UNK_WORD)
        unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum()
        logger.info("Now %i unknown words covering %.2f%% of the data."
                    % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions']))))
    if params.min_count > 0:
        logger.info("Selecting words with >= %i occurrences ..." % params.min_count)
        dico.min_count(params.min_count)
        data['sentences'][data['sentences'] >= len(dico)] = dico.index(UNK_WORD)
        unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum()
        logger.info("Now %i unknown words covering %.2f%% of the data."
                    % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions']))))
    if (data['sentences'].dtype == np.int32) and (len(dico) < 1 << 16):
        logger.info("Less than 65536 words. Moving data from int32 to uint16 ...")
        data['sentences'] = data['sentences'].astype(np.uint16)
    return data


def load_binarized(path, params):
    """
    Load a binarized dataset.
    """
    assert path.endswith('.pth')
    if params.debug_train:
        path = path.replace('train', 'valid')
    if getattr(params, 'multi_gpu', False):
        split_path = '%s.%i.pth' % (path[:-4], params.local_rank)
        if os.path.isfile(split_path):
            assert params.split_data is False
            path = split_path
    assert os.path.isfile(path), path
    logger.info("Loading data from %s ..." % path)
    data = torch.load(path)
    data = process_binarized(data, params)
    return data


def set_dico_parameters(params, data, dico):
    """
    Update dictionary parameters.
    """
    if 'dico' in data:
        assert data['dico'] == dico
    else:
        data['dico'] = dico

    n_words = len(dico)
    bos_index = dico.index(BOS_WORD)
    eos_index = dico.index(EOS_WORD)
    pad_index = dico.index(PAD_WORD)
    unk_index = dico.index(UNK_WORD)
    mask_index = dico.index(MASK_WORD)
    if hasattr(params, 'bos_index'):
        assert params.n_words == n_words
        assert params.bos_index == bos_index
        assert params.eos_index == eos_index
        assert params.pad_index == pad_index
        assert params.unk_index == unk_index
        assert params.mask_index == mask_index
    else:
        params.n_words = n_words
        params.bos_index = bos_index
        params.eos_index = eos_index
        params.pad_index = pad_index
        params.unk_index = unk_index
        params.mask_index = mask_index


def load_mono_data(params, data):
    """
    Load monolingual data.
    """
    data['mono'] = {}
    data['mono_stream'] = {}

    for lang in params.mono_dataset.keys():

        logger.info('============ Monolingual data (%s)' % lang)

        assert lang in params.langs and lang not in data['mono']
        data['mono'][lang] = {}
        data['mono_stream'][lang] = {}

        for splt in ['train', 'valid', 'test']:

            # no need to load training data for evaluation
            if splt == 'train' and params.eval_only:
                continue

            # load data / update dictionary parameters / update data
            mono_data = load_binarized(params.mono_dataset[lang][splt], params)
            set_dico_parameters(params, data, mono_data['dico'])

            # create stream dataset
            bs = params.batch_size if splt == 'train' else 1
            data['mono_stream'][lang][splt] = StreamDataset(mono_data['sentences'], mono_data['positions'], bs, params)

            # if there are several processes on the same machine, we can split the dataset
            if splt == 'train' and params.split_data and 1 < params.n_gpu_per_node <= data['mono_stream'][lang][splt].n_batches:
                n_batches = data['mono_stream'][lang][splt].n_batches // params.n_gpu_per_node
                a = n_batches * params.local_rank
                b = n_batches * params.local_rank + n_batches
                data['mono_stream'][lang][splt].select_data(a, b)

            # for denoising auto-encoding and online back-translation, we need a non-stream (batched) dataset
            if lang in params.ae_steps or lang in params.bt_src_langs:

                # create batched dataset
                dataset = Dataset(mono_data['sentences'], mono_data['positions'], params)

                # remove empty and too long sentences
                # if splt == 'train':
                dataset.remove_empty_sentences()
                dataset.remove_long_sentences(params.max_len)

                # if there are several processes on the same machine, we can split the dataset
                if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data:
                    n_sent = len(dataset) // params.n_gpu_per_node
                    a = n_sent * params.local_rank
                    b = n_sent * params.local_rank + n_sent
                    dataset.select_data(a, b)

                data['mono'][lang][splt] = dataset

            logger.info("")

    logger.info("")


def load_para_data(params, data):
    """
    Load parallel data.
    """
    data['para'] = {}

    required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps)

    for src, tgt in params.para_dataset.keys():

        logger.info('============ Parallel data (%s-%s)' % (src, tgt))

        assert (src, tgt) not in data['para']
        data['para'][(src, tgt)] = {}

        for splt in ['train', 'valid', 'test']:

            # no need to load training data for evaluation
            if splt == 'train' and params.eval_only:
                continue

            # for back-translation, we can't load training data
            if splt == 'train' and (src, tgt) not in required_para_train and (tgt, src) not in required_para_train:
                continue

            # load binarized datasets
            src_path, tgt_path = params.para_dataset[(src, tgt)][splt]
            src_data = load_binarized(src_path, params)
            tgt_data = load_binarized(tgt_path, params)

            # update dictionary parameters
            set_dico_parameters(params, data, src_data['dico'])
            set_dico_parameters(params, data, tgt_data['dico'])

            # create ParallelDataset
            dataset = ParallelDataset(
                src_data['sentences'], src_data['positions'],
                tgt_data['sentences'], tgt_data['positions'],
                params
            )

            # remove empty and too long sentences
            # if splt == 'train':
            dataset.remove_empty_sentences()
            dataset.remove_long_sentences(params.max_len)

            # for validation and test set, enumerate sentence per sentence
            if splt != 'train':
                dataset.tokens_per_batch = -1

            # if there are several processes on the same machine, we can split the dataset
            if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data:
                n_sent = len(dataset) // params.n_gpu_per_node
                a = n_sent * params.local_rank
                b = n_sent * params.local_rank + n_sent
                dataset.select_data(a, b)

            data['para'][(src, tgt)][splt] = dataset
            logger.info("")

    logger.info("")


def check_data_params(params):
    """
    Check datasets parameters.
    """
    # data path
    assert os.path.isdir(params.data_path), params.data_path

    # check languages
    params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en']
    assert len(params.langs) == len(set(params.langs)) >= 1
    # assert sorted(params.langs) == params.langs
    params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
    params.lang2id = {k: v for v, k in params.id2lang.items()}
    params.n_langs = len(params.langs)

    # CLM steps
    clm_steps = [s.split('-') for s in params.clm_steps.split(',') if len(s) > 0]
    params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps]
    assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.clm_steps])
    assert len(params.clm_steps) == len(set(params.clm_steps))

    # MLM / TLM steps
    mlm_steps = [s.split('-') for s in params.mlm_steps.split(',') if len(s) > 0]
    params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps]
    assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.mlm_steps])
    assert len(params.mlm_steps) == len(set(params.mlm_steps))

    # S2SLM steps
    s2slm_steps = [s.split('-') for s in params.s2slm_steps.split(',') if len(s) > 0]
    params.s2slm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in s2slm_steps]
    assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.s2slm_steps])
    assert len(params.s2slm_steps) == len(set(params.s2slm_steps))

    # parallel classification steps
    params.pc_steps = [tuple(s.split('-')) for s in params.pc_steps.split(',') if len(s) > 0]
    assert all([len(x) == 2 for x in params.pc_steps])
    assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.pc_steps])
    assert all([l1 != l2 for l1, l2 in params.pc_steps])
    assert len(params.pc_steps) == len(set(params.pc_steps))

    # machine translation steps
    params.mt_steps = [tuple(s.split('-')) for s in params.mt_steps.split(',') if len(s) > 0]
    assert all([len(x) == 2 for x in params.mt_steps])
    assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps])
    assert all([l1 != l2 for l1, l2 in params.mt_steps])
    assert len(params.mt_steps) == len(set(params.mt_steps))
    # assert len(params.mt_steps) == 0 or not params.encoder_only

    # denoising auto-encoder steps
    params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0]
    assert all([lang in params.langs for lang in params.ae_steps])
    assert len(params.ae_steps) == len(set(params.ae_steps))
    # assert len(params.ae_steps) == 0 or not params.encoder_only

    # back-translation steps
    params.bt_steps = [tuple(s.split('-')) for s in params.bt_steps.split(',') if len(s) > 0]
    assert all([len(x) == 3 for x in params.bt_steps])
    assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1, l2, l3 in params.bt_steps])
    assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
    assert len(params.bt_steps) == len(set(params.bt_steps))
    assert len(params.bt_steps) == 0 or not params.encoder_only
    params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]

    # check monolingual datasets
    required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps + params.s2slm_steps) if l2 is None] + params.ae_steps + params.bt_src_langs)
    # params.mono_dataset = {
    #     lang: {
    #         splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang))
    #         for splt in ['train', 'valid', 'test']
    #     } for lang in params.langs if lang in required_mono
    # }
    # NOTE swap lang splt
    params.mono_dataset = {
        lang: {
            splt: os.path.join(params.data_path, '%s.%s.pth' % (lang, splt))
            for splt in ['train', 'valid', 'test']
        } for lang in params.langs if lang in required_mono
    }
    assert all([all([os.path.isfile(p) for p in paths.values()]) for paths in params.mono_dataset.values()]), params.mono_dataset

    # check parallel datasets
    required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps + params.s2slm_steps)
    required_para = required_para_train | set([(l2, l3) for _, l2, l3 in params.bt_steps])
    params.para_dataset = {
        (src, tgt): {
            splt: (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)),
                   os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt)))
            for splt in ['train', 'valid', 'test']
            if splt != 'train' or (src, tgt) in required_para_train or (tgt, src) in required_para_train
        } for src in params.langs for tgt in params.langs
        if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para)
    }
    for paths in params.para_dataset.values():
        for p1, p2 in paths.values():
            if not os.path.isfile(p1):
                logger.error(f"{p1} not found")
            if not os.path.isfile(p2):
                logger.error(f"{p2} not found")
    assert all([all([os.path.isfile(p1) and os.path.isfile(p2) for p1, p2 in paths.values()]) for paths in params.para_dataset.values()])

    # check that we can evaluate on BLEU
    assert params.eval_bleu is False or len(params.mt_steps + params.bt_steps) > 0


def load_data(params):
    """
    Load monolingual data.
    The returned dictionary contains:
        - dico (dictionary)
        - vocab (FloatTensor)
        - train / valid / test (monolingual datasets)
    """
    data = {}

    # monolingual datasets
    load_mono_data(params, data)

    # parallel datasets
    load_para_data(params, data)

    # monolingual data summary
    logger.info('============ Data summary')
    for lang, v in data['mono_stream'].items():
        for data_set in v.keys():
            logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Monolingual data', data_set, lang, len(v[data_set])))

    # parallel data summary
    for (src, tgt), v in data['para'].items():
        for data_set in v.keys():
            logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Parallel data', data_set, '%s-%s' % (src, tgt), len(v[data_set])))

    logger.info("")
    return data

In [None]:
import math
import itertools
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F



class HashingMemory(nn.Module):

    MEM_VALUES_PARAMS = '.values.weight'
    VALUES = None
    EVAL_MEMORY = True
    _ids = itertools.count(0)

    def __init__(self, input_dim, output_dim, params):

        super().__init__()
        self.id = next(self._ids)

        # global parameters
        self.input2d = params.mem_input2d
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.size = params.mem_size
        self.modulo_size = params.mem_modulo_size
        self.n_indices = params.n_indices
        self.k_dim = params.mem_k_dim
        self.v_dim = params.mem_v_dim if params.mem_v_dim > 0 else output_dim
        self.heads = params.mem_heads
        self.knn = params.mem_knn
        self.shuffle_indices = params.mem_shuffle_indices
        self.keys_normalized_init = params.mem_keys_normalized_init
        self.product_quantization = params.mem_product_quantization
        assert self.modulo_size == -1 and self.size == self.n_indices or self.n_indices > self.size == self.modulo_size >= 1

        # keys / queries
        self.keys_type = params.mem_keys_type
        self.learn_keys = params.mem_keys_learn
        self.use_different_keys = params.mem_use_different_keys
        self.query_detach_input = params.mem_query_detach_input
        self.query_net_learn = params.mem_query_net_learn
        self.multi_query_net = params.mem_multi_query_net
        self.shuffle_query = params.mem_shuffle_query
        assert self.use_different_keys is False or self.keys_type in ['gaussian', 'uniform']
        assert self.use_different_keys is False or self.heads >= 2 or self.product_quantization
        assert self.multi_query_net is False or self.heads >= 2 or self.product_quantization
        assert self.shuffle_query is False or self.heads > 1 and params.mem_query_layer_sizes == ''
        assert self.shuffle_query is False or self.input_dim % (2 ** self.heads) == 0

        # scoring / re-scoring
        self.normalize_query = params.mem_normalize_query
        self.temperature = params.mem_temperature
        self.score_softmax = params.mem_score_softmax
        self.score_subtract = params.mem_score_subtract
        self.score_normalize = params.mem_score_normalize
        assert self.score_subtract in ['', 'min', 'mean', 'median']
        assert self.score_subtract == '' or self.knn >= 2
        assert not (self.score_normalize and self.score_softmax and self.score_subtract == '')

        # dropout
        self.input_dropout = params.mem_input_dropout
        self.query_dropout = params.mem_query_dropout
        self.value_dropout = params.mem_value_dropout

        # initialize keys
        self.init_keys()

        # self.values = nn.Embedding(self.size, self.v_dim, sparse=params.mem_sparse)
        self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.mem_sparse)

        # optionally use the same values for all memories
        if params.mem_share_values:
            if HashingMemory.VALUES is None:
                HashingMemory.VALUES = self.values.weight
            else:
                self.values.weight = HashingMemory.VALUES

        # values initialization
        if params.mem_value_zero_init:
            nn.init.zeros_(self.values.weight)
        else:
            nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)

        # no query network
        if len(params.mem_query_layer_sizes) == 0:
            assert self.heads == 1 or self.use_different_keys or self.shuffle_query
            assert self.input_dim == self.k_dim
            self.query_proj = QueryIdentity(self.input_dim, self.heads, self.shuffle_query)

        # query network
        if len(params.mem_query_layer_sizes) > 0:
            assert not self.shuffle_query

            # layer sizes / number of features
            l_sizes = list(params.mem_query_layer_sizes)
            assert len(l_sizes) >= 2 and l_sizes[0] == l_sizes[-1] == 0
            l_sizes[0] = self.input_dim
            l_sizes[-1] = (self.k_dim // 2) if self.multi_query_net else (self.heads * self.k_dim)

            # convolutional or feedforward
            if self.input2d:
                self.query_proj = QueryConv(
                    self.input_dim, self.heads, self.k_dim, self.product_quantization,
                    self.multi_query_net, l_sizes, params.mem_query_kernel_sizes,
                    bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv
                )
            else:
                assert params.mem_query_kernel_sizes == ''
                assert not params.mem_query_residual
                self.query_proj = QueryMLP(
                    self.input_dim, self.heads, self.k_dim, self.product_quantization,
                    self.multi_query_net, l_sizes,
                    bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv
                )

        # shuffle indices for different heads
        if self.shuffle_indices:
            head_permutations = [torch.randperm(self.n_indices).unsqueeze(0) for i in range(self.heads)]
            self.register_buffer('head_permutations', torch.cat(head_permutations, 0))

        # do not learn the query network
        if self.query_net_learn is False:
            for p in self.query_proj.parameters():
                p.requires_grad = False

    def forward(self, input):
        """
        Read from the memory.
        """
        # detach input
        if self.query_detach_input:
            input = input.detach()

        # input dimensions
        if self.input2d:
            assert input.shape[1] == self.input_dim
            n_images, _, height, width = input.shape
            prefix_shape = (n_images, width, height)
        else:
            assert input.shape[-1] == self.input_dim
            prefix_shape = input.shape[:-1]

        # compute query / store it
        bs = np.prod(prefix_shape)
        input = F.dropout(input, p=self.input_dropout, training=self.training)    # input shape
        query = self.query_proj(input)                                            # (bs * heads, k_dim)
        query = F.dropout(query, p=self.query_dropout, training=self.training)    # (bs * heads, k_dim)
        assert query.shape == (bs * self.heads, self.k_dim)

        # get indices
        scores, indices = self.get_indices(query, self.knn)                       # (bs * heads, knn) ** 2

        # optionally shuffle indices for different heads
        if self.shuffle_indices:
            indices = indices.view(bs, self.heads, -1).chunk(self.heads, 1)
            indices = [p[idx] for p, idx in zip(self.head_permutations, indices)]
            indices = torch.cat(indices, 1).view(bs * self.heads, -1)

        # take indices modulo the memory size
        if self.modulo_size != -1:
            indices = indices % self.modulo_size

        # re-scoring
        if self.temperature != 1:
            scores = scores / self.temperature                                    # (bs * heads, knn)
        if self.score_softmax:
            scores = F.softmax(scores.float(), dim=-1).type_as(scores)            # (bs * heads, knn)
        if self.score_subtract != '':
            if self.score_subtract == 'min':
                to_sub = scores.min(1, keepdim=True)[0]                           # (bs * heads, 1)
            if self.score_subtract == 'mean':
                to_sub = scores.mean(1, keepdim=True)                             # (bs * heads, 1)
            if self.score_subtract == 'median':
                to_sub = scores.median(1, keepdim=True)[0]                        # (bs * heads, 1)
            scores = scores - to_sub                                              # (bs * heads, knn)
        if self.score_normalize:
            scores = scores / scores.norm(p=1, dim=1, keepdim=True)               # (bs * heads, knn)

        # merge heads / knn (since we sum heads)
        indices = indices.view(bs, self.heads * self.knn)                         # (bs, heads * knn)
        scores = scores.view(bs, self.heads * self.knn)                           # (bs, heads * knn)

        # weighted sum of values
        # output = self.values(indices) * scores.unsqueeze(-1)                    # (bs * heads, knn, v_dim)
        # output = output.sum(1)                                                  # (bs * heads, v_dim)
        output = self.values(
            indices,
            per_sample_weights=scores.to(self.values.weight.data)
        ).to(scores)                                                              # (bs, v_dim)
        output = F.dropout(output, p=self.value_dropout, training=self.training)  # (bs, v_dim)

        # reshape output
        if self.input2d:
            output = output.view(n_images, width, height, self.v_dim)             # (n_images, width, height, v_dim)
            output = output.transpose(1, 3)                                       # (n_images, v_dim, height, width)
        else:
            if len(prefix_shape) >= 2:
                output = output.view(prefix_shape + (self.v_dim,))                # (..., v_dim)

        # store indices / scores (eval mode only - for usage statistics)
        if not self.training and HashingMemory.EVAL_MEMORY:
            self.last_indices = indices.view(bs, self.heads, self.knn).detach().cpu()
            self.last_scores = scores.view(bs, self.heads, self.knn).detach().cpu().float()

        return output

    def init_keys(self):
        raise Exception("Not implemented!")

    def _get_indices(self, query, knn, keys):
        raise Exception("Not implemented!")

    def get_indices(self, query, knn):
        raise Exception("Not implemented!")

    @staticmethod
    def register_args(parser):
        """
        Register memory parameters
        """
        # memory implementation
        parser.add_argument("--mem_implementation", type=str, default="fast",
                            help="Memory implementation (flat, pq_default, pq_fast)")

        # optimization
        parser.add_argument("--mem_grouped_conv", type=bool_flag, default=False,
                            help="Use grouped convolutions in the query network")
        parser.add_argument("--mem_values_optimizer", type=str, default="",
                            help="Memory values optimizer ("" for the same optimizer as the rest of the model)")
        parser.add_argument("--mem_sparse", type=bool_flag, default=False,
                            help="Perform sparse updates for the values")

        # global parameters
        parser.add_argument("--mem_input2d", type=bool_flag, default=False,
                            help="Convolutional query network")
        parser.add_argument("--mem_k_dim", type=int, default=16,
                            help="Memory keys dimension")
        parser.add_argument("--mem_v_dim", type=int, default=-1,
                            help="Memory values dimension (-1 for automatic output dimension)")
        parser.add_argument("--mem_heads", type=int, default=1,
                            help="Number of memory reading heads")
        parser.add_argument("--mem_knn", type=int, default=10,
                            help="Number of memory slots to read / update - k-NN to the query")
        parser.add_argument("--mem_share_values", type=bool_flag, default=False,
                            help="Share values across memories")
        parser.add_argument("--mem_shuffle_indices", type=bool_flag, default=False,
                            help="Shuffle indices for different heads")
        parser.add_argument("--mem_shuffle_query", type=bool_flag, default=False,
                            help="Shuffle query dimensions (when the query network is the identity and there are multiple heads)")
        parser.add_argument("--mem_modulo_size", type=int, default=-1,
                            help="Effective memory size: indices are taken modulo this parameter. -1 to disable.")

        # keys
        parser.add_argument("--mem_keys_type", type=str, default="uniform",
                            help="Memory keys type (binary,gaussian,uniform)")
        parser.add_argument("--mem_n_keys", type=int, default=512,
                            help="Number of keys")
        parser.add_argument("--mem_keys_normalized_init", type=bool_flag, default=False,
                            help="Normalize keys at initialization")
        parser.add_argument("--mem_keys_learn", type=bool_flag, default=False,
                            help="Learn keys")
        parser.add_argument("--mem_use_different_keys", type=bool_flag, default=False,
                            help="Use different keys for each head / product quantization")

        # queries
        parser.add_argument("--mem_query_detach_input", type=bool_flag, default=False,
                            help="Detach input")
        parser.add_argument("--mem_query_layer_sizes", type=str, default="",
                            help="Query MLP layer sizes ('', '0', '0,512,0')")
        parser.add_argument("--mem_query_kernel_sizes", type=str, default="",
                            help="Query MLP kernel sizes (2D inputs only)")
        parser.add_argument("--mem_query_bias", type=bool_flag, default=True,
                            help="Query MLP bias")
        parser.add_argument("--mem_query_batchnorm", type=bool_flag, default=True,
                            help="Query MLP batch norm")
        parser.add_argument("--mem_query_net_learn", type=bool_flag, default=True,
                            help="Query MLP learn")
        parser.add_argument("--mem_query_residual", type=bool_flag, default=False,
                            help="Use a bottleneck with a residual layer in the query MLP")
        parser.add_argument("--mem_multi_query_net", type=bool_flag, default=False,
                            help="Use multiple query MLP (one for each head)")

        # values initialization
        parser.add_argument("--mem_value_zero_init", type=bool_flag, default=False,
                            help="Initialize values with zeros")

        # scoring
        parser.add_argument("--mem_normalize_query", type=bool_flag, default=False,
                            help="Normalize queries")
        parser.add_argument("--mem_temperature", type=float, default=1,
                            help="Divide scores by a temperature")
        parser.add_argument("--mem_score_softmax", type=bool_flag, default=True,
                            help="Apply softmax on scores")
        parser.add_argument("--mem_score_subtract", type=str, default="",
                            help="Subtract scores ('', min, mean, median)")
        parser.add_argument("--mem_score_normalize", type=bool_flag, default=False,
                            help="L1 normalization of the scores")

        # dropout
        parser.add_argument("--mem_input_dropout", type=float, default=0,
                            help="Input dropout")
        parser.add_argument("--mem_query_dropout", type=float, default=0,
                            help="Query dropout")
        parser.add_argument("--mem_value_dropout", type=float, default=0,
                            help="Value dropout")

    @staticmethod
    def build(input_dim, output_dim, params):
        if params.mem_implementation == 'flat':
            M = HashingMemoryFlat
        elif params.mem_implementation == 'pq_default':
            M = HashingMemoryProduct
        elif params.mem_implementation == 'pq_fast':
            M = HashingMemoryProductFast
        else:
            raise Exception("Unknown memory implementation!")
        return M(input_dim, output_dim, params)

    @staticmethod
    def check_params(params):
        """
        Check and initialize memory parameters.
        """
        # memory
        assert params.mem_implementation in ['flat', 'pq_default', 'pq_fast']
        params.mem_product_quantization = params.mem_implementation != 'flat'

        # optimization
        assert params.mem_grouped_conv is False or params.mem_multi_query_net
        params.mem_values_optimizer = params.optimizer if params.mem_values_optimizer == '' else params.mem_values_optimizer
        params.mem_values_optimizer = params.mem_values_optimizer.replace('adam', 'sparseadam') if params.mem_sparse else params.mem_values_optimizer

        # even number of key dimensions for product quantization
        assert params.mem_k_dim >= 2
        assert params.mem_product_quantization is False or params.mem_k_dim % 2 == 0

        # memory type
        assert params.mem_keys_type in ['binary', 'gaussian', 'uniform']

        # number of indices
        if params.mem_keys_type == 'binary':
            assert params.mem_keys_normalized_init is False
            assert 1 << params.mem_k_dim == params.mem_n_keys
        if params.mem_product_quantization:
            params.n_indices = params.mem_n_keys ** 2
        else:
            params.n_indices = params.mem_n_keys

        # actual memory size
        if params.mem_modulo_size == -1:
            params.mem_size = params.n_indices
        else:
            assert 1 <= params.mem_modulo_size < params.n_indices
            params.mem_size = params.mem_modulo_size

        # different keys / different query MLP / shuffle hidden dimensions when no query network
        assert not params.mem_use_different_keys or params.mem_keys_type in ['gaussian', 'uniform']
        assert not params.mem_use_different_keys or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_query_layer_sizes not in ['', '0,0']
        assert not params.mem_shuffle_query or params.mem_heads > 1 and params.mem_query_layer_sizes == ''

        # query network
        if params.mem_query_layer_sizes == '':
            assert params.mem_heads == 1 or params.mem_use_different_keys or params.mem_shuffle_query
        else:
            s = [int(x) for x in filter(None, params.mem_query_layer_sizes.split(','))]
            assert len(s) >= 2 and s[0] == s[-1] == 0
            params.mem_query_layer_sizes = s
            assert not params.mem_query_residual or params.mem_input2d

        # convolutional query network kernel sizes
        if params.mem_query_kernel_sizes == '':
            assert not params.mem_input2d or params.mem_query_layer_sizes == ''
        else:
            assert params.mem_input2d
            s = [int(x) for x in filter(None, params.mem_query_kernel_sizes.split(','))]
            params.mem_query_kernel_sizes = s
            assert all(ks % 2 == 1 for ks in s)
            assert len(params.mem_query_kernel_sizes) == len(params.mem_query_layer_sizes) - 1 >= 1

        # scoring
        assert params.mem_score_subtract in ['', 'min', 'mean', 'median']
        assert params.mem_score_subtract == '' or params.mem_knn >= 2
        assert not (params.mem_score_normalize and params.mem_score_softmax and params.mem_score_subtract == '')

        # dropout
        assert 0 <= params.mem_input_dropout < 1
        assert 0 <= params.mem_query_dropout < 1
        assert 0 <= params.mem_value_dropout < 1


class HashingMemoryFlat(HashingMemory):

    def __init__(self, input_dim, output_dim, params):
        super().__init__(input_dim, output_dim, params)
        assert self.use_different_keys is False or self.heads >= 2
        assert not self.product_quantization

    def init_keys(self):
        """
        Initialize keys.
        """
        assert self.keys_type in ['binary', 'gaussian', 'uniform']

        # binary keys
        if self.keys_type == 'binary':
            keys = torch.FloatTensor(2 ** self.k_dim, self.k_dim)
            for i in range(keys.shape[0]):
                for j in range(keys.shape[1]):
                    keys[i, j] = int((1 << j) & i > 0)
            keys *= 2
            keys -= 1
            keys /= math.sqrt(self.k_dim)

        # random keys from Gaussian or uniform distributions
        if self.keys_type in ['gaussian', 'uniform']:
            init = get_gaussian_keys if self.keys_type == 'gaussian' else get_uniform_keys
            if self.use_different_keys:
                keys = torch.from_numpy(np.array([
                    init(self.n_indices, self.k_dim, self.keys_normalized_init, seed=i)
                    for i in range(self.heads)
                ])).view(self.heads, self.n_indices, self.k_dim)
            else:
                keys = torch.from_numpy(init(self.n_indices, self.k_dim, self.keys_normalized_init, seed=0))

        # learned or fixed keys
        if self.learn_keys:
            self.keys = nn.Parameter(keys)
        else:
            self.register_buffer('keys', keys)

    # def _get_indices(self, query, knn, keys):
    #     """
    #     Generate scores and indices given keys and unnormalized queries.
    #     """
    #     QUERY_SIZE = 4096
    #     assert query.dim() == 2 and query.size(1) == self.k_dim

    #     # optionally normalize queries
    #     if self.normalize_query:
    #         query = query / query.norm(2, 1, keepdim=True).expand_as(query)  # (bs, kdim)

    #     # compute memory indices, and split the query if it is too large
    #     with torch.no_grad():
    #         if len(query) <= QUERY_SIZE:
    #             indices = get_knn_faiss(keys.float(), query.float(), knn, distance='dot_product')[1]
    #         else:
    #             indices = torch.cat([
    #                 get_knn_faiss(keys.float(), query[i:i + QUERY_SIZE].float(), knn, distance='dot_product')[1]
    #                 for i in range(0, len(query), QUERY_SIZE)
    #             ], 0)
    #             # indices0 = get_knn_faiss(keys.float(), query.float(), knn, distance='dot_product')[1]
    #             # assert (indices0 - indices).abs().sum().item() == 0
    #         assert len(indices) == len(query)

    #     # compute value scores
    #     scores = (keys[indices] * query.unsqueeze(1)).sum(2)

    #     # return scores with indices
    #     assert scores.shape == indices.shape == (query.shape[0], knn)
    #     return scores, indices

    def _get_indices(self, query, knn, keys):
        """
        Generate scores and indices given keys and unnormalized queries.
        """
        assert query.dim() == 2 and query.size(1) == self.k_dim

        # optionally normalize queries
        if self.normalize_query:
            query = query / query.norm(2, 1, keepdim=True).expand_as(query)   # (bs, kdim)

        # compute scores with indices
        scores = F.linear(query, keys, bias=None)                             # (bs, n_keys)
        scores, indices = scores.topk(knn, dim=1, largest=True, sorted=True)  # (bs, knn) ** 2
        # scores, indices = get_knn_faiss(keys.float(), query.float().contiguous(), knn, distance='dot_product')   # (bs, knn) ** 2

        # return scores with indices
        assert scores.shape == indices.shape == (query.shape[0], knn)
        return scores, indices

    def get_indices(self, query, knn):
        """
        Generate scores and indices given unnormalized queries.
        """
        assert query.dim() == 2 and query.size(1) == self.k_dim
        if self.use_different_keys is False:
            return self._get_indices(query, knn, self.keys)
        else:
            bs = len(query)
            query = query.view(-1, self.heads, self.k_dim)
            outputs = [
                self._get_indices(query[:, i], knn, self.keys[i])
                for i in range(self.heads)
            ]
            scores = torch.cat([s.unsqueeze(1) for s, _ in outputs], 1).view(bs, knn)
            indices = torch.cat([idx.unsqueeze(1) for _, idx in outputs], 1).view(bs, knn)
            return scores, indices


class HashingMemoryProduct(HashingMemory):

    def __init__(self, input_dim, output_dim, params):
        super().__init__(input_dim, output_dim, params)
        assert self.k_dim % 2 == 0
        assert self.product_quantization

    def create_keys(self):
        """
        This function creates keys and returns them.
        I guess you could see that from the name of the function and the fact that is has a return statement.
        """
        assert self.keys_type in ['binary', 'gaussian', 'uniform']
        half = self.k_dim // 2
        n_keys = int(self.n_indices ** 0.5)

        # binary keys
        if self.keys_type == 'binary':
            keys = torch.FloatTensor(2 ** half, half)
            for i in range(keys.shape[0]):
                for j in range(keys.shape[1]):
                    keys[i, j] = int((1 << j) & i > 0)
            keys *= 2
            keys -= 1
            keys /= math.sqrt(self.k_dim)

        # random keys from Gaussian or uniform distributions
        if self.keys_type in ['gaussian', 'uniform']:
            init = get_gaussian_keys if self.keys_type == 'gaussian' else get_uniform_keys
            if self.use_different_keys:
                keys = torch.from_numpy(np.array([
                    init(n_keys, half, self.keys_normalized_init, seed=(2 * i + j))
                    for i in range(self.heads)
                    for j in range(2)
                ])).view(self.heads, 2, n_keys, half)
            else:
                keys = torch.from_numpy(init(n_keys, half, self.keys_normalized_init, seed=0))

        return keys

    def init_keys(self):
        """
        Initialize keys.
        """
        keys = self.create_keys()

        # learned or fixed keys
        if self.learn_keys:
            self.keys = nn.Parameter(keys)
        else:
            self.register_buffer('keys', keys)

    def _get_indices(self, query, knn, keys1, keys2):
        """
        Generate scores and indices given keys and unnormalized queries.
        """
        assert query.dim() == 2 and query.size(1) == self.k_dim
        assert len(keys1) == len(keys2)
        half = self.k_dim // 2
        n_keys = len(keys1)

        # split query for product quantization
        q1 = query[:, :half]                                                                            # (bs, half)
        q2 = query[:, half:]                                                                            # (bs, half)

        # optionally normalize queries
        if self.normalize_query:
            q1 = q1 / q1.norm(2, 1, keepdim=True).expand_as(q1)                                         # (bs, half)
            q2 = q2 / q2.norm(2, 1, keepdim=True).expand_as(q2)                                         # (bs, half)

        # compute memory value indices
        with torch.no_grad():

            # compute indices with associated scores
            scores1, indices1 = get_knn_faiss(keys1.float(), q1.float(), knn, distance='dot_product')  # (bs, knn) ** 2
            scores2, indices2 = get_knn_faiss(keys2.float(), q2.float(), knn, distance='dot_product')  # (bs, knn) ** 2

            # cartesian product on best candidate keys
            concat_scores = cartesian_product(scores1, scores2)                                         # (bs, knn ** 2, 2)
            concat_indices = cartesian_product(indices1, indices2)                                      # (bs, knn ** 2, 2)

            all_scores = concat_scores.sum(2)                                                           # (bs, knn ** 2)
            all_indices = concat_indices[:, :, 0] * n_keys + concat_indices[:, :, 1]                    # (bs, knn ** 2)

            _scores, best_indices = torch.topk(all_scores, k=knn, dim=1, largest=True, sorted=True)     # (bs, knn)
            indices = all_indices.gather(1, best_indices)                                               # (bs, knn)

        # compute value scores - for some reason, this part is extremely slow when the keys are learned
        indices1 = indices / n_keys
        indices2 = indices % n_keys
        scores1 = (keys1[indices1] * q1.unsqueeze(1)).sum(2)
        scores2 = (keys2[indices2] * q2.unsqueeze(1)).sum(2)
        scores = scores1 + scores2

        # return scores with indices
        assert scores.shape == indices.shape == (query.shape[0], knn)
        return scores, indices

    def get_indices(self, query, knn):
        """
        Generate scores and indices given unnormalized queries.
        """
        assert query.dim() == 2 and query.size(1) == self.k_dim
        if self.use_different_keys is False:
            return self._get_indices(query, knn, self.keys, self.keys)
        else:
            bs = len(query)
            query = query.view(-1, self.heads, self.k_dim)
            outputs = [
                self._get_indices(query[:, i], knn, self.keys[i][0], self.keys[i][1])
                for i in range(self.heads)
            ]
            scores = torch.cat([s.unsqueeze(1) for s, _ in outputs], 1).view(bs, knn)
            indices = torch.cat([idx.unsqueeze(1) for _, idx in outputs], 1).view(bs, knn)
            return scores, indices


class HashingMemoryProductFast(HashingMemoryProduct):

    def __init__(self, input_dim, output_dim, params):
        super().__init__(input_dim, output_dim, params)

    def _get_indices(self, query, knn, keys1, keys2):
        """
        Generate scores and indices given keys and unnormalized queries.
        """
        assert query.dim() == 2 and query.size(1) == self.k_dim
        assert len(keys1) == len(keys2)
        bs = query.size(0)
        half = self.k_dim // 2
        n_keys = len(keys1)

        # split query for product quantization
        q1 = query[:, :half]                                                                                          # (bs, half)
        q2 = query[:, half:]                                                                                          # (bs, half)

        # optionally normalize queries
        if self.normalize_query:
            q1 = q1 / q1.norm(2, 1, keepdim=True).expand_as(q1)                                                       # (bs, half)
            q2 = q2 / q2.norm(2, 1, keepdim=True).expand_as(q2)                                                       # (bs, half)

        # compute indices with associated scores
        scores1 = F.linear(q1, keys1, bias=None)                                                                      # (bs, n_keys ** 0.5)
        scores2 = F.linear(q2, keys2, bias=None)                                                                      # (bs, n_keys ** 0.5)
        scores1, indices1 = scores1.topk(knn, dim=1, largest=True, sorted=True)                                       # (bs, knn) ** 2
        scores2, indices2 = scores2.topk(knn, dim=1, largest=True, sorted=True)                                       # (bs, knn) ** 2
        # scores1, indices1 = get_knn_faiss(keys1, q1.contiguous(), knn, distance='dot_product')                        # (bs, knn) ** 2
        # scores2, indices2 = get_knn_faiss(keys2, q2.contiguous(), knn, distance='dot_product')                        # (bs, knn) ** 2

        # cartesian product on best candidate keys
        all_scores = (
            scores1.view(bs, knn, 1).expand(bs, knn, knn) +
            scores2.view(bs, 1, knn).expand(bs, knn, knn)
        ).view(bs, -1)                                                                                                # (bs, knn ** 2)
        all_indices = (
            indices1.view(bs, knn, 1).expand(bs, knn, knn) * n_keys +
            indices2.view(bs, 1, knn).expand(bs, knn, knn)
        ).view(bs, -1)                                                                                                # (bs, knn ** 2)

        # select overall best scores and indices
        scores, best_indices = torch.topk(all_scores, k=knn, dim=1, largest=True, sorted=True)                        # (bs, knn)
        indices = all_indices.gather(1, best_indices)                                                                 # (bs, knn)

        # code below: debug instant retrieval speed
        # scores = torch.zeros(bs, knn, dtype=query.dtype, device=query.device)
        # indices = torch.arange(knn, dtype=torch.int64, device=query.device).view(1, knn).expand(bs, knn)

        # return scores with indices
        assert scores.shape == indices.shape == (bs, knn)
        return scores, indices

In [None]:
import torch
from torch import nn



def mlp(sizes, bias=True, batchnorm=True, groups=1):
    """
    Generate a feedforward neural network.
    """
    assert len(sizes) >= 2
    pairs = [(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)]
    layers = []

    for i, (dim_in, dim_out) in enumerate(pairs):
        if groups == 1 or i == 0:
            layers.append(nn.Linear(dim_in, groups * dim_out, bias=bias))
        else:
            layers.append(GroupedLinear(groups * dim_in, groups * dim_out, bias=bias, groups=groups))
        if batchnorm:
            layers.append(nn.BatchNorm1d(groups * dim_out))
        if i < len(pairs) - 1:
            layers.append(nn.ReLU())

    return nn.Sequential(*layers)


def convs(channel_sizes, kernel_sizes, bias=True, batchnorm=True, residual=False, groups=1):
    """
    Generate a convolutional neural network.
    """
    assert len(channel_sizes) >= 2
    assert len(channel_sizes) == len(kernel_sizes) + 1
    pairs = [(channel_sizes[i], channel_sizes[i + 1]) for i in range(len(channel_sizes) - 1)]
    layers = []

    for i, (dim_in, dim_out) in enumerate(pairs):
        ks = (kernel_sizes[i], kernel_sizes[i])
        in_group = 1 if i == 0 else groups
        _dim_in = dim_in * in_group
        _dim_out = dim_out * groups
        if not residual:
            layers.append(nn.Conv2d(_dim_in, _dim_out, ks, padding=[k // 2 for k in ks], bias=bias, groups=in_group))
            if batchnorm:
                layers.append(nn.BatchNorm2d(_dim_out))
            if i < len(pairs) - 1:
                layers.append(nn.ReLU())
        else:
            layers.append(BottleneckResidualConv2d(
                _dim_in, _dim_out, ks, bias=bias,
                batchnorm=batchnorm, groups=in_group
            ))
            if i == len(pairs) - 1:
                layers.append(nn.Conv2d(_dim_out, _dim_out, (1, 1), bias=bias))

    return nn.Sequential(*layers)


class GroupedLinear(nn.Module):

    def __init__(self, in_features, out_features, bias=True, groups=1):

        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        self.bias = bias
        assert groups > 1

        self.layer = nn.Conv1d(in_features, out_features, bias=bias, kernel_size=1, groups=groups)

    def forward(self, input):
        assert input.dim() == 2 and input.size(1) == self.in_features
        return self.layer(input.unsqueeze(2)).squeeze(2)

    def extra_repr(self):
        return 'in_features={}, out_features={}, groups={}, bias={}'.format(
            self.in_features, self.out_features, self.groups, self.bias is not None
        )


class BottleneckResidualConv2d(nn.Module):

    def __init__(self, input_channels, output_channels, kernel_size, bias=True, batchnorm=True, groups=1):

        super().__init__()
        hidden_channels = min(input_channels, output_channels)
        assert all(k % 2 == 1 for k in kernel_size)

        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups)
        self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups)
        self.act = nn.ReLU()

        self.batchnorm = batchnorm
        if self.batchnorm:
            self.bn1 = nn.BatchNorm2d(hidden_channels)
            self.bn2 = nn.BatchNorm2d(output_channels)

        if input_channels == output_channels:
            self.residual = nn.Sequential()
        else:
            self.residual = nn.Conv2d(input_channels, output_channels, (1, 1), bias=False, groups=groups)

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x) if self.batchnorm else x
        x = self.act(x)
        x = self.conv2(x)
        x = self.bn2(x) if self.batchnorm else x
        x = self.act(x + self.residual(input))
        return x


class QueryIdentity(nn.Module):

    def __init__(self, input_dim, heads, shuffle_hidden):
        super().__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.shuffle_query = shuffle_hidden
        assert shuffle_hidden is False or heads > 1
        assert shuffle_hidden is False or self.input_dim % (2 ** self.heads) == 0
        if shuffle_hidden:
            self.slices = {head_id: get_slices(input_dim, head_id) for head_id in range(heads)}

    def forward(self, input):
        """
        Generate queries from hidden states by either
        repeating them or creating some shuffled version.
        """
        assert input.shape[-1] == self.input_dim
        input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input
        bs = len(input)

        if self.heads == 1:
            query = input

        elif not self.shuffle_query:
            query = input.unsqueeze(1).repeat(1, self.heads, 1)
            query = query.view(bs * self.heads, self.input_dim)

        else:
            query = torch.cat([
                input[:, a:b]
                for head_id in range(self.heads)
                for a, b in self.slices[head_id]
            ], 1).view(bs * self.heads, self.input_dim)

        assert query.shape == (bs * self.heads, self.input_dim)
        return query


class QueryMLP(nn.Module):

    def __init__(
        self, input_dim, heads, k_dim, product_quantization, multi_query_net,
        sizes, bias=True, batchnorm=True, grouped_conv=False
    ):
        super().__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.k_dim = k_dim
        self.sizes = sizes
        self.grouped_conv = grouped_conv
        assert not multi_query_net or product_quantization or heads >= 2
        assert sizes[0] == input_dim
        assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim)
        assert self.grouped_conv is False or len(sizes) > 2

        # number of required MLPs
        self.groups = (2 * heads) if multi_query_net else 1

        # MLPs
        if self.grouped_conv:
            self.query_mlps = mlp(sizes, bias=bias, batchnorm=batchnorm, groups=self.groups)
        elif len(self.sizes) == 2:
            sizes_ = list(sizes)
            sizes_[-1] = sizes_[-1] * self.groups
            self.query_mlps = mlp(sizes_, bias=bias, batchnorm=batchnorm, groups=1)
        else:
            self.query_mlps = nn.ModuleList([
                mlp(sizes, bias=bias, batchnorm=batchnorm, groups=1)
                for _ in range(self.groups)
            ])

    def forward(self, input):
        """
        Compute queries using either grouped 1D convolutions or ModuleList + concat.
        """
        assert input.shape[-1] == self.input_dim
        input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input
        bs = len(input)

        if self.grouped_conv or len(self.sizes) == 2:
            query = self.query_mlps(input)
        else:
            outputs = [m(input) for m in self.query_mlps]
            query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0]

        assert query.shape == (bs, self.heads * self.k_dim)
        return query.view(bs * self.heads, self.k_dim)


class QueryConv(nn.Module):

    def __init__(
        self, input_dim, heads, k_dim, product_quantization, multi_query_net,
        sizes, kernel_sizes, bias=True, batchnorm=True,
        residual=False, grouped_conv=False
    ):
        super().__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.k_dim = k_dim
        self.sizes = sizes
        self.grouped_conv = grouped_conv
        assert not multi_query_net or product_quantization or heads >= 2
        assert sizes[0] == input_dim
        assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim)
        assert self.grouped_conv is False or len(sizes) > 2
        assert len(sizes) == len(kernel_sizes) + 1 >= 2 and all(ks % 2 == 1 for ks in kernel_sizes)

        # number of required CNNs
        self.groups = (2 * heads) if multi_query_net else 1

        # CNNs
        if self.grouped_conv:
            self.query_convs = convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=self.groups)
        elif len(self.sizes) == 2:
            sizes_ = list(sizes)
            sizes_[-1] = sizes_[-1] * self.groups
            self.query_convs = convs(sizes_, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1)
        else:
            self.query_convs = nn.ModuleList([
                convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1)
                for _ in range(self.groups)
            ])

    def forward(self, input):

        bs, nf, h, w = input.shape
        assert nf == self.input_dim

        if self.grouped_conv or len(self.sizes) == 2:
            query = self.query_convs(input)
        else:
            outputs = [m(input) for m in self.query_convs]
            query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0]

        assert query.shape == (bs, self.heads * self.k_dim, h, w)
        query = query.transpose(1, 3).contiguous().view(bs * w * h * self.heads, self.k_dim)
        return query

In [None]:
import sys
import math
import numpy as np
import torch


# load FAISS GPU library if available (dramatically accelerates the nearest neighbor search)
try:
    import faiss
    FAISS_AVAILABLE = hasattr(faiss, 'StandardGpuResources')
except ImportError:
    FAISS_AVAILABLE = False
    sys.stderr.write("Impossible to import FAISS library!!\n")


def get_gaussian_keys(n_keys, dim, normalized, seed):
    """
    Generate random Gaussian keys.
    """
    rng = np.random.RandomState(seed)
    X = rng.randn(n_keys, dim)
    if normalized:
        X /= np.linalg.norm(X, axis=1, keepdims=True)
    return X.astype(np.float32)


def get_uniform_keys(n_keys, dim, normalized, seed):
    """
    Generate random uniform keys (same initialization as nn.Linear).
    """
    rng = np.random.RandomState(seed)
    bound = 1 / math.sqrt(dim)
    X = rng.uniform(-bound, bound, (n_keys, dim))
    if normalized:
        X /= np.linalg.norm(X, axis=1, keepdims=True)
    return X.astype(np.float32)


def get_slices(dim, head_id):
    """
    Generate slices of hidden dimensions.
    Used when there are multiple heads and/or different set of keys,
    and that there is no query network.
    """
    if head_id == 0:
        return [(0, dim)]
    offset = dim // (2 ** (head_id + 1))
    starts = np.arange(0, dim, offset)
    slices1 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 0]
    slices2 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 1]
    return slices1 + slices2


def cartesian_product(a, b):
    """
    Compute the batched cartesian product between two matrices.
    Input:
        a: Tensor(n, d1)
        b: Tensor(n, d2)
    Output:
        output: Tensor(n, d1 * d2, 2)
    """
    n1, d1 = a.shape
    n2, d2 = b.shape
    assert n1 == n2
    return torch.cat([
        a.unsqueeze(-1).repeat(1, 1, d2).unsqueeze(-1),
        b.repeat(1, d1).view(n2, d1, d2).unsqueeze(-1)
    ], 3).view(n1, d1 * d2, 2)


def swig_ptr_from_FloatTensor(x):
    assert x.is_contiguous()
    assert x.dtype == torch.float32
    return faiss.cast_integer_to_float_ptr(x.storage().data_ptr() + x.storage_offset() * 4)


def swig_ptr_from_LongTensor(x):
    assert x.is_contiguous()
    assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
    return faiss.cast_integer_to_long_ptr(x.storage().data_ptr() + x.storage_offset() * 8)


def get_knn_pytorch(a, b, k, distance='dot_product'):
    """
    Input:
        - matrix of size (m, d) (keys)
        - matrix of size (n, d) (queries)
        - number of nearest neighbors
        - distance metric
    Output:
        - `scores`  matrix of size (n, k) with nearest neighors scores
        - `indices` matrix of size (n, k) with nearest neighors indices
    """
    m, d = a.size()
    n, _ = b.size()
    assert b.size(1) == d
    assert k > 0
    assert distance in ['dot_product', 'cosine', 'l2']

    with torch.no_grad():

        if distance == 'dot_product':
            scores = a.mm(b.t())                                 # (m, n)

        elif distance == 'cosine':
            scores = a.mm(b.t())                                 # (m, n)
            scores /= (a.norm(2, 1)[:, None] + 1e-9)             # (m, n)
            scores /= (b.norm(2, 1)[None, :] + 1e-9)             # (m, n)

        elif distance == 'l2':
            scores = a.mm(b.t())                                 # (m, n)
            scores *= 2                                          # (m, n)
            scores -= (a ** 2).sum(1)[:, None]                   # (m, n)
            scores -= (b ** 2).sum(1)[None, :]                   # (m, n)

        scores, indices = scores.topk(k=k, dim=0, largest=True)  # (k, n)
        scores = scores.t()                                      # (n, k)
        indices = indices.t()                                    # (n, k)

    return scores, indices


def get_knn_faiss(xb, xq, k, distance='dot_product'):
    """
    `metric` can be faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2
    https://github.com/facebookresearch/faiss/blob/master/gpu/test/test_pytorch_faiss.py
    """
    assert xb.device == xq.device
    assert distance in ['dot_product', 'l2']
    metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2

    xq_ptr = swig_ptr_from_FloatTensor(xq)
    xb_ptr = swig_ptr_from_FloatTensor(xb)

    nq, d1 = xq.size()
    nb, d2 = xb.size()
    assert d1 == d2

    D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
    I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)

    D_ptr = swig_ptr_from_FloatTensor(D)
    I_ptr = swig_ptr_from_LongTensor(I)

    faiss.bruteForceKnn(
        FAISS_RES, metric,
        xb_ptr, nb,
        xq_ptr, nq,
        d1, k, D_ptr, I_ptr
    )

    return D, I


if FAISS_AVAILABLE:
    FAISS_RES = faiss.StandardGpuResources()
    FAISS_RES.setDefaultNullStreamAllDevices()
    FAISS_RES.setTempMemory(1200 * 1024 * 1024)
    get_knn = get_knn_faiss
else:
    sys.stderr.write("Switching to standard nearest neighbors search implementation, this will be significantly slower.\n")
    get_knn = get_knn_pytorch

Impossible to import FAISS library!!
Switching to standard nearest neighbors search implementation, this will be significantly slower.


In [None]:
!pip install git+https://github.com/Maluuba/nlg-eval.git@master

Collecting git+https://github.com/Maluuba/nlg-eval.git@master
  Cloning https://github.com/Maluuba/nlg-eval.git (to revision master) to /tmp/pip-req-build-243osu4o
  Running command git clone -q https://github.com/Maluuba/nlg-eval.git /tmp/pip-req-build-243osu4o
Collecting nltk>=3.4.5
  Downloading nltk-3.7-py3-none-any.whl (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 5.1 MB/s 
Collecting psutil>=5.6.2
  Downloading psutil-5.9.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (280 kB)
[K     |████████████████████████████████| 280 kB 45.5 MB/s 
Collecting gensim~=3.8.3
  Downloading gensim-3.8.3-cp37-cp37m-manylinux1_x86_64.whl (24.2 MB)
[K     |████████████████████████████████| 24.2 MB 1.2 MB/s 
[?25hCollecting Theano>=0.8.1
  Downloading Theano-1.0.5.tar.gz (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 42.9 MB/s 
Collecting xdg
  Downloading xdg-5.1.1-py3-none-any.whl (5.0 kB)
Collecting regex>=2021.8.3

In [None]:

import nlgeval
import nlgeval.utils

In [None]:
from logging import getLogger
import os
import copy
import time
import json
import math
import numpy as np
from collections import OrderedDict, defaultdict

import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
# import rouge
from nlgeval import NLGEval



XPersona_LANGS = ["en", "zh"]
# XPersona_LANGS = ["en", "fr"]
# XPersona_LANGS = ["en", "it"]
# XPersona_LANGS = ["en", "id"]
# XPersona_LANGS = ["en", "jp"]
# XPersona_LANGS = ["en", "ko"]

logger = getLogger()
# evaluator = rouge.Rouge(
#   metrics=['rouge-n', 'rouge-l'],
#   max_n=2,
#   limit_length=True,
#   length_limit=100,
#   length_limit_type='words',
#   alpha=0.5, # Default F1_score
#   weight_factor=1.2,
#   stemming=False)
nlgeval = NLGEval(
  no_skipthoughts=True,no_glove=True,metrics_to_omit=['CIDEr'])


def get_parameters(model, train_layers_str):
  ret = []

  fr, to = map(int, train_layers_str.split(","))
  assert fr >= 0
  if fr == 0:
    # add embeddings
    ret += model.embeddings.parameters()
    logger.info("Adding embedding parameters")
    ret += model.position_embeddings.parameters()
    logger.info("Adding positional embedding parameters")
    ret += model.lang_embeddings.parameters()
    logger.info("Adding language embedding parameters")
    fr = 1
  assert fr <= to
  # add attention layers
  # NOTE cross attention is not added
  for i in range(fr, to + 1):
    ret += model.attentions[i-1].parameters()
    ret += model.layer_norm1[i-1].parameters()
    ret += model.ffns[i-1].parameters()
    ret += model.layer_norm2[i-1].parameters()
    logger.info("Adding layer-%s parameters to optimizer" % i)

  return ret


def tokens2words(toks):
  words = []
  for tok in toks:
    if len(words) > 0 and words[-1].endswith("@@"):
      words[-1] = words[-1][:-2] + tok
    else:
      words.append(tok)
  return words


class XPersona(object):

  def __init__(self, encoder, decoder, scores, dico, params):
    self.encoder = encoder
    self.decoder = decoder
    self.params = params
    self.scores = scores
    self.dico = dico

    self.iter_cache = {}
  
  def setup_vocab_mask(self, dico):
    n_words = len(dico)
    params = self.params

    self.vocab_mask = {}

    decode_vocab_sizes = [int(s) for s in params.decode_vocab_sizes.split(",")]
    assert len(decode_vocab_sizes) == len(XPersona_LANGS)

    for lang, sz in zip(XPersona_LANGS, decode_vocab_sizes):
      
      fn = os.path.join(params.vocab_path, lang + ".vocab")
      assert os.path.isfile(fn), fn

      mask = torch.ByteTensor(n_words)
      mask.fill_(0)
      assert mask.sum() == 0
      mask[dico.eos_index] = 1
      # TODO generate unk?
      mask[dico.unk_index] = 1
      count = 0
      with open(fn) as fp:
        for line, _ in zip(fp, range(sz)):
          tok = line.strip().split("\t")[0].split(" ")[0]
          if tok not in dico.word2id:
            # logger.warn("Token %s not in dico" % tok)
            count += 1
          else: mask[dico.word2id[tok]] = 1
      
      # mask[dico.word2id["<@@"]] = 0
      logger.warn("%d tokens not in dico" % count)
      self.vocab_mask[lang] = mask
  
  def gen_references_v2(self, dico, eval_directions):
    self.references = {}
    for split in ["valid", "test"]:
      self.references[split] = {}
      for direction in eval_directions:
        x_lang, y_lang = direction
        if y_lang in self.references: continue
        refs = []
        for batch in self.get_iterator(split, x_lang, y_lang):
          _, (sent_y, len_y), _ = batch
          for j in range(len(len_y)):
            ref_sent = sent_y[1:len_y[j]-1,j]
            ref_toks = [dico[ref_sent[k].item()] for k in range(len(ref_sent))]
            ref_words = tokens2words(ref_toks)

            #zh or en2zh
            if y_lang.endswith("zh"): refs.append(" ".join("".join(ref_words)))
            else: refs.append(" ".join(ref_words))

        self.references[split][y_lang] = refs
  
  def _parse_lang(self, lang):
    if type(lang) == tuple:
      assert len(lang) == 2
      lang1, lang2 = lang
      assert lang1 in XPersona_LANGS
      assert lang2 in XPersona_LANGS
      return (lang1, lang2)
    if type(lang) == str:
      if lang in XPersona_LANGS:
        return (lang, lang)
      else:
        lang1, lang2 = lang.split("2")
        assert lang1 in XPersona_LANGS
        assert lang2 in XPersona_LANGS
        return (lang1, lang2)

  def get_iterator(self, splt, x_lang, y_lang):
    x_lang = self._parse_lang(x_lang)
    y_lang = self._parse_lang(y_lang)
    logger.info("Getting iterator -- x_lang: (%s, %s), y_lang: (%s, %s) split:%s" % (
      x_lang[0], x_lang[1], y_lang[0], y_lang[1], splt))
    return self.get_or_load_data(x_lang, y_lang, splt).get_iterator(
      shuffle=(splt == 'train'),
      group_by_size=self.params.group_by_size,
      return_indices=True)
  
  def next_batch(self, splt, x_lang, y_lang):
    
    key = (splt, x_lang, y_lang)
    if key not in self.iter_cache:
      self.iter_cache[key] = self.get_iterator(splt, x_lang, y_lang)
    try:
      ret = next(self.iter_cache[key])
    except StopIteration:
      self.iter_cache[key] = self.get_iterator(splt, x_lang, y_lang)
      ret = next(self.iter_cache[key])
    return ret
  
  def lang2str(self, lang):
    lang1, lang2 = lang
    if lang1 == lang2: return lang1
    return "%s-%s" % (lang1, lang2)
  
  def get_or_load_data(self, x_lang, y_lang, splt):
    params = self.params
    data = self.data

    lang = (x_lang, y_lang)
    if lang in self.data:
      if splt in self.data[lang]: return self.data[lang][splt]
    else:
      self.data[lang] = {}
    
    dpath = os.path.join(params.data_path, "eval", params.ds_name)

    x = load_binarized(os.path.join(dpath, "%s.x.%s.pth" % (
      splt, self.lang2str(x_lang))), params)
    y = load_binarized(os.path.join(dpath, "%s.y.%s.pth" % (
      splt, self.lang2str(y_lang))), params)
    data["dico"] = data.get("dico", x["dico"])
    set_dico_parameters(params, data, x["dico"])
    set_dico_parameters(params, data, y["dico"])

    data[lang][splt] = ParallelDataset(
      x["sentences"], x["positions"],
      y["sentences"], y["positions"],
      params)
    data[lang][splt].remove_empty_sentences()
    data[lang][splt].cut_long_sentences(params.max_len, params.max_len)

    if params.cut_dataset > 0 and splt == "train":
      data[lang][splt].select_data(0, params.cut_dataset + 1)

    return self.data[lang][splt]

  def run(self):
    params = self.params

    train_directions = [d.split("-") for d in params.train_directions.split(",")]
    eval_directions = [d.split("-") for d in params.eval_directions.split(",")]

    self.data = {}
  
    # self.encoder.cuda()
    # self.decoder.cuda()
    self.encoder.to(params.device1)
    self.decoder.to(params.device2)

    parameters = []
    if params.train_layers == "all":
      parameters.extend([_ for _ in self.encoder.parameters()])
      parameters.extend([_ for _ in self.decoder.parameters()])
    elif params.train_layers == "decoder":
      parameters = self.decoder.parameters()
    elif params.train_layers == "encoder":
      parameters = self.encoder.parameters()
    else:
      parameters = get_parameters(self.encoder, params.train_layers)
    self.optimizer = get_optimizer(parameters, params.optimizer)

    self.gen_references_v2(self.dico, eval_directions)
    if self.params.decode_with_vocab: self.setup_vocab_mask(self.dico)

    # self.best_scores = defaultdict(float)
    self.best_ppl = 1000000
    
    for epoch in range(params.n_epochs):
      self.epoch = epoch
      logger.info("XPersona - Training epoch %d ..." % epoch)
      self.train(train_directions)
      logger.info("XPersona - Evaluating epoch %d ..." % epoch)
      self.eval(eval_directions, "valid", True)
      self.eval(eval_directions, "test", False)
  
  def gen_resp_(self):
    params = self.params

    train_directions = [d.split("-") for d in params.train_directions.split(",")]
    eval_directions = [d.split("-") for d in params.eval_directions.split(",")]

    self.data = {}
  
    self.encoder.cuda()
    self.decoder.cuda()
    parameters = []
    if params.train_layers == "all":
      parameters.extend([_ for _ in self.encoder.parameters()])
      parameters.extend([_ for _ in self.decoder.parameters()])
    elif params.train_layers == "decoder":
      parameters = self.decoder.parameters()
    elif params.train_layers == "encoder":
      parameters = self.encoder.parameters()
    else:
      parameters = get_parameters(self.encoder, params.train_layers)
    self.optimizer = get_optimizer(parameters, params.optimizer)

    self.gen_references_v2(self.dico, eval_directions)
    if self.params.decode_with_vocab: self.setup_vocab_mask(self.dico)

    # self.best_scores = defaultdict(float)
    self.best_ppl = 0
    self.generate_response(eval_directions, "test")
  
  def test(self):
    params = self.params

    # train_directions = [d.split("-") for d in params.train_directions.split(",")]
    eval_directions = [d.split("-") for d in params.eval_directions.split(",")]

    self.data = {}
  
    # self.encoder.cuda()
    # self.decoder.cuda()
    self.encoder.to(params.device1)
    self.decoder.to(params.device2)
    parameters = []
    if params.train_layers == "all":
      parameters.extend([_ for _ in self.encoder.parameters()])
      parameters.extend([_ for _ in self.decoder.parameters()])
    elif params.train_layers == "decoder":
      parameters = self.decoder.parameters()
    elif params.train_layers == "encoder":
      parameters = self.encoder.parameters()
    else:
      parameters = get_parameters(self.encoder, params.train_layers)
    self.optimizer = get_optimizer(parameters, params.optimizer)

    self.gen_references_v2(self.dico, eval_directions)
    if self.params.decode_with_vocab: self.setup_vocab_mask(self.dico)

    # self.best_scores = defaultdict(float)
    self.best_ppl = 0
    self.epoch = 100
    self.eval(eval_directions, "test", False)
  
  def generate_response(self, eval_directions, split="test"):
    params = self.params
    encoder = self.encoder
    decoder = self.decoder
    encoder.eval()
    decoder.eval()
    dico = self.dico
    # best_scores = self.best_scores

    for direction in eval_directions:
      x_lang, y_lang = direction
      logger.info("Evaluating %s-%s-xpersona on %s set" % (x_lang, y_lang, split))

      X, Y = [], []
      x_lang_id = params.lang2id[x_lang[-2:]]
      y_lang_id = params.lang2id[y_lang[-2:]]
      vocab_mask=self.vocab_mask[y_lang[-2:]] if params.decode_with_vocab else None

      perplexity_list = []
      for batch in self.get_iterator(split, x_lang, y_lang):
        (sent_x, len_x), (sent_y, len_y), _ = batch
        lang_x = sent_x.clone().fill_(x_lang_id)
        lang_y = sent_y.clone().fill_(y_lang_id)

        sent_x, len_x, lang_x, sent_y, len_y, lang_y = to_cuda(sent_x, len_x, lang_x, sent_y, len_y, lang_y)

        with torch.no_grad():
          encoded = encoder(
            "fwd", x=sent_x, lengths=len_x, langs=lang_x, causal=False)
          encoded = encoded.transpose(0, 1)
          
          # calculate perplexity
          alen = torch.arange(len_y.max(), dtype=torch.long, device=len_y.device)
          pred_mask = alen[:, None] < len_y[None] - 1
          y = sent_y[1:].masked_select(pred_mask[:-1])

          if params.beam_size == 1:
            decoded, _ = decoder.generate(
              encoded, len_x, y_lang_id, max_len=params.max_dec_len,
              vocab_mask=vocab_mask)
          else:
            decoded, _ = decoder.generate_beam(
              encoded, len_x, y_lang_id, beam_size=params.beam_size,
              length_penalty=0.9, early_stopping=False,
              max_len=params.max_dec_len, vocab_mask=vocab_mask)
      
        for j in range(decoded.size(1)):
          sent = decoded[:, j]
          delimiters = (sent == params.eos_index).nonzero().view(-1)
          assert len(delimiters) >= 1 and delimiters[0].item() == 0
          sent = sent[1:] if len(delimiters) == 1  else sent[1: delimiters[1]]

          trg_tokens = [dico[sent[k].item()] for k in range(len(sent))]
          trg_words = tokens2words(trg_tokens)
          if y_lang.endswith("zh"): Y.append(" ".join("".join(trg_words)))
          else: Y.append(" ".join(trg_words))

          # if len(X) < 5:
          x_sent = sent_x[1:len_x[j], j]
          x_toks = [dico[x_sent[k].item()] for k in range(len(x_sent))]
          x_words = tokens2words(x_toks)
          X.append(x_words)
    
      logger.info("%d res %d ref" % (len(Y), len(self.references[split][y_lang])))
      for i in range(len(X)):
        logger.info("%d X: %s\nGenerated: %s\nReference: %s\n" % (
            i, " ".join(X[i]), Y[i], self.references[split][y_lang][i]))
      
  def train(self, train_directions):
    params = self.params
    encoder = self.encoder
    decoder = self.decoder

    encoder.train()
    decoder.train()

    # training variables
    losses = []
    ns = 0  # number of sentences
    nw = 0  # number of words
    t = time.time()

    # x_lang, y_lang = train_direction
    # x_lang_id = params.lang2id[x_lang[-2:]]
    # y_lang_id = params.lang2id[y_lang[-2:]]
    n_train_drt = len(train_directions)

    for step_idx in range(params.epoch_size):

      x_lang, y_lang = train_directions[step_idx % n_train_drt]
      x_lang_id = params.lang2id[x_lang[-2:]]
      y_lang_id = params.lang2id[y_lang[-2:]]

      batch = self.next_batch("train", x_lang, y_lang)
      (sent_x, len_x), (sent_y, len_y), _ = batch
      lang_x = sent_x.clone().fill_(x_lang_id)
      lang_y = sent_y.clone().fill_(y_lang_id)
      alen = torch.arange(len_y.max(), dtype=torch.long, device=len_y.device)
      pred_mask = alen[:, None] < len_y[None] - 1
      y = sent_y[1:].masked_select(pred_mask[:-1])
      assert len(y) == (len_y-1).sum().item()

      # sent_x, len_x, lang_x, sent_y, len_y, lang_y, y = to_cuda(
        # sent_x, len_x, lang_x, sent_y, len_y, lang_y, y)
      
      sent_x, len_x, lang_x = sent_x.to(params.device1), len_x.to(params.device1), lang_x.to(params.device1)
      sent_y, len_y, lang_y, y = sent_y.to(params.device2), len_y.to(params.device2), lang_y.to(params.device2), y.to(params.device2)

      enc_x = self.encoder("fwd", x=sent_x, lengths=len_x, langs=lang_x, causal=False)
      enc_x = enc_x.transpose(0, 1)

      enc_x, len_x = enc_x.to(params.device2), len_x.to(params.device2)

      dec_y = self.decoder('fwd', x=sent_y, lengths=len_y, langs=lang_y,
        causal=True, src_enc=enc_x, src_len=len_x)
      
      _, loss = self.decoder("predict", tensor=dec_y, pred_mask=pred_mask, y=y,
        get_scores=False)
      
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      bs = len(len_y)
      ns += bs
      nw += len_y.sum().item()
      losses.append(loss.item())

      # log
      if ns % (100 * bs) < bs:
        logger.info(
          "XPersona - Epoch %i - Train iter %7i - %.1f words/s - Loss: %.4f" % (
            self.epoch, ns, nw / (time.time() - t), sum(losses) / len(losses)))
        nw, t = 0, time.time()
        losses = []
  
  def eval(self, eval_directions, split="test", save=True):
    params = self.params
    encoder = self.encoder
    decoder = self.decoder
    encoder.eval()
    decoder.eval()
    dico = self.dico
    # best_scores = self.best_scores

    for direction in eval_directions:
      x_lang, y_lang = direction
      logger.info("Evaluating %s-%s-xpersona on %s set" % (x_lang, y_lang, split))

      X, Y = [], []
      x_lang_id = params.lang2id[x_lang[-2:]]
      y_lang_id = params.lang2id[y_lang[-2:]]
      vocab_mask=self.vocab_mask[y_lang[-2:]] if params.decode_with_vocab else None

      # perplexity_list = []
      loss_list = []
      for batch in self.get_iterator(split, x_lang, y_lang):
        (sent_x, len_x), (sent_y, len_y), _ = batch
        lang_x = sent_x.clone().fill_(x_lang_id)
        lang_y = sent_y.clone().fill_(y_lang_id)

        # sent_x, len_x, lang_x, sent_y, len_y, lang_y = to_cuda(sent_x, len_x, lang_x, sent_y, len_y, lang_y)
        
        sent_x, len_x, lang_x = sent_x.to(params.device1), len_x.to(params.device1), lang_x.to(params.device1)
        sent_y, len_y, lang_y = sent_y.to(params.device2), len_y.to(params.device2), lang_y.to(params.device2)

        with torch.no_grad():
          encoded = encoder(
            "fwd", x=sent_x, lengths=len_x, langs=lang_x, causal=False)
          encoded = encoded.transpose(0, 1)
          
          encoded, len_x = encoded.to(params.device2), len_x.to(params.device2)

          # calculate perplexity
          alen = torch.arange(len_y.max(), dtype=torch.long, device=len_y.device)
          pred_mask = alen[:, None] < len_y[None] - 1
          y = sent_y[1:].masked_select(pred_mask[:-1])

          dec_y = self.decoder('fwd', x=sent_y, lengths=len_y, langs=lang_y, causal=True, src_enc=encoded, src_len=len_x)
          _, loss = self.decoder("predict", tensor=dec_y, pred_mask=pred_mask, y=y, get_scores=False)

          loss_list.append(loss.item())
          # perplexity = math.exp(loss.item())
          # perplexity_list.append(perplexity)

          if params.beam_size == 1:
            decoded, _ = decoder.generate(
              encoded, len_x, y_lang_id, max_len=params.max_dec_len,
              vocab_mask=vocab_mask)
          else:
            decoded, _ = decoder.generate_beam(
              encoded, len_x, y_lang_id, beam_size=params.beam_size,
              length_penalty=0.9, early_stopping=False,
              max_len=params.max_dec_len, vocab_mask=vocab_mask)
      
        for j in range(decoded.size(1)):
          sent = decoded[:, j]
          delimiters = (sent == params.eos_index).nonzero().view(-1)
          assert len(delimiters) >= 1 and delimiters[0].item() == 0
          sent = sent[1:] if len(delimiters) == 1  else sent[1: delimiters[1]]

          trg_tokens = [dico[sent[k].item()] for k in range(len(sent))]
          trg_words = tokens2words(trg_tokens)
          if y_lang.endswith("zh"): Y.append(" ".join("".join(trg_words)))
          else: Y.append(" ".join(trg_words))

          # if len(X) < 5:
          x_sent = sent_x[1:len_x[j], j]
          x_toks = [dico[x_sent[k].item()] for k in range(len(x_sent))]
          x_words = tokens2words(x_toks)
          X.append(x_words)
    
      logger.info("%d res %d ref" % (len(Y), len(self.references[split][y_lang])))
      for i in range(5):
        logger.info("%d X: %s\nGenerated: %s\nReference: %s\n" % (
            i, " ".join(X[i]), Y[i], self.references[split][y_lang][i]))
      eval_res = nlgeval.compute_metrics([self.references[split][y_lang][:len(Y)]], Y)
      # eval_res = evaluator.get_scores(Y, self.references[y_lang][:len(Y)])

      direction_str = "-".join(direction)

      # if save:
      #   if eval_res["Bleu_4"] > best_scores[direction_str]:
      #     logger.info("New best Bleu_4 score: %.5f! Saving model..." % eval_res["Bleu_4"])
      #     best_scores[direction_str] = eval_res["Bleu_4"]
      #     self.save("best_%s_Bleu_4" % direction_str)
      
      # use perplexity to stop train
      # calculate perplexity
      avg_loss = np.mean(loss_list)
      perplexity = math.exp(avg_loss)
      if save:
        if perplexity < self.best_ppl:
          logger.info("New best Perplexity: %.5f! Saving model..." % perplexity)
          self.best_ppl = perplexity
          self.save("best_%s_Perplexity" % direction_str)
      
      if split == "test":
        print("writing down output and refenerences ....")
        assert len(X) == len(self.references[split][y_lang][:len(Y)]) == len(Y)

        with open(os.path.join(self.params.dump_path, "output_"+str(self.epoch)+"_"+str(y_lang)+".txt"), "w") as f:
          for sent in Y:
            f.write(sent + "\n")
        with open(os.path.join(self.params.dump_path, "ref_"+str(self.epoch)+"_"+str(y_lang)+".txt"), "w") as f:
          for sent in self.references[split][y_lang][:len(Y)]:
            f.write(sent + "\n")
        with open(os.path.join(self.params.dump_path, "persona_chat_"+str(self.epoch)+"_"+str(y_lang)+".txt"), "w") as f:
          for persona_and_history, output, reference in zip(X, Y, self.references[split][y_lang][:len(Y)]):
            f.write("=====================================================\n")
            f.write("History:\n")
            f.write(" ".join(persona_and_history))
            f.write('\n')
            f.write("Response: ")
            f.write(output)
            f.write('\n')
            f.write("Ref: ")
            f.write(reference)
            f.write('\n')
      # logger.info("XPersona - %s - Epoch %d - Best BLEU-4: %.5f - scores: %s" % (
      #   direction_str, self.epoch, best_scores[direction_str], eval_res))
      
      logger.info("XPersona - %s - Epoch %d - Current Perplexity %.5f - Best Perplexity: %.5f - Metrics Scores: %s" % (
        direction_str, self.epoch, perplexity, self.best_ppl, eval_res))
      
      # eval_res_print = {metric:results["f"] for metric, results in sorted(eval_res.items(), key=lambda x: x[0])}

      # logger.info("XPersona - %s - Epoch %d - Best rouge-l: %.5f - scores: %s" % (
      #   direction_str, self.epoch, best_scores[direction_str], eval_res_print))
      
      # if eval_res["rouge-l"]['f'] > best_scores[direction_str]:
      #   logger.info("New best rouge-l score! Saving model...")
      #   best_scores[direction_str] = eval_res["rouge-l"]['f']
      #   self.save("best_%s_rouge-l" % direction_str)

  def save(self, name):
    path = os.path.join(self.params.dump_path, "%s.pth" % name)
    logger.info("Saving %s to %s ..." % (name, path))
    data = {
      "epoch": getattr(self, "epoch", 0),
      "encoder": self.encoder.state_dict(),
      "decoder": self.decoder.state_dict(),
      "enc_params": {
        k: v for k, v in self.params.encoder_model_params.__dict__.items()},
      "dec_params": {
        k: v for k, v in self.params.decoder_model_params.__dict__.items()},
      "dico_id2word": self.dico.id2word,
      "dico_word2id": self.dico.word2id,
      "dico_counts": self.dico.counts,
      "params": {k: v for k, v in self.params.__dict__.items()}
    }
    torch.save(data, path)

In [None]:
from logging import getLogger
import math
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F



N_MAX_POSITIONS = 512  # maximum input sequence length

DECODER_ONLY_PARAMS = [
    'layer_norm15.%i.weight', 'layer_norm15.%i.bias',
    'encoder_attn.%i.q_lin.weight', 'encoder_attn.%i.q_lin.bias',
    'encoder_attn.%i.k_lin.weight', 'encoder_attn.%i.k_lin.bias',
    'encoder_attn.%i.v_lin.weight', 'encoder_attn.%i.v_lin.bias',
    'encoder_attn.%i.out_lin.weight', 'encoder_attn.%i.out_lin.bias'
]

TRANSFORMER_LAYER_PARAMS = [
    'attentions.%i.q_lin.weight', 'attentions.%i.q_lin.bias',
    'attentions.%i.k_lin.weight', 'attentions.%i.k_lin.bias',
    'attentions.%i.v_lin.weight', 'attentions.%i.v_lin.bias',
    'attentions.%i.out_lin.weight', 'attentions.%i.out_lin.bias',
    'layer_norm1.%i.weight', 'layer_norm1.%i.bias',
    'ffns.%i.lin1.weight', 'ffns.%i.lin1.bias',
    'ffns.%i.lin2.weight', 'ffns.%i.lin2.bias',
    'layer_norm2.%i.weight', 'layer_norm2.%i.bias'
]


logger = getLogger()


def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    if padding_idx is not None:
        nn.init.constant_(m.weight[padding_idx], 0)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    # nn.init.normal_(m.weight, mean=0, std=1)
    # nn.init.xavier_uniform_(m.weight)
    # nn.init.constant_(m.bias, 0.)
    return m


def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def gelu(x):
    """
    GELU activation
    https://arxiv.org/abs/1606.08415
    https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
    https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/modeling.py
    """
    # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))


def get_masks(slen, lengths, causal):
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    assert lengths.max().item() <= slen
    bs = lengths.size(0)
    alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
    mask = alen < lengths[:, None]

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class PredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
    def __init__(self, params):
        super().__init__()
        self.asm = params.asm
        self.n_words = params.n_words
        self.pad_index = params.pad_index
        dim = params.emb_dim

        if params.asm is False:
            self.proj = Linear(dim, params.n_words, bias=True)
        else:
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
                n_classes=params.n_words,
                cutoffs=params.asm_cutoffs,
                div_value=params.asm_div_value,
                head_bias=True,  # default is False
            )

    def forward(self, x, y, get_scores=False):
        """
        Compute the loss, and optionally the scores.
        """
        assert (y == self.pad_index).sum().item() == 0

        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
            loss = F.cross_entropy(scores, y, reduction='mean')
        else:
            _, loss = self.proj(x, y)
            scores = self.proj.log_prob(x) if get_scores else None

        return scores, loss

    def get_scores(self, x):
        """
        Compute scores.
        """
        assert x.dim() == 2
        return self.proj.log_prob(x) if self.asm else self.proj(x)


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

    def __init__(self, n_heads, dim, dropout, tf_cls):
        super().__init__()
        self.layer_id = next(MultiHeadAttention.NEW_ID)
        self.dim = dim
        self.n_heads = n_heads
        self.ms2s = (tf_cls != TransformerModel)
        self.dropout = dropout
        assert self.dim % self.n_heads == 0

        self.q_lin = Linear(dim, dim)
        self.k_lin = Linear(dim, dim)
        self.v_lin = Linear(dim, dim)
        self.out_lin = Linear(dim, dim)

    def forward(self, input, mask, kv=None, cache=None, ms2s=False):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
        assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.ms2s:
                if self.layer_id in cache:
                    if kv is None:
                        k_, v_ = cache[self.layer_id]
                        k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                        # print(v_.size(2), v.size(2))
                        v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                        cache[self.layer_id] = (k[:,:,:-1,:], v[:,:,:-1,:])
                    else:
                        k, v = cache[self.layer_id]
                else: cache[self.layer_id] = (k[:,:,:-1,:], v[:,:,:-1,:])
                torch.cuda.empty_cache()
            else:
                if self.layer_id in cache:
                    if kv is None:
                        k_, v_ = cache[self.layer_id]
                        k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                        v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    else:
                        k, v = cache[self.layer_id]
                cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)                                       # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

        return self.out_lin(context)


class TransformerFFN(nn.Module):

    def __init__(self, in_dim, dim_hidden, out_dim, dropout, gelu_activation):
        super().__init__()
        self.dropout = dropout
        self.lin1 = Linear(in_dim, dim_hidden)
        self.lin2 = Linear(dim_hidden, out_dim)
        self.act = gelu if gelu_activation else F.relu

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class TransformerModel(nn.Module):

    ATTRIBUTES = ['encoder', 'with_output', 'eos_index', 'pad_index', 'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'asm_cutoffs', 'asm_div_value']

    def __init__(self, params, dico, is_encoder, with_output):
        """
        Transformer model (encoder or decoder).
        """
        super().__init__()

        # encoder / decoder, output layer
        self.is_encoder = is_encoder
        self.is_decoder = not is_encoder
        self.with_output = with_output

        # dictionary / languages
        self.n_langs = params.n_langs
        self.n_words = params.n_words
        self.eos_index = params.eos_index
        self.pad_index = params.pad_index
        self.dico = dico
        self.id2lang = params.id2lang
        self.lang2id = params.lang2id
        self.use_lang_emb = getattr(params, 'use_lang_emb', True)
        assert len(self.dico) == self.n_words
        assert len(self.id2lang) == len(self.lang2id) == self.n_langs

        # model parameters
        self.dim = params.emb_dim       # 512 by default
        self.hidden_dim = self.dim * 4  # 2048 by default
        self.n_heads = params.n_heads   # 8 by default
        if is_encoder: 
            # self.n_layers = params.n_enc_layers
            self.n_layers = getattr(params, 'n_enc_layers', params.n_layers)
        else: self.n_layers = params.n_layers
        self.dropout = params.dropout
        self.attention_dropout = params.attention_dropout
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
        self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim)
        if params.sinusoidal_embeddings:
            create_sinusoidal_embeddings(N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight)
        if params.n_langs > 1 and self.use_lang_emb:
            self.lang_embeddings = Embedding(self.n_langs, self.dim)
        self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
        if self.is_decoder:
            self.layer_norm15 = nn.ModuleList()
            self.encoder_attn = nn.ModuleList()

        # memories
        self.memories = nn.ModuleDict()
        if getattr(params, 'use_memory', False):
            mem_positions = params.mem_enc_positions if is_encoder else params.mem_dec_positions
            for layer_id, pos in mem_positions:
                assert 0 <= layer_id <= params.n_layers - 1
                assert pos in ['in', 'after']
                self.memories['%i_%s' % (layer_id, pos)] = HashingMemory.build(self.dim, self.dim, params)

        for layer_id in range(self.n_layers):
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout, tf_cls=self.__class__))
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
            if self.is_decoder:
                self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
                self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout, tf_cls=self.__class__))
            if ('%i_in' % layer_id) in self.memories:
                self.ffns.append(None)
            else:
                self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation))
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))

        # output layer
        if self.with_output:
            self.pred_layer = PredLayer(params)
            if params.share_inout_emb:
                self.pred_layer.proj.weight = self.embeddings.weight

    def forward(self, mode, **kwargs):
        """
        Forward function with different forward modes.
        ### Small hack to handle PyTorch distributed.
        """
        if mode == 'fwd':
            return self.fwd(**kwargs)
        elif mode == 'predict':
            return self.predict(**kwargs)
        else:
            raise Exception("Unknown mode: %s" % mode)

    def fwd(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, langs=None, cache=None):
        """
        Inputs:
            `x` LongTensor(slen, bs), containing word indices
            `lengths` LongTensor(bs), containing the length of each sentence
            `causal` Boolean, if True, the attention is only done over previous hidden states
            `positions` LongTensor(slen, bs), containing word positions
            `langs` LongTensor(slen, bs), containing language IDs
        """
        # lengths = (x != self.pad_index).float().sum(dim=1)
        # mask = x != self.pad_index

        # check inputs
        slen, bs = x.size()
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
        x = x.transpose(0, 1)  # batch size as dimension 0
        assert (src_enc is None) == (src_len is None)
        if src_enc is not None:
            assert self.is_decoder
            assert src_enc.size(0) == bs

        # generate masks
        mask, attn_mask = get_masks(slen, lengths, causal)
        if self.is_decoder and src_enc is not None:
            src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]

        # positions
        if positions is None:
            positions = x.new(slen).long()
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
            assert positions.size() == (slen, bs)
            positions = positions.transpose(0, 1)

        # langs
        if langs is not None:
            assert langs.size() == (slen, bs)
            langs = langs.transpose(0, 1)

        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
            x = x[:, -_slen:]
            positions = positions[:, -_slen:]
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
        tensor = self.embeddings(x)
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        if langs is not None and self.use_lang_emb:
            tensor = tensor + self.lang_embeddings(langs)
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
        for i in range(self.n_layers):

            # self attention
            attn = self.attentions[i](tensor, attn_mask, cache=cache)
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
            if self.is_decoder and src_enc is not None:
                attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
                attn = F.dropout(attn, p=self.dropout, training=self.training)
                tensor = tensor + attn
                tensor = self.layer_norm15[i](tensor)

            # FFN
            if ('%i_in' % i) in self.memories:
                tensor = tensor + self.memories['%i_in' % i](tensor)
            else:
                tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)

            # memory
            if ('%i_after' % i) in self.memories:
                tensor = tensor + self.memories['%i_after' % i](tensor)
            # TODO: add extra layer norm here?

            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
        tensor = tensor.transpose(0, 1)

        return tensor

    def predict(self, tensor, pred_mask, y, get_scores):
        """
        Given the last hidden state, compute word scores and/or the loss.
            `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
                we need to predict a word
            `y` is a LongTensor of shape (pred_mask.sum(),)
            `get_scores` is a boolean specifying whether we need to return scores
        """
        masked_tensor = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim)
        scores, loss = self.pred_layer(masked_tensor, y, get_scores)
        return scores, loss

    def generate(self, src_enc, src_len, tgt_lang_id, max_len=200, sample_temperature=None, vocab_mask=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        if vocab_mask is not None:
            assert len(vocab_mask.size()) == 1
            vocab_mask = 1 - vocab_mask
            vocab_mask = vocab_mask.unsqueeze(0)

        # input batch
        bs = len(src_len)
        assert src_enc.size(0) == bs

        # generated sentences
        generated = src_len.new(max_len, bs)  # upcoming output
        generated.fill_(self.pad_index)       # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)    # we use <EOS> for <BOS> everywhere

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)

        # language IDs
        langs = src_len.new(max_len).long().fill_(tgt_lang_id)
        langs = langs.unsqueeze(1).expand(max_len, bs)

        # current position / max lengths / length of generated sentences / unfinished sentences
        cur_len = 1
        gen_len = src_len.clone().fill_(1)
        unfinished_sents = src_len.clone().fill_(1)

        # cache compute states
        cache = {'slen': 0}

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=gen_len,
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs, self.dim), (cur_len, max_len, src_enc.size(), tensor.size(), (1, bs, self.dim))
            tensor = tensor.data[-1, :, :].type_as(src_enc)  # (bs, dim)
            scores = self.pred_layer.get_scores(tensor)      # (bs, n_words)
            
            if vocab_mask is not None: 
                scores[vocab_mask.expand_as(scores)] = -float('inf')
            # select next words: sample or greedy
            if sample_temperature is None:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            else:
                next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), 1).squeeze(1)
            assert next_words.size() == (bs,)

            # update generations / lengths / finished sentences / current length
            generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

        # add <EOS> to unfinished sentences
        if cur_len == max_len:
            generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)

        # sanity check
        assert (generated == self.eos_index).sum() == 2 * bs

        return generated[:cur_len], gen_len

    def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200, vocab_mask=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        if vocab_mask is not None:
            assert len(vocab_mask.size()) == 1
            vocab_mask = 1 - vocab_mask
            vocab_mask = vocab_mask.unsqueeze(0)

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_len)
        n_words = self.n_words

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(max_len, bs * beam_size)  # upcoming output
        generated.fill_(self.pad_index)                   # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)                # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)

        # language IDs
        langs = positions.clone().fill_(tgt_lang_id)

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        cache = {'slen': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=src_len.new(bs * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs * beam_size, self.dim)
            tensor = tensor.data[-1, :, :]               # (bs * beam_size, dim)
            scores = self.pred_layer.get_scores(tensor)  # (bs * beam_size, n_words)
            if vocab_mask is not None: 
                scores[vocab_mask.expand_as(scores)] = -float('inf')
            scores = F.log_softmax(scores, dim=-1)       # (bs * beam_size, n_words)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size * n_words)            # (bs, beam_size * n_words)

            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
            assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in cache.keys():
                if k != 'slen':
                    cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # visualize hypotheses
        # print([len(x) for x in generated_hyps], cur_len)
        # globals().update( locals() );
        # !import code; code.interact(local=vars())
        # for ii in range(bs):
        #     for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
        #         print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
        #     print("")

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
        for i, hypo in enumerate(best):
            decoded[:tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == 2 * bs

        return decoded, tgt_len


class BeamHypotheses(object):

    def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_len = max_len - 1  # ignoring <BOS>
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.n_hyp = n_hyp
        self.hyp = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.hyp)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.n_hyp or score > self.worst_score:
            self.hyp.append((score, hyp))
            if len(self) > self.n_hyp:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
                del self.hyp[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
        if len(self) < self.n_hyp:
            return False
        elif self.early_stopping:
            return True
        else:
            return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty

In [None]:
from logging import getLogger
import io
import numpy as np
import torch


logger = getLogger()


def load_fasttext_model(path):
    """
    Load a binarized fastText model.
    """
    try:
        import fastText
    except ImportError:
        raise Exception("Unable to import fastText. Please install fastText for Python: "
                        "https://github.com/facebookresearch/fastText")
    return fastText.load_model(path)


def read_txt_embeddings(path, params):
    """
    Reload pretrained embeddings from a text file.
    """
    word2id = {}
    vectors = []

    # load pretrained embeddings
    _emb_dim_file = params.emb_dim
    with io.open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        for i, line in enumerate(f):
            if i == 0:
                split = line.split()
                assert len(split) == 2
                assert _emb_dim_file == int(split[1])
                continue
            word, vect = line.rstrip().split(' ', 1)
            vect = np.fromstring(vect, sep=' ')
            if word in word2id:
                logger.warning("Word \"%s\" found twice!" % word)
                continue
            if not vect.shape == (_emb_dim_file,):
                logger.warning("Invalid dimension (%i) for word \"%s\" in line %i."
                               % (vect.shape[0], word, i))
                continue
            assert vect.shape == (_emb_dim_file,)
            word2id[word] = len(word2id)
            vectors.append(vect[None])

    assert len(word2id) == len(vectors)
    logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path))

    # compute new vocabulary / embeddings
    embeddings = np.concatenate(vectors, 0)
    embeddings = torch.from_numpy(embeddings).float()

    assert embeddings.size() == (len(word2id), params.emb_dim)
    return word2id, embeddings


def load_bin_embeddings(path, params):
    """
    Reload pretrained embeddings from a fastText binary file.
    """
    model = load_fasttext_model(path)
    assert model.get_dimension() == params.emb_dim
    words = model.get_labels()
    logger.info("Loaded binary model from %s" % path)

    # compute new vocabulary / embeddings
    embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0)
    embeddings = torch.from_numpy(embeddings).float()
    word2id = {w: i for i, w in enumerate(words)}
    logger.info("Generated embeddings for %i words." % len(words))

    assert embeddings.size() == (len(word2id), params.emb_dim)
    return word2id, embeddings


def load_embeddings(path, params):
    """
    Reload pretrained embeddings.
    """
    if path.endswith('.bin'):
        return load_bin_embeddings(path, params)
    else:
        return read_txt_embeddings(path, params)

In [None]:
from logging import getLogger
import torch

from torch import nn



logger = getLogger()


class SentenceEmbedder(object):

    @staticmethod
    def reload(path, params, cls_name=TransformerModel):
        """
        Create a sentence embedder from a pretrained model.
        """
        # reload model
        reloaded = torch.load(path)
        state_dict = reloaded['model']

        # handle models from multi-GPU checkpoints
        if 'checkpoint' in path:
            state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}

        # reload dictionary and model parameters
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        pretrain_params = AttrDict(reloaded['params'])
        pretrain_params.n_words = len(dico)
        pretrain_params.bos_index = dico.index(BOS_WORD)
        pretrain_params.eos_index = dico.index(EOS_WORD)
        pretrain_params.pad_index = dico.index(PAD_WORD)
        pretrain_params.unk_index = dico.index(UNK_WORD)
        pretrain_params.mask_index = dico.index(MASK_WORD)

        # if "n_nlu_layers" in params: 
        #     pretrain_params.n_nlu_layers = params.n_nlu_layers
        # if "n_task_layers" in params: 
        #     pretrain_params.n_task_layers = params.n_task_layers
        # if "n_lang_layers" in params: 
        #     pretrain_params.n_lang_layers = params.n_lang_layers
        
        # TODO config n layers to load

        # build model and reload weights
        model = cls_name(pretrain_params, dico, True, True, params.use_task_emb)
        # model = cls_name(params, dico, True, True, params.use_task_emb)
        # NOTE task embedding is not included in the Facebook XLM15
        model.load_state_dict(state_dict, strict=False)
        model.eval()

        # adding missing parameters
        params.max_batch_size = 0

        return SentenceEmbedder(model, dico, pretrain_params)
        # return SentenceEmbedder(model, dico, params)

    def __init__(self, model, dico, pretrain_params):
        """
        Wrapper on top of the different sentence embedders.
        Returns sequence-wise or single-vector sentence representations.
        """
        self.pretrain_params = {k: v for k, v in pretrain_params.__dict__.items()}
        self.model = model
        self.dico = dico
        self.n_layers = model.n_layers
        self.out_dim = model.dim
        self.n_words = model.n_words

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def cuda(self):
        self.model.cuda()
    
    def parallel(self, params):
        self.model =  nn.parallel.DistributedDataParallel(
            self.model, device_ids=[params.local_rank],
            output_device=params.local_rank, broadcast_buffers=False)

    def get_parameters(self, params):

        layer_range = params.finetune_layers

        s = layer_range.split(':')
        assert len(s) == 2
        i, j = int(s[0].replace('_', '-')), int(s[1].replace('_', '-'))

        # negative indexing
        i = self.n_layers + i + 1 if i < 0 else i
        j = self.n_layers + j + 1 if j < 0 else j

        # sanity check
        assert 0 <= i <= self.n_layers
        assert 0 <= j <= self.n_layers

        if i > j:
            return []

        parameters = []

        # embeddings
        if i == 0:
            # embeddings
            if not params.fixed_embeddings:
                parameters += self.model.embeddings.parameters()
                logger.info("Adding embedding parameters to optimizer")
            # positional embeddings
            if self.pretrain_params['sinusoidal_embeddings'] is False \
                and not params.fixed_position_embeddings:
                parameters += self.model.position_embeddings.parameters()
                logger.info("Adding positional embedding parameters to optimizer")
            # language embeddings
            if hasattr(self.model, 'lang_embeddings') and \
                not params.fixed_lang_embeddings:
                parameters += self.model.lang_embeddings.parameters()
                logger.info("Adding language embedding parameters to optimizer")
            # task embeddings
            if hasattr(self.model, "task_embeddings") and \
                not params.fixed_task_embeddings:
                parameters += self.model.task_embeddings.parameters()
                logger.info("Adding task embedding parameters to optimizer")
            parameters += self.model.layer_norm_emb.parameters()
        # layers
        for l in range(max(i - 1, 0), j):
            parameters += self.model.attentions[l].parameters()
            parameters += self.model.layer_norm1[l].parameters()
            parameters += self.model.ffns[l].parameters()
            parameters += self.model.layer_norm2[l].parameters()
            logger.info("Adding layer-%s parameters to optimizer" % (l + 1))

        logger.info("Optimizing on %i Transformer elements." % sum([p.nelement() for p in parameters]))

        return parameters

    def get_embeddings(self, x, lengths, positions=None, langs=None):
        """
        Inputs:
            `x`        : LongTensor of shape (slen, bs)
            `lengths`  : LongTensor of shape (bs,)
        Outputs:
            `sent_emb` : FloatTensor of shape (bs, out_dim)
        With out_dim == emb_dim
        """
        slen, bs = x.size()
        assert lengths.size(0) == bs and lengths.max().item() == slen

        # get transformer last hidden layer
        tensor = self.model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)
        assert tensor.size() == (slen, bs, self.out_dim)

        # single-vector sentence representation (first column of last layer)
        return tensor[0]

In [None]:
import sys
sys.argv

['/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py',
 '-f',
 '/root/.local/share/jupyter/runtime/kernel-44b48dc6-ae97-40e9-b586-c786c1e953bd.json']

In [None]:
%%writefile xnlg-ft.py

import os
import io
import argparse
import torch
import copy
import sys

import nltk
nltk.download('punkt')




def get_params():

  # parse parameters
  parser = argparse.ArgumentParser(description='Train on XNLG')

  # main parameters
  parser.add_argument("--exp_name", type=str, default="",
                      help="Experiment name")
  parser.add_argument("--dump_path", type=str, default="",
                      help="Experiment dump path")
  parser.add_argument("--exp_id", type=str, default="",
                      help="Experiment ID")

  parser.add_argument("--model_path", type=str, default="",
                      help="Model location")

  # data
  parser.add_argument("--data_path", type=str, default="",
                      help="Data path")
  parser.add_argument("--ds_name", type=str, default="xpersona",
                      help="name of dataset: xsumm or xgiga")
  parser.add_argument("--max_vocab", type=int, default=-1,
                      help="Maximum vocabulary size (-1 to disable)")
  parser.add_argument("--min_count", type=int, default=0,
                      help="Minimum vocabulary count")

  # batch parameters
  parser.add_argument("--max_len", type=int, default=256,
                      help="Maximum length of sentences (after BPE)")
  parser.add_argument("--max_len_q", type=int, default=256,
                    help="Maximum length of sentences (after BPE)")
  parser.add_argument("--max_len_a", type=int, default=256,
                    help="Maximum length of sentences (after BPE)")
  parser.add_argument("--max_len_e", type=int, default=256,
                    help="Maximum length of sentences (after BPE)")
  parser.add_argument("--group_by_size", type=bool_flag, default=False,
                      help="Sort sentences by size during the training")
  parser.add_argument("--batch_size", type=int, default=32,
                      help="Number of sentences per batch")
  parser.add_argument("--max_batch_size", type=int, default=0,
                      help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)")
  parser.add_argument("--tokens_per_batch", type=int, default=-1,
                      help="Number of tokens per batch")

  # model / optimization
  parser.add_argument("--finetune_layers", type=str, default='0:_1',
                      help="Layers to finetune. 0 = embeddings, _1 = last encoder layer")
  parser.add_argument("--weighted_training", type=bool_flag, default=False,
                      help="Use a weighted loss during training")
  parser.add_argument("--dropout", type=float, default=0,
                      help="Fine-tuning dropout")
  parser.add_argument("--optimizer_e", type=str, default="adam,lr=0.0001",
                      help="Embedder (pretrained model) optimizer")
  parser.add_argument("--optimizer_p", type=str, default="adam,lr=0.0001",
                      help="Projection (classifier) optimizer")
  parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001",
                      help="Projection (classifier) optimizer")                    
  parser.add_argument("--n_epochs", type=int, default=100,
                      help="Maximum number of epochs")
  parser.add_argument("--epoch_size", type=int, default=-1,
                      help="Epoch size (-1 for full pass over the dataset)")

  # debug
  parser.add_argument("--debug_train", type=bool_flag, default=False,
                      help="Use valid sets for train sets (faster loading)")
  parser.add_argument("--debug_slurm", type=bool_flag, default=False,
                      help="Debug multi-GPU / multi-node within a SLURM job")
  parser.add_argument("--sample_alpha", type=float, default=0,
                      help="Exponent for transforming word counts to probabilities (~word2vec sampling)")
  parser.add_argument("--word_pred", type=float, default=0.15,
                      help="Fraction of words for which we need to make a prediction")

  parser.add_argument("--max_dec_len", type=int, default=80,
                      help="Maximum length of target sentence (after BPE)")

  # decode with vocab

  parser.add_argument("--decode_with_vocab", type=bool_flag, default=False,
                      help="Decode with vocab")
  parser.add_argument("--decode_vocab_sizes", type=str, default="26000,20000",
                      help="decode_vocab_sizes")
  parser.add_argument("--vocab_path", type=str, default="",
                      help="vocab_path")

  # multi-gpu
  parser.add_argument("--local_rank", type=int, default=-1,
                      help="Multi-GPU - Local rank")
  parser.add_argument("--multi_gpu", type=bool_flag, default=False,
                      help="multi-gpu")

  parser.add_argument("--train_layers", type=str, default="",
                      help="train layers of encoder") 
  parser.add_argument("--n_enc_layers", type=int, default=0,
                      help="") 
  parser.add_argument("--n_dec_layers", type=int, default=0,
                      help="") 
  parser.add_argument("--fixed_embeddings", type=bool_flag, default=False,
                    help="fixed_embeddings")
  parser.add_argument("--fixed_position_embeddings", type=bool_flag, default=False,
                      help="fixed_position_embeddings")
  parser.add_argument("--fixed_lang_embeddings", type=bool_flag, default=False,
                      help="fixed_lang_embeddings")
  parser.add_argument("--fixed_task_embeddings", type=bool_flag, default=False,
                      help="fixed_task_embeddings")
  parser.add_argument("--beam_size", type=int, default=1,
                      help="")
  parser.add_argument("--no_init", type=str, default="None",
                      help="dont init with pretrained models")
  
  parser.add_argument("--train_directions", type=str, default="en-en",
                      help="")
  parser.add_argument("--eval_directions", type=str, default="",
                      help="")
  parser.add_argument("--emb_dim", type=int, default=-1,
                      help="Number of sentences per batch")
  parser.add_argument("--reload_emb", type=str, default="",
                      help="path to .vec produced by fasttext")
  parser.add_argument("--cut_dataset", type=int, default=-1,
                      help="Number of sentences in dataset. -1 for full dataset.")
  
  parser.add_argument("--device1", type=int, default=0, help="device id for the encoder")
  parser.add_argument("--device2", type=int, default=0, help="device id for the decoder")
  
  params = parser.parse_args()

  return params


def read_txt_embeddings(logger, path):
  """
  Reload pretrained embeddings from a text file.
  """
  import numpy as np
  word2id = {}
  vectors = []

  # load pretrained embeddings
  # _emb_dim_file = params.emb_dim
  _emb_dim_file = 0
  with io.open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
    for i, line in enumerate(f):
      if i == 0:
        split = line.split()
        assert len(split) == 2
        _emb_dim_file = int(split[1])
        continue
      word, vect = line.rstrip().split(' ', 1)
      vect = np.fromstring(vect, sep=' ')
      if word in word2id:
        logger.warning("Word \"%s\" found twice!" % word)
        continue
      if not vect.shape == (_emb_dim_file,):
        logger.warning("Invalid dimension (%i) for word \"%s\" in line %i."
                        % (vect.shape[0], word, i))
        continue
      assert vect.shape == (_emb_dim_file,)
      word2id[word] = len(word2id)
      vectors.append(vect[None])

  assert len(word2id) == len(vectors)
  logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path))

  # compute new vocabulary / embeddings
  embeddings = np.concatenate(vectors, 0)
  embeddings = torch.from_numpy(embeddings).float()

  # assert embeddings.size() == (len(word2id), params.emb_dim)
  return word2id, embeddings


def load_bin_embeddings(logger, path):
  """
  Reload pretrained embeddings from a fastText binary file.
  """
  import fasttext
  import numpy as np
  model = fasttext.load_model(path)
  words = model.get_labels()
  logger.info("Loaded binary model from %s" % path)

  # compute new vocabulary / embeddings
  embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0)
  embeddings = torch.from_numpy(embeddings).float()
  word2id = {w: i for i, w in enumerate(words)}
  logger.info("Generated embeddings for %i words." % len(words))

  return word2id, embeddings


def set_pretrain_emb(logger, model, dico, word2id, embeddings):
  """
  Pretrain word embeddings.
  """
  n_found = 0
  with torch.no_grad():
    for i in range(len(dico)):
      idx = word2id.get(dico[i], None)
      if idx is None:
        continue
      n_found += 1
      model.embeddings.weight[i] = embeddings[idx].cuda()
      try:
        model.pred_layer.proj.weight[i] = embeddings[idx].cuda()
      except AttributeError:
        pass
  logger.info("Pretrained %i/%i words (%.3f%%)."
              % (n_found, len(dico), 100. * n_found / len(dico)))


def str_to_class(str):
  return getattr(sys.modules[__name__], str)


def run_xnlg():
  params = get_params()

  # initialize the experiment / build sentence embedder
  logger = initialize_exp(params)

  if params.tokens_per_batch > -1:
    params.group_by_size = True

  # check parameters
  assert os.path.isdir(params.data_path)
  assert os.path.isfile(params.model_path)

  reloaded = torch.load(params.model_path)
  model_params = AttrDict(reloaded['params'])
  logger.info(
    "Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
  params.n_langs = model_params['n_langs']
  params.id2lang = model_params['id2lang']
  params.lang2id = model_params['lang2id']

  
  if "enc_params" in reloaded:
    encoder_model_params = AttrDict(reloaded["enc_params"])
  elif params.n_enc_layers == model_params.n_layers or params.n_enc_layers == 0:
    encoder_model_params = model_params
  else:
    encoder_model_params = AttrDict(reloaded['params'])
    encoder_model_params.n_layers = params.n_enc_layers
    assert model_params.n_layers is not encoder_model_params.n_layers
  
  if "dec_params" in reloaded:
    decoder_model_params = AttrDict(reloaded["dec_params"])
  elif params.n_dec_layers == model_params.n_layers or params.n_dec_layers == 0:
    decoder_model_params = model_params
  else:
    decoder_model_params = AttrDict(reloaded['params'])
    decoder_model_params.n_layers = params.n_dec_layers
    assert model_params.n_layers is not decoder_model_params.n_layers
  
  params.encoder_model_params = encoder_model_params
  params.decoder_model_params = decoder_model_params

  if params.emb_dim != -1:
    encoder_model_params.emb_dim = params.emb_dim
    decoder_model_params.emb_dim = params.emb_dim
  
  # build dictionary / build encoder / build decoder / reload weights
  dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])

  for p in [params, encoder_model_params, decoder_model_params]:
    p.n_words = len(dico)
    p.bos_index = dico.index(BOS_WORD)
    p.eos_index = dico.index(EOS_WORD)
    p.pad_index = dico.index(PAD_WORD)
    p.unk_index = dico.index(UNK_WORD)
    p.mask_index = dico.index(MASK_WORD)

  encoder = TransformerModel(encoder_model_params, dico, is_encoder=True, with_output=False)
  decoder = TransformerModel(decoder_model_params, dico, is_encoder=False, with_output=True)

  def _process_state_dict(state_dict):
    return {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}

  if params.no_init == "all":
    logger.info("All Models will not load state dict.!!!")
  elif params.reload_emb != "":
    logger.info("Reloading embedding from %s ..." % params.reload_emb)
    word2id, embeddings = read_txt_embeddings(logger, params.reload_emb)
    set_pretrain_emb(logger, encoder, dico, word2id, embeddings)
    set_pretrain_emb(logger, decoder, dico, word2id, embeddings)
  else:
    if "model" in reloaded:
      if params.no_init != "encoder":
        encoder.load_state_dict(_process_state_dict(reloaded['model']), strict=False)
      if params.no_init != "decoder":
        decoder.load_state_dict(_process_state_dict(reloaded['model']), strict=False)
    else:
      if params.no_init != "encoder":
        encoder.load_state_dict(_process_state_dict(reloaded['encoder']), strict=False)
      if params.no_init != "decoder":
        decoder.load_state_dict(_process_state_dict(reloaded['decoder']))
  
  scores = {}

  XPersona(encoder, decoder, scores, dico, params).run()
if __name__ == "__main__":
  run_xnlg()


Overwriting xnlg-ft.py


NameError: ignored

In [None]:
!python xnlg-ft.py --exp_name xpersona --exp_id ftOnZh --optimizer adam,lr=0.00001 --batch_size 8 --n_epochs 200 --epoch_size 3000 --max_len 120 --max_vocab 95000 --train_layers 1,10 --decode_with_vocab False --n_enc_layers 10 --n_dec_layers 6 --ds_name xpersona

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Traceback (most recent call last):
  File "xnlg-ft.py", line 315, in <module>
    run_xnlg()
  File "xnlg-ft.py", line 229, in run_xnlg
    params = get_params()
  File "xnlg-ft.py", line 50, in get_params
    parser.add_argument("--group_by_size", type=bool_flag, default=False,
NameError: name 'bool_flag' is not defined


In [None]:
!python xnlg-ft.py --exp_name xpersona --exp_id ftOnZh --dump_path ./dump --model_path ./data/pretrained_XNLG/en-zh_valid-en-zh.pth --data_path ./data/processed/XNLG --optimizer adam,lr=0.00001 --batch_size 8 --n_epochs 200 --epoch_size 3000 --max_len 120 --max_vocab 95000 --train_layers 1,10 --decode_with_vocab False --n_enc_layers 10 --n_dec_layers 6 --ds_name xpersona --train_directions en-en --eval_directions zh-zh

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Traceback (most recent call last):
  File "xnlg-ft.py", line 315, in <module>
    run_xnlg()
  File "xnlg-ft.py", line 229, in run_xnlg
    params = get_params()
  File "xnlg-ft.py", line 50, in get_params
    parser.add_argument("--group_by_size", type=bool_flag, default=False,
NameError: name 'bool_flag' is not defined


In [None]:
%%writefile test.py
import os
import io
import argparse
import torch
import copy
import sys

import nltk
nltk.download('punkt')




def get_params():

    # parse parameters
    parser = argparse.ArgumentParser(description='Train on XNLG')

    # main parameters
    parser.add_argument("--exp_name", type=str, default="",
                        help="Experiment name")
    parser.add_argument("--dump_path", type=str, default="",
                        help="Experiment dump path")
    parser.add_argument("--exp_id", type=str, default="",
                        help="Experiment ID")

    parser.add_argument("--model_path", type=str, default="",
                        help="Model location")
    parser.add_argument("--saved_path", type=str, default="",
                        help="saved location")

    # data
    parser.add_argument("--data_path", type=str, default="",
                        help="Data path")
    parser.add_argument("--ds_name", type=str, default="xpersona",
                        help="name of dataset: xsumm or xgiga")
    parser.add_argument("--max_vocab", type=int, default=-1,
                        help="Maximum vocabulary size (-1 to disable)")
    parser.add_argument("--min_count", type=int, default=0,
                        help="Minimum vocabulary count")

    # batch parameters
    parser.add_argument("--max_len", type=int, default=256,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--max_len_q", type=int, default=256,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--max_len_a", type=int, default=256,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--max_len_e", type=int, default=256,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--group_by_size", type=bool_flag, default=False,
                        help="Sort sentences by size during the training")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Number of sentences per batch")
    parser.add_argument("--max_batch_size", type=int, default=0,
                        help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)")
    parser.add_argument("--tokens_per_batch", type=int, default=-1,
                        help="Number of tokens per batch")

    # model / optimization
    parser.add_argument("--finetune_layers", type=str, default='0:_1',
                        help="Layers to finetune. 0 = embeddings, _1 = last encoder layer")
    parser.add_argument("--weighted_training", type=bool_flag, default=False,
                        help="Use a weighted loss during training")
    parser.add_argument("--dropout", type=float, default=0,
                        help="Fine-tuning dropout")
    parser.add_argument("--optimizer_e", type=str, default="adam,lr=0.0001",
                        help="Embedder (pretrained model) optimizer")
    parser.add_argument("--optimizer_p", type=str, default="adam,lr=0.0001",
                        help="Projection (classifier) optimizer")
    parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001",
                        help="Projection (classifier) optimizer")                    
    parser.add_argument("--n_epochs", type=int, default=100,
                        help="Maximum number of epochs")
    parser.add_argument("--epoch_size", type=int, default=-1,
                        help="Epoch size (-1 for full pass over the dataset)")

    # debug
    parser.add_argument("--debug_train", type=bool_flag, default=False,
                        help="Use valid sets for train sets (faster loading)")
    parser.add_argument("--debug_slurm", type=bool_flag, default=False,
                        help="Debug multi-GPU / multi-node within a SLURM job")
    parser.add_argument("--sample_alpha", type=float, default=0,
                        help="Exponent for transforming word counts to probabilities (~word2vec sampling)")
    parser.add_argument("--word_pred", type=float, default=0.15,
                        help="Fraction of words for which we need to make a prediction")

    parser.add_argument("--max_dec_len", type=int, default=80,
                        help="Maximum length of target sentence (after BPE)")

    # decode with vocab

    parser.add_argument("--decode_with_vocab", type=bool_flag, default=False,
                        help="Decode with vocab")
    parser.add_argument("--decode_vocab_sizes", type=str, default="26000,20000",
                        help="decode_vocab_sizes")
    parser.add_argument("--vocab_path", type=str, default="",
                        help="vocab_path")

    # multi-gpu
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="Multi-GPU - Local rank")
    parser.add_argument("--multi_gpu", type=bool_flag, default=False,
                        help="multi-gpu")

    parser.add_argument("--train_layers", type=str, default="",
                        help="train layers of encoder") 
    parser.add_argument("--n_enc_layers", type=int, default=0,
                        help="") 
    parser.add_argument("--n_dec_layers", type=int, default=0,
                        help="") 
    parser.add_argument("--fixed_embeddings", type=bool_flag, default=False,
                        help="fixed_embeddings")
    parser.add_argument("--fixed_position_embeddings", type=bool_flag, default=False,
                        help="fixed_position_embeddings")
    parser.add_argument("--fixed_lang_embeddings", type=bool_flag, default=False,
                        help="fixed_lang_embeddings")
    parser.add_argument("--fixed_task_embeddings", type=bool_flag, default=False,
                        help="fixed_task_embeddings")
    parser.add_argument("--beam_size", type=int, default=1,
                        help="")
    parser.add_argument("--no_init", type=str, default="None",
                        help="dont init with pretrained models")
    
    parser.add_argument("--train_directions", type=str, default="en-en",
                        help="")
    parser.add_argument("--eval_directions", type=str, default="",
                        help="")
    parser.add_argument("--emb_dim", type=int, default=-1,
                        help="Number of sentences per batch")
    parser.add_argument("--reload_emb", type=str, default="",
                        help="path to .vec produced by fasttext")
    parser.add_argument("--cut_dataset", type=int, default=-1,
                        help="Number of sentences in dataset. -1 for full dataset.")
    
    parser.add_argument("--device1", type=int, default=3, help="device id for the encoder")
    parser.add_argument("--device2", type=int, default=4, help="device id for the decoder")

    params = parser.parse_args()

    return params


def run_test():
    params = get_params()

    # initialize the experiment / build sentence embedder
    logger = initialize_exp(params)

    if params.tokens_per_batch > -1:
        params.group_by_size = True
    
    # check parameters
    assert os.path.isdir(params.data_path)
    assert os.path.isfile(params.saved_path)
    device = torch.device('cpu')
    reloaded = torch.load(params.saved_path, map_location=device)
    model_params = AttrDict(reloaded['params'])
    logger.info(
        "Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
    params.n_langs = model_params['n_langs']
    params.id2lang = model_params['id2lang']
    params.lang2id = model_params['lang2id']

    if "enc_params" in reloaded:
        encoder_model_params = AttrDict(reloaded["enc_params"])
    elif params.n_enc_layers == model_params.n_layers or params.n_enc_layers == 0:
        encoder_model_params = model_params
    else:
        encoder_model_params = AttrDict(reloaded['params'])
        encoder_model_params.n_layers = params.n_enc_layers
        assert model_params.n_layers is not encoder_model_params.n_layers
    
    if "dec_params" in reloaded:
        decoder_model_params = AttrDict(reloaded["dec_params"])
    elif params.n_dec_layers == model_params.n_layers or params.n_dec_layers == 0:
        decoder_model_params = model_params
    else:
        decoder_model_params = AttrDict(reloaded['params'])
        decoder_model_params.n_layers = params.n_dec_layers
        assert model_params.n_layers is not decoder_model_params.n_layers
    
    params.encoder_model_params = encoder_model_params
    params.decoder_model_params = decoder_model_params

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])

    for p in [params, encoder_model_params, decoder_model_params]:
        p.n_words = len(dico)
        p.bos_index = dico.index(BOS_WORD)
        p.eos_index = dico.index(EOS_WORD)
        p.pad_index = dico.index(PAD_WORD)
        p.unk_index = dico.index(UNK_WORD)
        p.mask_index = dico.index(MASK_WORD)

    encoder = TransformerModel(encoder_model_params, dico, is_encoder=True, with_output=False)
    decoder = TransformerModel(decoder_model_params, dico, is_encoder=False, with_output=True)

    encoder.load_state_dict(reloaded["encoder"])
    decoder.load_state_dict(reloaded["decoder"])
    
    scores = {}
    XPersona(encoder, decoder, scores, dico, params).test()



Writing test.py


In [None]:
!python test.py --exp_name testonZh --dump_path ./dump --saved_path ./dump/xpersona/ftOnZh/best_zh-zh_Perplexity.pth --data_path ./data/processed/XNLG --optimizer adam,lr=0.00001 --batch_size 8 --n_epochs 200 --epoch_size 3000 --max_len 120 --max_vocab 95000 --train_layers 1,10 --decode_with_vocab False --n_enc_layers 10 --n_dec_layers 6 --ds_name xpersona --train_directions en-en --eval_directions zh-zh

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Traceback (most recent call last):
  File "test.py", line 207, in <module>
    run_test()
  File "test.py", line 145, in run_test
    params = get_params()
  File "test.py", line 51, in get_params
    parser.add_argument("--group_by_size", type=bool_flag, default=False,
NameError: name 'bool_flag' is not defined
