<a href="https://colab.research.google.com/github/brgsil/toxicity-lm-ia024/blob/main/PPLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Requirements

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install transformers
!pip install datasets
!pip install accelerate
!pip install bitsandbytes

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Pipeline, GPT2Tokenizer, GPT2LMHeadModel
from accelerate import Accelerator

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 5.2 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 44.1 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 61.3 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.0 tokenizers-0.13.2 transformers-4.24.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 4.6 MB/

In [None]:
import torch
import json
import os
from operator import add
from typing import List, Optional, Tuple, Union

import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from googleapiclient import discovery
import time

SMALL_CONST = 1e-15
BIG_CONST = 1e10

In [None]:
if torch.cuda.is_available(): 
   dev = "cuda:0"
   accelerator = Accelerator(mixed_precision='fp16')
else: 
   dev = "cpu"
   accelerator = Accelerator(mixed_precision='no')
#device = torch.device(dev)
device = accelerator.device
print('Using {}'.format(device))

Using cuda


# Datasets

In [None]:
class RTPDataset(Dataset):

    def __init__(self, tokenizer, example_per_bin: int = 1000):

        dataset = load_dataset('allenai/real-toxicity-prompts', split="train")

        bos_token = '' if 'opt' in tokenizer.name_or_path else tokenizer.bos_token

        self.prompt_toxicity = np.asarray([data['prompt']['toxicity'] for data in dataset if data['prompt']['toxicity'] != None])
        self.prompts = np.asarray([data['prompt']['text'] for data in dataset if data['prompt']['toxicity'] != None])
        text = np.asarray([bos_token + data['prompt']['text'] for data in dataset if data['prompt']['toxicity'] != None])
        self.challenging = np.asarray([data['challenging'] for data in dataset if data['prompt']['toxicity'] != None])

        bins = np.digitize(self.prompt_toxicity, bins=np.linspace(0.1,1,10))

        np.random.seed(42)
        bins_idxs = [np.where(bins == i)[0] for i in range(10)]
        [np.random.shuffle(arr) for arr in bins_idxs]
        bins_idxs = np.asarray([arr[:example_per_bin] for arr in bins_idxs]).reshape(-1)

        self.prompt_toxicity = self.prompt_toxicity[bins_idxs]
        self.prompts = self.prompts[bins_idxs]
        text = text[bins_idxs]
        self.challenging = self.challenging[bins_idxs]

        tokenizer.padding_side = 'left'
        tokenized = tokenizer(text.tolist())
        self.tokens = tokenized.input_ids
        self.attentions = tokenized.attention_mask

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index): 
        return self.prompts[index], self.prompt_toxicity[index], self.challenging[index], self.tokens[index], self.attentions[index]

# PPLM

In [None]:
def get_bag_of_words_indices(tokenizer):
    bow_indices = []
    filepath = '/content/drive/Shareddrives/IA024-Final/toxicity_classifier/bad_words.py'
    with open(filepath, "r") as f:
        words = f.read().strip().split("\n")
    bow_indices.append(
        [tokenizer.encode(word.strip(),
                            add_special_tokens=False)
            for word in words])
    return bow_indices


def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
    if bow_indices is None:
        return None

    one_hot_bows_vectors = []
    for single_bow in bow_indices:
        single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
        single_bow = torch.tensor(single_bow).to(device)
        num_words = single_bow.shape[0]
        #one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
        #one_hot_bow = torch.zeros(num_words, 50272).to(device) # For OPT model only
        one_hot_bow = torch.zeros(num_words, 250880).to(device) # For BLOOM model only
        one_hot_bow.scatter_(1, single_bow, 1)
        one_hot_bows_vectors.append(one_hot_bow)
    return one_hot_bows_vectors

In [None]:
def top_k_filter(logits, k, probs=False):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        if probs:
            return torch.where(logits < batch_mins,
                               torch.ones_like(logits) * 0.0, logits)
        return torch.where(logits < batch_mins,
                           torch.ones_like(logits) * -BIG_CONST,
                           logits)

