In [145]:
%load_ext nb_black
%config IPCompleter.greedy=True

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [2]:
import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange

from transformers import GPT2Tokenizer
from transformers.modeling_gpt2 import GPT2LMHeadModel
from transformers.file_utils import cached_path

I0504 23:21:58.036933 139800908659584 file_utils.py:41] PyTorch version 1.4.0+cpu available.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  from ._conv import register_converters as _register_converters


<IPython.core.display.Javascript object>

In [3]:
PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10

BAG_OF_WORDS_ARCHIVE_MAP = {
    "legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
    "military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
    "politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
    "religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
    "science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
    "space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
    "technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
}

DISCRIMINATOR_MODELS_PARAMS = {
    "clickbait": {
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
        "class_size": 2,
        "embed_size": 1024,
        "class_vocab": {"non_clickbait": 0, "clickbait": 1},
        "default_class": 1,
        "pretrained_model": "gpt2-medium",
    },
    "sentiment": {
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
        "class_size": 5,
        "embed_size": 1024,
        "class_vocab": {"very_positive": 2, "very_negative": 3},
        "default_class": 3,
        "pretrained_model": "gpt2-medium",
    },
}

<IPython.core.display.Javascript object>

In [4]:
def top_k_filter(logits, k, probs=False):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        if probs:
            return torch.where(
                logits < batch_mins, torch.ones_like(logits) * 0.0, logits
            )
        return torch.where(
            logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits
        )

<IPython.core.display.Javascript object>

In [5]:
def perturb_past(
    past,
    model,
    last,
    unpert_past=None,
    unpert_logits=None,
    accumulated_hidden=None,
    grad_norms=None,
    stepsize=0.01,
    one_hot_bows_vectors=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    num_iterations=3,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    kl_scale=0.01,
    device="cuda",
):
    # Generate inital perturbed past
    grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]

    if accumulated_hidden is None:
        accumulated_hidden = 0

    if decay:
        decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
    else:
        decay_mask = 1.0

    # TODO fix this comment (SUMANTH)
    # Generate a mask is gradient perturbated is based on a past window
    _, _, _, curr_length, _ = past[0].shape

    if curr_length > window_length and window_length > 0:
        ones_key_val_shape = (
            tuple(past[0].shape[:-2])
            + tuple([window_length])
            + tuple(past[0].shape[-1:])
        )

        zeros_key_val_shape = (
            tuple(past[0].shape[:-2])
            + tuple([curr_length - window_length])
            + tuple(past[0].shape[-1:])
        )

        ones_mask = torch.ones(ones_key_val_shape)
        ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
        ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

        window_mask = torch.cat(
            (ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2
        ).to(device)
    else:
        window_mask = torch.ones_like(past[0]).to(device)

    # accumulate perturbations for num_iterations
    loss_per_iter = []
    new_accumulated_hidden = None
    for i in range(num_iterations):
        print("Iteration ", i + 1)
        curr_perturbation = [
            to_var(torch.from_numpy(p_), requires_grad=True, device=device)
            for p_ in grad_accumulator
        ]

        # Compute hidden using perturbed past
        perturbed_past = list(map(add, past, curr_perturbation))
        _, _, _, curr_length, _ = curr_perturbation[0].shape
        all_logits, _, all_hidden = model(last, past=perturbed_past)
        hidden = all_hidden[-1]
        new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
        # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
        logits = all_logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        loss = 0.0
        loss_list = []
        if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
            for one_hot_bow in one_hot_bows_vectors:
                bow_logits = torch.mm(probs, torch.t(one_hot_bow))
                bow_loss = -torch.log(torch.sum(bow_logits))
                loss += bow_loss
                loss_list.append(bow_loss)
            print(" pplm_bow_loss:", loss.data.cpu().numpy())

        if loss_type == 2 or loss_type == 3:
            ce_loss = torch.nn.CrossEntropyLoss()
            # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
            curr_unpert_past = unpert_past
            curr_probs = torch.unsqueeze(probs, dim=1)
            wte = model.resize_token_embeddings()
            for _ in range(horizon_length):
                inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
                _, curr_unpert_past, curr_all_hidden = model(
                    past=curr_unpert_past, inputs_embeds=inputs_embeds
                )
                curr_hidden = curr_all_hidden[-1]
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(
                    curr_hidden, dim=1
                )

            prediction = classifier(
                new_accumulated_hidden / (curr_length + 1 + horizon_length)
            )

            label = torch.tensor(
                prediction.shape[0] * [class_label], device=device, dtype=torch.long
            )
            discrim_loss = ce_loss(prediction, label)
            print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
            loss += discrim_loss
            loss_list.append(discrim_loss)

        kl_loss = 0.0
        if kl_scale > 0.0:
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
            unpert_probs = (
                unpert_probs
                + SMALL_CONST
                * (unpert_probs <= SMALL_CONST).float().to(device).detach()
            )
            correction = (
                SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
            )
            corrected_probs = probs + correction.detach()
            kl_loss = kl_scale * (
                (corrected_probs * (corrected_probs / unpert_probs).log()).sum()
            )
            print(" kl_loss", kl_loss.data.cpu().numpy())
            loss += kl_loss

        loss_per_iter.append(loss.data.cpu().numpy())
        print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())

        # compute gradients
        loss.backward()

        # calculate gradient norms
        if grad_norms is not None and loss_type == PPLM_BOW:
            grad_norms = [
                torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
                for index, p_ in enumerate(curr_perturbation)
            ]
        else:
            grad_norms = [
                (torch.norm(p_.grad * window_mask) + SMALL_CONST)
                for index, p_ in enumerate(curr_perturbation)
            ]

        # normalize gradients
        grad = [
            -stepsize
            * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
            for index, p_ in enumerate(curr_perturbation)
        ]

        # accumulate gradient
        grad_accumulator = list(map(add, grad, grad_accumulator))

        # reset gradients, just to make sure
        for p_ in curr_perturbation:
            p_.grad.data.zero_()

        # removing past from the graph
        new_past = []
        for p_ in past:
            new_past.append(p_.detach())
        past = new_past

    # apply the accumulated perturbations to the past
    grad_accumulator = [
        to_var(torch.from_numpy(p_), requires_grad=True, device=device)
        for p_ in grad_accumulator
    ]
    pert_past = list(map(add, past, grad_accumulator))

    return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter

