In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from stable_baselines3 import PPO, A2C
from models.GPT2ActorCriticPolicy import GPT2ActorCriticPolicy
from envs.AdaptEnv import AdaptEnv
import os

prompt = 'Tell me a story. '

out_dir = f'output/happy/story/DialoGPT-medium/2022-05-07/18-42-22/'
os.makedirs(out_dir, exist_ok=True)
with open(f'{out_dir}/prompt.txt', 'w+') as prompt_file:
    prompt_file.write(prompt)

cache_dir = '/om2/user/rogerjin/.cache/transformers'
hf_id = 'microsoft/DialoGPT-medium'
# hf_id = 'gpt2-large'
device = 'cuda:0'
save_path = '/om2/user/rogerjin/6.884/adaPT/checkpoints/happy/story/DialoGPT-small/2022-05-06/15-27-24/DialoGPT-small_ppo_10000_lr=5e-06.zip'
tokenizer = GPT2Tokenizer.from_pretrained(hf_id, cache_dir=cache_dir, device=device)

control = True
run_name = 'be_sadder'

if control:
    model = GPT2LMHeadModel.from_pretrained(hf_id, cache_dir=cache_dir).to(device)
    out_path = f'{out_dir}/control.txt'
else:
    ppo = PPO.load(save_path)
    model = ppo.policy._modules['mlp_extractor'].policy_latent.lm # todo: figure out if the weights are different
    out_path = f'{out_dir}/{run_name}.txt'

In [3]:
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
tokenizer.padding_side = "left" 
tokenizer.pad_token = tokenizer.eos_token # to avoid an error

num_sequences = 1000
sentences = [prompt] * num_sequences
batch_size = 20

output_sequences = []

for i in range(0, num_sequences, batch_size):
    batch = sentences[i:i+batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True)

    output_sequences.extend(model.generate(
        input_ids=inputs['input_ids'].to(device),
        attention_mask=inputs['attention_mask'].to(device),
        do_sample=True, # disable sampling to test if batching affects output
        min_length=25,
        max_length = 60,
        temperature=0.6,
        top_k=50,
        repetition_penalty=1.3,
        no_repeat_ngram_size=2
    ))

decodings = []

for i in range(len(sentences)):
    decoding = tokenizer.decode(output_sequences[i], skip_special_tokens=True)
    print(decoding)
    decodings.append(f'<|DECODING|>{decoding}')
#     decodings.append(f'{decoding}\n')
    # you can use skip_special_tokens=True in decode() to remove padding token
    # but note that it will also remove other special_tokens

with open(out_path, 'w+') as out:
    out.writelines(decodings)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end gene

Tell me a story. ive seen the movie. I can't tell you much about it. It's worth watching if you have.
Tell me a story. ive heard of this. I don't even know what to do with my life now haha...
Tell me a story. ive heard it's really bad. I can see how it could be awful to watch, but is your friend in the movie?
Tell me a story. ive been reading the manga, but never read anime. I'll tell you one though! One day i will start watching anime
Tell me a story. ive been watching it since the first episode and I have no idea what's going on. It was pretty good
Tell me a story. ive never really heard of it. theres literally no one who can give you any info on this
Tell me a story. ive been trying to explain why it has the same effect on my body and brain. Haha
Tell me a story. ive never had to deal with it. You're the 3rd person I ve heard about this happening
Tell me a story. ive been trying to figure out if I can get my account back. my friends and family are not happy about it, as soon i told 