In [None]:
def perturb_past(
        past,
        model,
        last,
        gen_length,
        unpert_past=None,
        unpert_logits=None,
        accumulated_hidden=None,
        grad_norms=None,
        stepsize=0.01,
        one_hot_bows_vectors=None,
        num_iterations=3,
        horizon_length=1,
        window_length=0,
        gamma=1.5,
        kl_scale=0.01,
        device='cuda'
):
    # Generate inital perturbed past
    grad_accumulator = [
        np.zeros((2, *p[0].shape)).astype("float32")
        for p in past
    ]

    if accumulated_hidden is None:
        accumulated_hidden = 0

    # TODO fix this comment (SUMANTH)
    # Generate a mask is gradient perturbated is based on a past window

    # accumulate perturbations for num_iterations
    loss_per_iter = []
    new_accumulated_hidden = None
    for i in range(num_iterations):
        curr_perturbation = []
        for p_ in grad_accumulator:
          a = torch.from_numpy(p_[0]).float().requires_grad_(requires_grad=True).to(device)
          a.retain_grad()
          b = torch.from_numpy(p_[1]).float().requires_grad_(requires_grad=True).to(device)
          b.retain_grad()
          curr_perturbation.append((a,b))

        # Compute hidden using perturbed past
        #perturbed_past = [(past[k][0] + curr_perturbation[k][0], past[k][1] + curr_perturbation[k][1]) for k in range(len(past))]
        perturbed_past = [(past[k][0] + curr_perturbation[k][0], past[k][1] + curr_perturbation[k][1].transpose(1,2)) for k in range(len(past))] # For BLOOM model only


        perturbed_past = tuple(perturbed_past)
        out = model(last, attention_mask=last.new_ones((1, gen_length)), past_key_values=perturbed_past)
        all_logits, _, all_hidden = out.logits, out.past_key_values, out.hidden_states
        hidden = all_hidden[-1]
        new_accumulated_hidden = accumulated_hidden + torch.sum(
            hidden,
            dim=1
        ).detach()
        # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
        logits = all_logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        loss = 0.0
        loss_list = []

        for one_hot_bow in one_hot_bows_vectors:
                bow_logits = torch.mm(probs, torch.t(one_hot_bow))
                bow_loss = torch.log(torch.sum(bow_logits))
                loss += bow_loss
                loss_list.append(bow_loss)

        kl_loss = 0.0
        if kl_scale > 0.0:
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
            unpert_probs = (
                    unpert_probs + SMALL_CONST *
                    (unpert_probs <= SMALL_CONST).float().to(device).detach()
            )
            correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
                device).detach()
            corrected_probs = probs + correction.detach()
            kl_loss = kl_scale * (
                (corrected_probs * (corrected_probs / unpert_probs).log()).sum()
            )
            loss += kl_loss

        loss_per_iter.append(loss.data.cpu().numpy())

        # compute gradients
        loss.backward()

        # calculate gradient norms
        if grad_norms is not None:
            grad_norms = [
                torch.max(grad_norms[index], torch.norm(torch.cat((p_[0].grad.unsqueeze(0), p_[1].grad.unsqueeze(0)))))
                for index, p_ in enumerate(curr_perturbation)
            ]
        else:
            grad_norms = [
                (torch.norm(torch.cat((p_[0].grad.unsqueeze(0), p_[1].grad.unsqueeze(0)))) + SMALL_CONST)
                for index, p_ in enumerate(curr_perturbation)
            ]

        # normalize gradients
        grad = [
            -stepsize * (torch.cat((p_[0].grad.unsqueeze(0), p_[1].grad.unsqueeze(0))) / grad_norms[index] ** gamma).data.cpu().numpy()
            for index, p_ in enumerate(curr_perturbation)
        ]


        # accumulate gradient
        grad_accumulator = [grad[i] + grad_accumulator[i] for i in range(len(grad))]

        # reset gradients, just to make sure
        for p_ in curr_perturbation:
            p_[0].grad.data.zero_()
            p_[1].grad.data.zero_()

        # removing past from the graph
        new_past = []
        for p_ in past:
            new_past.append((p_[0].detach(), p_[1].detach()))
        past = new_past

    # apply the accumulated perturbations to the past
    grad_accumulator_ = []
    for p_ in grad_accumulator:
        a = torch.from_numpy(p_).float().requires_grad_(requires_grad=True).to(device)
        a.retain_grad()
        grad_accumulator_.append(a)
    grad_accumulator = grad_accumulator_

    #pert_past = [(past[i][0] + grad_accumulator[i][0], past[i][1] + grad_accumulator[i][1]) for i in range(len(past))]
    pert_past = [(past[i][0] + grad_accumulator[i][0], past[i][1] + grad_accumulator[i][1].transpose(1,2)) for i in range(len(past))] # For BLOOM model only

    return pert_past, grad_norms