<IPython.core.display.Javascript object>

In [6]:
class ClassificationHead(torch.nn.Module):
    """Classification Head for  transformer encoders"""

    def __init__(self, class_size, embed_size):
        super().__init__()
        self.class_size = class_size
        self.embed_size = embed_size
        self.mlp = torch.nn.Linear(embed_size, class_size)

    def forward(self, hidden_state):
        logits = self.mlp(hidden_state)
        return logits

<IPython.core.display.Javascript object>

In [7]:
def get_classifier(
    name: Optional[str], class_label: Union[str, int], device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
    classifier = ClassificationHead(
        class_size=params["class_size"], embed_size=params["embed_size"]
    ).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError(
            "Either url or path have to be specified "
            "in the discriminator model parameters"
        )
    classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
        else:
            label_id = params["default_class"]
            print("class_label {} not in class_vocab".format(class_label))
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
        else:
            label_id = params["default_class"]
            print("class_label {} not in class_vocab".format(class_label))
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id

<IPython.core.display.Javascript object>

In [8]:
def get_bag_of_words_indices(
    bag_of_words_ids_or_paths: List[str], tokenizer
) -> List[List[List[int]]]:
    bow_indices = []
    for id_or_path in bag_of_words_ids_or_paths:
        if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
            filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
        else:
            filepath = id_or_path
        with open(filepath, "r") as f:
            words = f.read().strip().split("\n")
        bow_indices.append(
            [tokenizer.encode(word.strip(), add_prefix_space=True) for word in words]
        )
    return bow_indices

<IPython.core.display.Javascript object>

In [9]:
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
    if bow_indices is None:
        return None

    one_hot_bows_vectors = []
    for single_bow in bow_indices:
        single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
        single_bow = torch.tensor(single_bow).to(device)
        num_words = single_bow.shape[0]
        one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
        one_hot_bow.scatter_(1, single_bow, 1)
        one_hot_bows_vectors.append(one_hot_bow)
    return one_hot_bows_vectors

<IPython.core.display.Javascript object>

In [10]:
def full_text_generation(
    model,
    tokenizer,
    context=None,
    num_samples=1,
    device="cuda",
    bag_of_words=None,
    discrim=None,
    class_label=None,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    repetition_penalty=1.0,
    **kwargs
):
    classifier, class_id = get_classifier(discrim, class_label, device)

    bow_indices = []
    if bag_of_words:
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)

    if bag_of_words and classifier:
        print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
        loss_type = PPLM_BOW_DISCRIM

    elif bag_of_words:
        loss_type = PPLM_BOW
        print("Using PPLM-BoW")

    elif classifier is not None:
        loss_type = PPLM_DISCRIM
        print("Using PPLM-Discrim")

    else:
        raise Exception("Specify either a bag of words or a discriminator")

    unpert_gen_tok_text, _, _ = generate_text_pplm(
        model=model,
        tokenizer=tokenizer,
        context=context,
        device=device,
        length=length,
        sample=sample,
        perturb=False,
        repetition_penalty=repetition_penalty,
    )
    if device == "cuda":
        torch.cuda.empty_cache()

    pert_gen_tok_texts = []
    discrim_losses = []
    losses_in_time = []

    for i in range(num_samples):
        pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
            model=model,
            tokenizer=tokenizer,
            context=context,
            device=device,
            perturb=True,
            bow_indices=bow_indices,
            classifier=classifier,
            class_label=class_id,
            loss_type=loss_type,
            length=length,
            stepsize=stepsize,
            temperature=temperature,
            top_k=top_k,
            sample=sample,
            num_iterations=num_iterations,
            grad_length=grad_length,
            horizon_length=horizon_length,
            window_length=window_length,
            decay=decay,
            gamma=gamma,
            gm_scale=gm_scale,
            kl_scale=kl_scale,
            repetition_penalty=repetition_penalty,
        )
        pert_gen_tok_texts.append(pert_gen_tok_text)
        if classifier is not None:
            discrim_losses.append(discrim_loss.data.cpu().numpy())
        losses_in_time.append(loss_in_time)

    if device == "cuda":
        torch.cuda.empty_cache()

    return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time


