[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/salesforce/GeDi/blob/master/GeDi_guided_GPT_2_XL.ipynb)

Official implementation of generation with the topic GeDi (pronounced *Jedi*) model based on our paper [GeDi: Generative Discriminator Guided Sequence Generation](https://arxiv.org/abs/2009.06367)

Check our github repository for more options (like detoxification and sentiment control) https://github.com/salesforce/GeDi

In [None]:
!wget https://storage.googleapis.com/sfr-gedi-data/GeDi.zip
import zipfile
with zipfile.ZipFile('GeDi.zip', 'r') as zip_ref:
    zip_ref.extractall('./')

In [None]:
cd GeDi

In [None]:
'''Installing transformers v2.8'''

!pip install transformers==2.8
!pip install -r hf_requirements.txt

'''Downloading GeDi topic model checkpoints'''

!wget https://storage.googleapis.com/sfr-gedi-data/gedi_topic.zip

with zipfile.ZipFile('gedi_topic.zip', 'r') as zip_ref:
    zip_ref.extractall('./')

In [None]:
import numpy as np
import torch


from modeling_gpt2 import GPT2LMHeadModel

from transformers import (
    GPT2Config,
    GPT2Tokenizer
)

In [None]:
mode = "topic"
code_desired = "true"
code_undesired = "false"
model_type = 'gpt2'
gen_type = "gedi"
gen_model_name_or_path = "gpt2-xl"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_CLASSES = {"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),}
config_class, model_class, tokenizer_class = MODEL_CLASSES["gpt2"]
tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, do_lower_case=False)



The next step needs to download and convert the GPT2-XL model. 

This takes a while (usually about 3 minutes to download and another 5 or so to convert). 

The good news is that once the model is loaded, you can quickly sample from many different prompts and topics.

In [None]:
#Loading GPT2-XL model (1.5B param LM) below, this could take a while.
#This requires additional CPU memory overhead to load the pretrained weights in a new model
#Due to CPU memory constraints on Colab, we're loading the model in half precision (load_in_half_prec=True) 
#Do to this change, generations may not always exactly match samples in paper, but sometimes do, and seem to be similar in quality
#If you run the notebook with enough CPU RAM (most likely 16GB+), you can try setting load_in_half_prec=False   

model = model_class.from_pretrained(gen_model_name_or_path, load_in_half_prec=True)
model = model.to(device)
model = model.float()

gedi_model_name_or_path = 'gedi_topic'
gedi_model = model_class.from_pretrained(gedi_model_name_or_path)
gedi_model.to(device)

### Set arguments for generation

You can change the max generation length, or play around with hyperparameter settings. 

The default hyperparameters were used in the topic model for the paper.

More aggressive topic steering can be done by increasing `disc_weight` or `filter_p` (`filter_p` should always be less than 1)

In [None]:
#setting arguments for generation
#max generation length
gen_length = 200
#omega from paper, higher disc_weight means more aggressive topic steering
disc_weight = 30
#1 - rho from paper, should be between 0 and 1 higher filter_p means more aggressive topic steering
filter_p = 0.8
#tau from paper, preserves tokens that are classified as correct topic
target_p = 0.8
#hyperparameter that determines class prior, set to uniform by default
class_bias = 0

if gen_length>1024:
  length = 1024
else:
  length = gen_length

### Specify prompt and topic to GeDi


The topic and prompt can be specified as strings with the `secondary_code` and `prompt` variables below.

Note that our GeDi topic model has been trained on only four topics:  `world`, `sports`, `business` and `science` so it performs best on steering generation from GPT-2 towards these topics. However, it also shows some promising zero-shot results on new topics for eg. `education`, `food`, `fire`, `space`, `cars`, `climate`.

Generic short prompts tend to work the best.

In [None]:
#Specify what topic you want to generate on using the secondary_code variable

secondary_code = 'climate'
bpe_tokens = tokenizer.encode(secondary_code)
if len(bpe_tokens) > 1:
  print("Warning! number of bpe tokens for " + code + " is greater than 1, model isn't trained for this, generation is less likely to match the topic")

In [None]:
#Specify prompt below
prompt = "In a shocking finding"

start_len=0
text_ids = tokenizer.encode(prompt)
encoded_prompts=torch.LongTensor(text_ids).unsqueeze(0).to(device)

multi_code = tokenizer.encode(secondary_code)
attr_class = 1

generated_sequence = model.generate(input_ids=encoded_prompts,
                                         pad_lens=None,
                                          max_length= length,
                                          top_k=None,
                                          top_p=None,
                                          repetition_penalty= 1.2,
                                          rep_penalty_scale= 10,
                                          eos_token_ids = tokenizer.eos_token_id,
                                          pad_token_id = 0,
                                          do_sample= False,
                                          penalize_cond= True,
                                          gedi_model= gedi_model,
                                          tokenizer= tokenizer,
                                          disc_weight= disc_weight,
                                          filter_p = filter_p,
                                          target_p = target_p,
                                          class_bias = class_bias,
                                          attr_class = attr_class,
                                          code_0 = code_undesired,
                                          code_1 = code_desired,
                                          multi_code=multi_code
                                          )

text = tokenizer.decode(generated_sequence.tolist()[0], clean_up_tokenization_spaces=True)
print('\n')
print(text)