In [None]:
def generate(model,
              context,
              perturb=True,
              length=10,
              stepsize=0.02,
              temperature=1.0,
              top_k=10,
              sample=True,
              num_iterations=10,
              grad_length=10000,
              horizon_length=1,
              window_length=0,
              gamma=1.0,
              gm_scale=0.9,
              kl_scale=0.01):
  
    output_so_far = torch.tensor([context]).to(device)

    bow_indices = get_bag_of_words_indices(tokenizer)
    one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)

    grad_norms = None
    last = None
    past = None

    for i in range(length):

        # Get past/probs for current output, except for last word
        # Note that GPT takes 2 inputs: past + current_token

        # run model forward to obtain unperturbed
        if past is None and output_so_far is not None:
            last = output_so_far[:, -1:]
            if output_so_far.shape[1] > 1:
                out = model(output_so_far[:, :-1])
                _, past, _ = out.logits, out.past_key_values, out.hidden_states
                #past = [torch.cat((layer[0], layer[1])).reshape(len(past[0]), *past[0][0].shape) for layer in past]

        out = model(output_so_far)
        unpert_logits, unpert_past, unpert_all_hidden = out.logits, out.past_key_values, out.hidden_states
        #unpert_past = [torch.cat((layer[0], layer[1])).reshape(len(unpert_past[0]), *unpert_past[0][0].shape) for layer in unpert_past]
        unpert_last_hidden = unpert_all_hidden[-1]

        # check if we are abowe grad max length
        if i >= grad_length:
            current_stepsize = stepsize * 0
        else:
            current_stepsize = stepsize

        # modify the past 
        accumulated_hidden = unpert_last_hidden[:, :-1, :]
        accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

        if past is not None:
            pert_past, grad_norms = perturb_past(
                past,
                model,
                last,
                unpert_past=unpert_past,
                unpert_logits=unpert_logits,
                accumulated_hidden=accumulated_hidden,
                grad_norms=grad_norms,
                stepsize=current_stepsize,
                one_hot_bows_vectors=one_hot_bows_vectors,
                num_iterations=num_iterations,
                horizon_length=horizon_length,
                window_length=window_length,
                gamma=gamma,
                kl_scale=kl_scale,
                device=device
            )
        else:
            pert_past = past

        pert_past = tuple([(layer[0], layer[1]) for layer in pert_past])
        out = model(last, past_key_values=pert_past)
        pert_logits, past, pert_all_hidden = out.logits, out.past_key_values, out.hidden_states
        #past = [torch.cat((layer[0], layer[1])).reshape(len(past[0]), *past[0][0].shape) for layer in past]
        pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST
        pert_probs = F.softmax(pert_logits, dim=-1)

        # Fuse the modified model and original model

        unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)

        pert_probs = ((pert_probs ** gm_scale) * (
                unpert_probs ** (1 - gm_scale)))  # + SMALL_CONST
        pert_probs = top_k_filter(pert_probs, k=top_k,
                                    probs=True)  # + SMALL_CONST

        # rescale
        pert_probs = pert_probs / torch.sum(pert_probs, dim=-1).unsqueeze(1)

        # sample or greedy
        if sample:
            last = torch.multinomial(pert_probs, num_samples=1)

        else:
            _, last = torch.topk(pert_probs, k=1, dim=-1)

        # update context/output_so_far appending the new token
        output_so_far = (
            last if output_so_far is None
            else torch.cat((output_so_far, last), dim=1)
        )

    return output_so_far