<IPython.core.display.Javascript object>

In [11]:
def generate_text_pplm(
    model,
    tokenizer,
    context=None,
    past=None,
    device="cuda",
    perturb=True,
    bow_indices=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    repetition_penalty=1.0,
):
    output_so_far = None
    if context:
        context_t = torch.tensor(context, device=device, dtype=torch.long)
        while len(context_t.shape) < 2:
            context_t = context_t.unsqueeze(0)
        output_so_far = context_t

    # collect one hot vectors for bags of words
    one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)

    grad_norms = None
    last = None
    unpert_discrim_loss = 0
    loss_in_time = []
    for i in trange(length, ascii=True):

        # Get past/probs for current output, except for last word
        # Note that GPT takes 2 inputs: past + current_token

        # run model forward to obtain unperturbed
        if past is None and output_so_far is not None:
            last = output_so_far[:, -1:]
            if output_so_far.shape[1] > 1:
                _, past, _ = model(output_so_far[:, :-1])

        unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
        unpert_last_hidden = unpert_all_hidden[-1]

        # check if we are abowe grad max length
        if i >= grad_length:
            current_stepsize = stepsize * 0
        else:
            current_stepsize = stepsize

        # modify the past if necessary
        if not perturb or num_iterations == 0:
            pert_past = past

        else:
            accumulated_hidden = unpert_last_hidden[:, :-1, :]
            accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

            if past is not None:
                pert_past, _, grad_norms, loss_this_iter = perturb_past(
                    past,
                    model,
                    last,
                    unpert_past=unpert_past,
                    unpert_logits=unpert_logits,
                    accumulated_hidden=accumulated_hidden,
                    grad_norms=grad_norms,
                    stepsize=current_stepsize,
                    one_hot_bows_vectors=one_hot_bows_vectors,
                    classifier=classifier,
                    class_label=class_label,
                    loss_type=loss_type,
                    num_iterations=num_iterations,
                    horizon_length=horizon_length,
                    window_length=window_length,
                    decay=decay,
                    gamma=gamma,
                    kl_scale=kl_scale,
                    device=device,
                )
                loss_in_time.append(loss_this_iter)
            else:
                pert_past = past

        pert_logits, past, pert_all_hidden = model(last, past=pert_past)
        pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST

        for token_idx in set(output_so_far[0].tolist()):
            if pert_logits[0, token_idx] < 0:
                pert_logits[0, token_idx] *= repetition_penalty
            else:
                pert_logits[0, token_idx] /= repetition_penalty

        pert_probs = F.softmax(pert_logits, dim=-1)

        if classifier is not None:
            ce_loss = torch.nn.CrossEntropyLoss()
            prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
            label = torch.tensor([class_label], device=device, dtype=torch.long)
            unpert_discrim_loss = ce_loss(prediction, label)
            print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
        else:
            unpert_discrim_loss = 0

        # Fuse the modified model and original model
        if perturb:

            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)

            pert_probs = (pert_probs ** gm_scale) * (
                unpert_probs ** (1 - gm_scale)
            )  # + SMALL_CONST
            pert_probs = top_k_filter(pert_probs, k=top_k, probs=True)  # + SMALL_CONST

            # rescale
            if torch.sum(pert_probs) <= 1:
                pert_probs = pert_probs / torch.sum(pert_probs)

        else:
            pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST
            pert_probs = F.softmax(pert_logits, dim=-1)

        # sample or greedy
        if sample:
            last = torch.multinomial(pert_probs, num_samples=1)

        else:
            _, last = torch.topk(pert_probs, k=1, dim=-1)

        # update context/output_so_far appending the new token
        output_so_far = (
            last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
        )

        print(tokenizer.decode(output_so_far.tolist()[0]))

    return output_so_far, unpert_discrim_loss, loss_in_time

