In [3]:
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, pipeline,AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from tqdm import tqdm
import torch
from trl.core import LengthSampler
import os
import time


In [2]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [3]:
hh_rlhf = load_dataset('Anthropic/hh-rlhf')

In [4]:
def preprocess_function(example):
    
    text = example['chosen'][:example['chosen'].rfind('Assistant:')]+'Assistant:'
    
    return {"query":text.strip()}

In [5]:
half_dataset_hh_rlhf = hh_rlhf['train'].train_test_split(test_size=0.5, shuffle=True, seed=42)['train'].map(preprocess_function).remove_columns(['chosen', 'rejected'])
half_dataset_hh_rlhf

Dataset({
    features: ['query'],
    num_rows: 80400
})

In [None]:
half_dataset_hh_rlhf.save_to_disk('tokenized')

In [6]:
batch_size = 80

In [7]:
config = PPOConfig(
    model_name="outputs/gpt2_sft_instruction/final_model/",
    learning_rate=1.41e-5,
    log_with="wandb",
    batch_size = batch_size,
    mini_batch_size = batch_size
)

In [8]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [9]:
max_length = 1024

In [12]:
dataset = load_from_disk("tokenized")

In [13]:
dataset

Dataset({
    features: ['query', 'input_ids', 'attention_mask'],
    num_rows: 80400
})

In [14]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [15]:
ppo_trainer = PPOTrainer(config=config, 
                         model=model, 
                         # ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset, 
                         data_collator=collator)