In [None]:
num_generations = 5
test_models = ["EleutherAI/gpt-neo-125M",
               "EleutherAI/gpt-neo-1.3B",
              #  #"EleutherAI/gpt-neo-2.7B",
               "gpt2",
               "gpt2-medium",
               "gpt2-large",
               "gpt2-xl",
               "facebook/opt-125m",
               "facebook/opt-350m",
               "facebook/opt-1.3b",
              #  "facebook/opt-2.7b",
            #    "facebook/opt-6.7b",
               "facebook/galactica-125m",
               "facebook/galactica-1.3b",
               #"facebook/galactica-6.7b",
               "bigscience/bloom-560m",
               "bigscience/bloom-1b1",
               "bigscience/bloom-1b7",
            #    "bigscience/bloom-3b",
            #    "bigscience/bloom-7b1"
               ]

In [None]:
for model_name in test_models:
    
    save_path = '/content/drive/Shareddrives/IA024-Final/PPLM/RTP/' + model_name.replace('/','_') + '.json'

    if not os.path.exists(save_path):

        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
        if tokenizer.pad_token is None:
            if 'galactica' in model_name:
                tokenizer.bos_token_id = 0
                tokenizer.pad_token_id = 1
                tokenizer.eos_token_id = 2
                tokenizer.unk_token_id = 3
            else:
                tokenizer.pad_token = tokenizer.eos_token
                
        model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                     use_cache=True,
                                                     output_hidden_states=True,
                                                     device_map="auto", 
                                                     cache_dir="/content/drive/Shareddrives/IA024-Final/models")
        model.eval()

        dataset = RTPDataset(tokenizer, example_per_bin=5)

        for param in model.parameters():
            param.requires_grad = False

        generated_outputs = []

        pbar = tqdm(dataset, total=len(dataset))
        for prompt, toxicity, challenge, input, mask in pbar:
            gen_texts = []
            for i in range(num_generations):
                gen = generate(model, input, length=20)
                gen_texts.append(tokenizer.decode(gen[0, len(input):], skip_special_tokens=True))
            
            output = {'prompt': prompt,
                    'toxicity': toxicity,
                    'challenging': str(challenge),
                    'model_generated': gen_texts}
            generated_outputs.append(output)
            
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(generated_outputs, f, ensure_ascii=False, indent=4)

In [None]:
for model_name in test_models:
    
    load_path = '/content/drive/Shareddrives/IA024-Final/PPLM/RTP/' + model_name.replace('/','_') + '.json'
    save_path = '/content/drive/Shareddrives/IA024-Final/PPLM/RTP/' + model_name.replace('/','_') + '_scored' + '.json'

    if os.path.exists(load_path):

        with open(load_path, 'r') as f:
            gen_data = json.load(f)
    

        API_KEY = 'AIzaSyAUAPSOGpv4ZfqaIkuqHCOdABmR0fL0upA'

        client = discovery.build(
            "commentanalyzer",
            "v1alpha1",
            developerKey=API_KEY,
            discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1"
        )

        scored_data = []
        if os.path.exists(save_path):
            with open(save_path, 'r') as f:
                scored_data = json.load(f)
        
        pbar = tqdm(range(len(scored_data), len(gen_data)), total=len(gen_data)-len(scored_data))
        for data_idx in pbar:
            data = gen_data[data_idx]
            data['scores'] = []
            for sample in data['model_generated']:
                if sample != '':
                    analyze_request = {
                    'comment': { 'text': sample },
                    'requestedAttributes': {'TOXICITY': {}},
                    'languages': ['en']
                    }

                    toxicity_value = -1
                    while toxicity_value == -1:
                        try:
                            response = client.comments().analyze(body=analyze_request).execute()
                            toxicity_value = response['attributeScores']['TOXICITY']['summaryScore']['value']
                        except Exception as err:
                            #print(err)
                            time.sleep(1) # Wait for more quota
                else:
                    toxicity_value = 0

                data['scores'].append(toxicity_value)
            
            scored_data.append(data)

            # Save checkpoint of processed data
            if data_idx % 100 == 0:
                start_save_time = time.perf_counter()
                with open(save_path, 'w', encoding='utf-8') as f:
                    json.dump(scored_data, f, ensure_ascii=False, indent=4)
                pbar.set_postfix({'Save Time': time.perf_counter() - start_save_time})
        
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(scored_data, f, ensure_ascii=False, indent=4)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/199 [00:00<?, ?it/s]