<IPython.core.display.Javascript object>

In [12]:
def set_generic_model_params(discrim_weights, discrim_meta):
    if discrim_weights is None:
        raise ValueError(
            "When using a generic discriminator, "
            "discrim_weights need to be specified"
        )
    if discrim_meta is None:
        raise ValueError(
            "When using a generic discriminator, " "discrim_meta need to be specified"
        )

    with open(discrim_meta, "r") as discrim_meta_file:
        meta = json.load(discrim_meta_file)
    meta["path"] = discrim_weights
    DISCRIMINATOR_MODELS_PARAMS["generic"] = meta

<IPython.core.display.Javascript object>

In [13]:
def run_pplm_example(
    pretrained_model="gpt2-medium",
    cond_text="",
    uncond=False,
    num_samples=1,
    bag_of_words=None,
    discrim=None,
    discrim_weights=None,
    discrim_meta=None,
    class_label=-1,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    seed=0,
    no_cuda=False,
    colorama=False,
    repetition_penalty=1.0,
):
    # set Random seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set the device
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"

    if discrim == "generic":
        set_generic_model_params(discrim_weights, discrim_meta)

    if discrim is not None:
        pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
        print(
            "discrim = {}, pretrained_model set "
            "to discriminator's = {}".format(discrim, pretrained_model)
        )

    # load pretrained model
    model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
    model.to(device)
    model.eval()

    # load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)

    # Freeze GPT-2 weights
    for param in model.parameters():
        param.requires_grad = False

    # figure out conditioning text
    if uncond:
        tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
    else:
        raw_text = cond_text
        while not raw_text:
            print("Did you forget to add `--cond_text`? ")
            raw_text = input("Model prompt >>> ")
        tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)

    print("= Prefix of sentence =")
    print(tokenizer.decode(tokenized_cond_text))
    print()

    # generate unperturbed and perturbed texts

    # full_text_generation returns:
    # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
    unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
        model=model,
        tokenizer=tokenizer,
        context=tokenized_cond_text,
        device=device,
        num_samples=num_samples,
        bag_of_words=bag_of_words,
        discrim=discrim,
        class_label=class_label,
        length=length,
        stepsize=stepsize,
        temperature=temperature,
        top_k=top_k,
        sample=sample,
        num_iterations=num_iterations,
        grad_length=grad_length,
        horizon_length=horizon_length,
        window_length=window_length,
        decay=decay,
        gamma=gamma,
        gm_scale=gm_scale,
        kl_scale=kl_scale,
        repetition_penalty=repetition_penalty,
    )

    # untokenize unperturbed text
    unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])

    print("=" * 80)
    print("= Unperturbed generated text =")
    print(unpert_gen_text)
    print()

    generated_texts = []

    bow_word_ids = set()
    if bag_of_words and colorama:
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
        for single_bow_list in bow_indices:
            # filtering all words in the list composed of more than 1 token
            filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
            # w[0] because we are sure w has only 1 item because previous fitler
            bow_word_ids.update(w[0] for w in filtered)

    # iterate through the perturbed texts
    for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
        try:
            # untokenize unperturbed text
            if colorama:
                import colorama

                pert_gen_text = ""
                for word_id in pert_gen_tok_text.tolist()[0]:
                    if word_id in bow_word_ids:
                        pert_gen_text += "{}{}{}".format(
                            colorama.Fore.RED,
                            tokenizer.decode([word_id]),
                            colorama.Style.RESET_ALL,
                        )
                    else:
                        pert_gen_text += tokenizer.decode([word_id])
            else:
                pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])

            print("= Perturbed generated text {} =".format(i + 1))
            print(pert_gen_text)
            print()
        except Exception as exc:
            print("Ignoring error while generating perturbed text:", exc)

        # keep the prefix, perturbed seq, original seq for each index
        generated_texts.append(
            (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
        )

    return

<IPython.core.display.Javascript object>

In [155]:
args = {
    "pretrained_model": "gpt2-medium",
    "cond_text": "The lake",
    "uncond": True,
    "num_samples": 1,
    "bag_of_words": "military",
    "discrim": "sentiment",
    "discrim_weights": None,
    "discrim_meta": None,
    "class_label": -1,
    "length": 100,
    "stepsize": 0.03,
    "temperature": 1.0,
    "top_k": 10,
    "sample": True,
    "num_iterations": 3,
    "grad_length": 10000,
    "window_length": 5,
    "horizon_length": 1,
    "decay": True,
    "gamma": 1.5,
    "gm_scale": 0.9,
    "kl_scale": 0.01,
    "seed": 0,
    "colorama": False,
    "repetition_penalty": 1.0,
}

<IPython.core.display.Javascript object>

In [15]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

<IPython.core.display.Javascript object>

In [16]:
#run_pplm_example(**args)

<IPython.core.display.Javascript object>

In [17]:
torch.manual_seed(args["seed"])
np.random.seed(args["seed"])

<IPython.core.display.Javascript object>

In [18]:
if args["discrim"] == "generic":
    set_generic_model_params(discrim_weights, discrim_meta)

if args["discrim"] is not None:
    pretrained_model = DISCRIMINATOR_MODELS_PARAMS[args["discrim"]]["pretrained_model"]
    print(
        "discrim = {}, pretrained_model set "
        "to discriminator's = {}".format(args["discrim"], pretrained_model)
    )

# load pretrained model
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
model.to(device)
model.eval();

discrim = sentiment, pretrained_model set to discriminator's = gpt2-medium


I0504 23:22:34.408077 139800908659584 configuration_utils.py:283] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json from cache at /home/u37216/.cache/torch/transformers/98aa65385e18b0efd17acd8bf64dcdf21406bb0c99c801c2d3c9f6bfd1f48f29.42ee920fcfb8dd7cf21fdc10f45b4545a7050ec5dd5463e844c310dd9beeae87
I0504 23:22:34.410663 139800908659584 configuration_utils.py:319] Model config GPT2Config {
  "_num_labels": 2,
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bad_words_ids": null,
  "bos_token_id": 50256,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_epsilon": 1

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2):