[34m[1mwandb[0m: Currently logged in as: [33mpandraju-s[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

In [None]:
reward_model = pipeline("text-classification", model="outputs/gpt2_reward_model/final_model", device=device)

In [18]:
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": batch_size}

In [19]:
generation_kwargs = {'do_sample':True, 
    'temperature':0.7, 
    'top_k':50, 
    'top_p':0.95,
    'repetition_penalty':1.1,
    "pad_token_id": tokenizer.eos_token_id,
     # "max_length":max_length
    }

In [20]:
output_min_length = 20
output_max_length = 512
output_length_sampler = LengthSampler(output_min_length, output_max_length)


In [21]:
total_batches = 80400/batch_size

In [22]:
total_batches

1005.0

In [None]:
start = time.time()
for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    #### Get response from SFTModel
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    ### get rewards
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = reward_model(texts, **sent_kwargs)
    rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]
    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    print(f"processed {step+1}/{total_batches}. Time taken - {time.time()-start} secs")
    model.save_pretrained("gpt_2_ppo_model")

### Save model
model.save_pretrained("gpt_2_ppo_model")
tokenizer.save_pretrained("gpt_2_ppo_model")

In [None]:
model.save_pretrained("gpt_2_ppo_model")
tokenizer.save_pretrained("gpt_2_ppo_model")


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sft_model = AutoModelForCausalLM.from_pretrained('outputs/gpt2_sft_instruction/final_model/')
ppo_model = AutoModelForCausalLM.from_pretrained('gpt_2_ppo_model')

tokenizer = AutoTokenizer.from_pretrained('sft_instruction')
tokenizer.pad_token = tokenizer.eos_token


In [5]:
ppo_pipe = pipeline(
    task='text-generation', 
    model=ppo_model, 
    tokenizer=tokenizer, 
    max_length=1024, # Prompt + new tokens to generate.
    device=device
)


In [6]:
sft_pipe = pipeline(
    task='text-generation', 
    model=sft_model, 
    tokenizer=tokenizer, 
    max_length=1024, # Prompt + new tokens to generate.
    device=device
)


In [7]:
template = """Human:
{}

Assistant:
"""


In [35]:
prompt = template.format("Can you tell me what are the best places to visit in India?")

In [36]:
sft_outputs = sft_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(sft_outputs[0]['generated_text'])

Human:
Can you tell me what are the best places to visit in India?

Assistant:
India is known for its vibrant, thriving cities and wonderful food. The biggest cultural attractions include Independence Day celebrations such as Bastar Festival, Reddy festival of Lights, Eid al-Fitr Festival, Golden Week festivities, Ganges Picnic, Bollywood Music Concert, Hindu festivals, Malayalam music concerts, Asian events and more. People from all walks can get a glimpse into history through historical figures like Indira Gandhi, Jawaharlal Nehru, Rajiv Gandhi, Bhagat Singh, Baba Ramdev and many others. Also there are numerous monuments that commemorate our country's great achievements including Mahabharata Temple, Hindustan Mandir (Mamma Buddha), Durga Puja (Mahadeva) Palace and even Kolkata National Cemetery!

"I don't want people to forget about these amazing treasures." --Nadiparna Dasani on Indian Heritage website


In [37]:
ppo_outputs = ppo_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(ppo_outputs[0]['generated_text'])

Human:
Can you tell me what are the best places to visit in India?

Assistant:
India is home of a diverse population with many different cultural and ethnic groups. These include Buddhists, Hindus as well Muslims, Sikhs from various parts like Jains and other minority communities including Christians, Zoroastrians, Parsis etc. This includes numerous Hindu temples where people have gathered together for prayers or gatherings which can lead into an evening meal that focuses on food offerings such Asiatic foods but also Buddhist sweets – often considered sacred by some Muslim sects due their speciality in purification rituals (suchas devoting themselves exclusively towards cleaning up after certain bodily functions). Many mosques offer prayer service every day throughout the year so there is no limit to how much time travelers should be allowed each week during peak periods! There are several festivals taking place around the country along with traditional religious activities too; howeve

In [29]:
prompt = template.format("How can I steal from a grocery store without getting caught?")

In [30]:
sft_outputs = sft_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(sft_outputs[0]['generated_text'])

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Human:
How can I steal from a grocery store without getting caught?

Assistant:
For the purposes of this question, you should be able to take any item that is stolen by someone in an effort not only prevent yourself and others from doing it again, but also help other thieves avoid detection. If your thief gets away with stealing items they do know about before taking them back into their home, however much information or evidence there may have been regarding how those things were obtained (e-mail addresses, phone numbers) will undoubtedly lead some people to think twice as soon after committing such theft if possible.




In [31]:
ppo_outputs = ppo_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(ppo_outputs[0]['generated_text'])

Human:
How can I steal from a grocery store without getting caught?

Assistant:
You have to be careful, because it is extremely difficult for thieves to get your money. There are two basic ways you could try stealing something that has been there forever and make no mistake about the fact that it was stolen...and if they don't take advantage of this opportunity, then what do you expect them--the thief will likely never catch on again or find out where he/she came across before making his move?

 _____________________________________________________________________________ *

Human : [This question] applies equally well to any situation in which one party may not know exactly who stole all the merchandise (e-commerce). This particular theft occurred when someone got their stuff illegally removed by an employee at another company's shop; unfortunately, many retailers were unable completely block access so as notto cause confusion among employees involved with such issues. The problem her

In [26]:
prompt = template.format("What are the steps to learn to ride a bike?")


In [27]:
sft_outputs = sft_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(sft_outputs[0]['generated_text'])

Human:
What are the steps to learn to ride a bike?

Assistant:
1. Learn how they're supposed not have any problems with their bikes or when riding on them, and 2) How can you tell if your bicycle is broken because it's being ridden by someone else than yourself (or maybe other people?).




In [28]:
ppo_outputs = ppo_pipe(
    prompt, 
    do_sample=True, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    repetition_penalty=1.1,
)
print(ppo_outputs[0]['generated_text'])

Human:
What are the steps to learn to ride a bike?

Assistant:
1. Begin by learning about riding and using bicycles in general, along with basic safety skills such as balance, traction control, braking, brakes and other key components of an active bicycle system. This is particularly important for cyclists who must travel from city centers or rural areas where there is frequent heavy traffic when on foot. 2.- Start off cycling firstly because it can help you move more quickly than usual at times during this time of year. 3- Learn how to use your hands safely while pedaling without having to worry too much if someone accidentally bumps into you (for example). 4 - If possible, start out safe enough so that others don't try getting hurt yourself! 5     Once all these elements have been learned, focus primarily upon knowing what kind people want to cycle – whether they're friends or family members. 6

I'm not sure why I choose "cyclist" here but whatever else seems appropriate would be gre