<IPython.core.display.Javascript object>

In [19]:
pretrained_model

'gpt2-medium'

<IPython.core.display.Javascript object>

In [20]:
DISCRIMINATOR_MODELS_PARAMS["sentiment"]

{'url': 'https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt',
 'class_size': 5,
 'embed_size': 1024,
 'class_vocab': {'very_positive': 2, 'very_negative': 3},
 'default_class': 3,
 'pretrained_model': 'gpt2-medium'}

<IPython.core.display.Javascript object>

In [21]:
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)

I0504 23:22:55.666390 139800908659584 tokenization_utils.py:504] loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json from cache at /home/u37216/.cache/torch/transformers/f20f05d3ae37c4e3cd56764d48e566ea5adeba153dcee6eb82a18822c9c731ec.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
I0504 23:22:55.667993 139800908659584 tokenization_utils.py:504] loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt from cache at /home/u37216/.cache/torch/transformers/6d882670c55563617571fe0c97df88626fb5033927b40fc18a8acf98dafd4946.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda


<IPython.core.display.Javascript object>

In [22]:
tokenizer

<transformers.tokenization_gpt2.GPT2Tokenizer at 0x7f255a64d240>

<IPython.core.display.Javascript object>

In [23]:
for param in model.parameters():
    param.requires_grad = False

<IPython.core.display.Javascript object>

In [24]:
tokenized_cond_text = tokenizer.encode([tokenizer.bos_token]) #performs tokenization of bos token

<IPython.core.display.Javascript object>

In [33]:
params = DISCRIMINATOR_MODELS_PARAMS[args['discrim']]
params

{'url': 'https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt',
 'class_size': 5,
 'embed_size': 1024,
 'class_vocab': {'very_positive': 2, 'very_negative': 3},
 'default_class': 3,
 'pretrained_model': 'gpt2-medium'}

<IPython.core.display.Javascript object>

In [34]:
classifier = ClassificationHead(
        class_size=params["class_size"], embed_size=params["embed_size"]
    ).to(device)

<IPython.core.display.Javascript object>

In [43]:
if "url" in params:
    resolved_archive_file = cached_path(params["url"])
elif "path" in params:
    resolved_archive_file = params["path"]
else:
    raise ValueError(
        "Either url or path have to be specified "
        "in the discriminator model parameters"
    )
print(resolved_archive_file)
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
classifier.eval()

/home/u37216/.cache/torch/transformers/42357c6dbedfbfd9f1a59709d4553429b37f25cb0aff8e3c7c8b8291c55dbbc9.a0039fb75d9c0c460975276b6875387756815e189a964d9478dc215178f9cb0c


ClassificationHead(
  (mlp): Linear(in_features=1024, out_features=5, bias=True)
)

<IPython.core.display.Javascript object>

In [47]:
params['class_vocab']

{'very_positive': 2, 'very_negative': 3}

<IPython.core.display.Javascript object>

In [152]:
args['class_label'] = 'very_positive'
args['window_length'] = 5

<IPython.core.display.Javascript object>

In [49]:
classifier, class_id = get_classifier(args["discrim"], args["class_label"], device)

<IPython.core.display.Javascript object>

In [28]:
class_id

3

<IPython.core.display.Javascript object>

In [None]:
bag_of_words_ids_or_paths = args['bag_of_words'].split(';')

In [53]:
bow_indices = []
if args['bag_of_words']:
    bow_indices = get_bag_of_words_indices(args['bag_of_words'].split(";"), tokenizer)

if args['bag_of_words'] and classifier:
    print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
    loss_type = PPLM_BOW_DISCRIM

elif bag_of_words:
    loss_type = PPLM_BOW
    print("Using PPLM-BoW")

elif classifier is not None:
    loss_type = PPLM_DISCRIM
    print("Using PPLM-Discrim")

else:
    raise Exception("Specify either a bag of words or a discriminator")

Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.


<IPython.core.display.Javascript object>

In [71]:
context = tokenized_cond_text

<IPython.core.display.Javascript object>

In [72]:
output_so_far = None
if context:
    context_t = torch.tensor(context, device=device, dtype=torch.long)
    while len(context_t.shape) < 2:
        context_t = context_t.unsqueeze(0)
    output_so_far = context_t

# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)

<IPython.core.display.Javascript object>

In [79]:
one_hot_bows_vectors[0].shape

torch.Size([149, 50257])

<IPython.core.display.Javascript object>

In [90]:
len(single_bow)

149

<IPython.core.display.Javascript object>

In [112]:
grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = []
past = None
perturb = True
num_iterations = 3
temperature = 1.0
repetition_penalty = 1.0
class_label = class_id

<IPython.core.display.Javascript object>

In [116]:
for i in trange(args['length'], ascii=True):

    # Get past/probs for current output, except for last word
    # Note that GPT takes 2 inputs: past + current_token

    # run model forward to obtain unperturbed
    if past is None and output_so_far is not None:
        last = output_so_far[:, -1:]
        print(last.shape)
        if output_so_far.shape[1] > 1:
            _, past, _ = model(output_so_far[:, :-1])

    unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
    unpert_last_hidden = unpert_all_hidden[-1]

    # check if we are abowe grad max length
    if i >= args['grad_length']:
        current_stepsize = args['stepsize'] * 0
    else:
        current_stepsize = args['stepsize']
    
    break
    
    # modify the past if necessary
    if not perturb or num_iterations == 0:
        pert_past = past

    else:
        accumulated_hidden = unpert_last_hidden[:, :-1, :]
        accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

        if past is not None:
            pert_past, _, grad_norms, loss_this_iter = perturb_past(
                past,
                model,
                last,
                unpert_past=unpert_past,
                unpert_logits=unpert_logits,
                accumulated_hidden=accumulated_hidden,
                grad_norms=grad_norms,
                stepsize=current_stepsize,
                one_hot_bows_vectors=one_hot_bows_vectors,
                classifier=classifier,
                class_label=class_label,
                loss_type=loss_type,
                num_iterations=num_iterations,
                horizon_length=horizon_length,
                window_length=window_length,
                decay=decay,
                gamma=gamma,
                kl_scale=kl_scale,
                device=device,
            )
            loss_in_time.append(loss_this_iter)
        else:
            pert_past = past

    pert_logits, past, pert_all_hidden = model(last, past=pert_past)
    pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST

    for token_idx in set(output_so_far[0].tolist()):
        if pert_logits[0, token_idx] < 0:
            pert_logits[0, token_idx] *= repetition_penalty
        else:
            pert_logits[0, token_idx] /= repetition_penalty

    pert_probs = F.softmax(pert_logits, dim=-1)

    if classifier is not None:
        ce_loss = torch.nn.CrossEntropyLoss()
        prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
        label = torch.tensor([class_label], device=device, dtype=torch.long)
        unpert_discrim_loss = ce_loss(prediction, label)
        print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
    else:
        unpert_discrim_loss = 0

    # Fuse the modified model and original model
    if perturb:

        unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)

        pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - args['gm_scale']))  # + SMALL_CONST
        pert_probs = top_k_filter(pert_probs, k=top_k, probs=True)  # + SMALL_CONST

        # rescale
        if torch.sum(pert_probs) <= 1:
            pert_probs = pert_probs / torch.sum(pert_probs)

    else:
        pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST
        pert_probs = F.softmax(pert_logits, dim=-1)

    # sample or greedy
    if sample:
        last = torch.multinomial(pert_probs, num_samples=1)

    else:
        _, last = torch.topk(pert_probs, k=1, dim=-1)

    # update context/output_so_far appending the new token
    output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)

    print(tokenizer.decode(output_so_far.tolist()[0]))



  0%|          | 0/100 [00:00<?, ?it/s]


<IPython.core.display.Javascript object>

In [117]:
if past is None and output_so_far is not None:
    last = output_so_far[:, -1:]
    print(last.shape)
    if output_so_far.shape[1] > 1:
        _, past, _ = model(output_so_far[:, :-1])

unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) # can take tensor([[50256]]) as input
# unpert_logits: torch.Size([1, 1, 50257]), 
# unpert_past[0]: torch.Size([2, 1, 16, 1, 64])  #length: 24
# unpert_all_hidden[0]: torch.Size([1, 1, 1024]) #length: 25
unpert_last_hidden = unpert_all_hidden[-1] #torch.Size([1, 1, 1024])

<IPython.core.display.Javascript object>

In [136]:
if i >= args['grad_length']:
    current_stepsize = args['stepsize'] * 0
else:
    current_stepsize = args['stepsize']

# modify the past if necessary
if not perturb or num_iterations == 0:
    pert_past = past

<IPython.core.display.Javascript object>

In [137]:
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

<IPython.core.display.Javascript object>

In [144]:
accumulated_hidden

tensor([[0., 0., 0.,  ..., 0., 0., 0.]])

<IPython.core.display.Javascript object>

### perturb_past

In [146]:
grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]

<IPython.core.display.Javascript object>

In [153]:
if accumulated_hidden is None:
    accumulated_hidden = 0

if args['decay']:
    decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (args['window_length']))[1:]
else:
    decay_mask = 1.0


<IPython.core.display.Javascript object>

In [156]:
_, _, _, curr_length, _ = past[0].shape

<IPython.core.display.Javascript object>

In [168]:
if curr_length > args['window_length'] and args['window_length'] > 0:
    ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([args['window_length']]) + tuple(past[0].shape[-1:])
    #(2, 1, 16, 5, 64)
    zeros_key_val_shape = (
        tuple(past[0].shape[:-2]) + tuple([curr_length - args['window_length']]) + tuple(past[0].shape[-1:])
    )
    #(2, 1, 16, -4, 64)
    ones_mask = torch.ones(ones_key_val_shape)
    ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
    ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

    window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
else:
    window_mask = torch.ones_like(past[0]).to(device)

<IPython.core.display.Javascript object>

In [169]:
tuple(past[0].shape[:-2]) + tuple([curr_length - args['window_length']]) + tuple(past[0].shape[-1:])

(2, 1, 16, -4, 64)

<IPython.core.display.Javascript object>

In [176]:
loss_per_iter = []
new_accumulated_hidden = None

<IPython.core.display.Javascript object>

In [261]:
for i in range(num_iterations):
    print("Iteration ", i + 1)
    curr_perturbation = [
        torch.from_numpy(p_).requires_grad_(True).to(device) for p_ in grad_accumulator
    ] #torch.Size([2, 1, 16, 1, 64])
    perturbed_past = list(map(add, past, curr_perturbation))
    _, _, _, curr_length, _ = curr_perturbation[0].shape  # 1
    all_logits, _, all_hidden = model(last, past=perturbed_past)
    # torch.Size([1, 1, 50257]), 
    # torch.Size([2, 1, 16, 2, 64]), length: 24
    # torch.Size([1, 1, 1024])
    hidden = all_hidden[-1]
    new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
    logits = all_logits[:, -1, :]
    probs = F.softmax(logits, dim = -1)
    
    loss = 0.0
    loss_list = []
    
    if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
        for one_hot_bow in one_hot_bows_vectors:
            bow_logits = torch.mm(probs, torch.t(one_hot_bow))
            bow_loss = -torch.log(torch.sum(bow_logits))
            loss += bow_loss
            loss_list.append(loss)
        print("pplm_bow_loss:", loss.data.cpu().numpy())
    
    if loss_type == 2 or loss_type == 3:
        loss_fn = nn.CrossEntropyLoss()
        curr_unpert_past = unpert_past
        curr_probs = torch.unsqueeze(probs, dim=1)
        wte = model.resize_token_embeddings() #Embedding(50257, 1024)
        for _ in range(args['horizon_length']): 
            inputs_embeds = torch.matmul(curr_probs, wte.weight.data) #torch.Size([1, 1, 1024])
            _, curr_unpert_past, curr_all_hidden = model(past = curr_unpert_past,
                                                        inputs_embeds =inputs_embeds)
            curr_hidden = curr_all_hidden[-1]
            new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1) #torch.Size([1, 1024])
        
        prediction = classifier(new_accumulated_hidden/(curr_length + 1 + args['horizon_length']) ) #torch.Size([1, 5])
        label = torch.tensor(prediction.shape[0]*[class_label], device=device, dtype=torch.long)
        
    break

Iteration  1
pplm_bow_loss: 14.222659


<IPython.core.display.Javascript object>

In [264]:
prediction.shape

torch.Size([1, 5])

<IPython.core.display.Javascript object>