In [3]:
import json
import math
import os
import sys
from itertools import islice

import numpy as np
import torch
import tritonclient.grpc as client_util
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from tritonclient.utils import np_to_triton_dtype

import trlx
from trlx.data.default_configs import (
    ModelConfig,
    OptimizerConfig,
    PPOConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
    TRLConfig,
)

default_config = TRLConfig(
    train=TrainConfig(
        seq_length=1024,
        epochs=10000,
        total_steps=10000,
        batch_size=4,
        eval_batch_size=32,
        checkpoint_interval=10000,
        eval_interval=500,
        pipeline="PromptPipeline",
        trainer="AcceleratePPOTrainer",
        checkpoint_dir="checkpoints/ppo_hh",
    ),
    model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2),
    tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"),
    optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=8e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
    scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=8e-6)),
    method=PPOConfig(
        name="PPOConfig",
        num_rollouts=64,
        # chunk_size=16,
        chunk_size=4,
        ppo_epochs=4,
        init_kl_coef=0.05,
        target=6,
        horizon=10000,
        gamma=1,
        lam=0.95,
        cliprange=0.2,
        cliprange_value=0.2,
        vf_coef=1,
        scale_reward="running",
        ref_mean=None,
        ref_std=None,
        cliprange_reward=10,
        gen_kwargs=dict(
            max_new_tokens=128,
            top_k=0,
            top_p=1.0,
            do_sample=True,
        ),
    ),
)


default_config.train.batch_size = 8
default_config.train.total_steps = 1500
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_125M"
default_config.model.model_path = "Dahoas/pythia-125M-static-sft"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.num_rollouts = 128


def prepare_tensor(name: str, input):
    t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype))
    t.set_data_from_numpy(input)
    return t


def create_reward_fn():  # noqa:  C901
    reward_tokenizer = AutoTokenizer.from_pretrained("gpt2")
    reward_tokenizer.pad_token = reward_tokenizer.eos_token
    reward_tokenizer.truncation_side = "left"

    class RewardModel(nn.Module):
        def __init__(self, checkpoint_path, eos_token_id):
            super().__init__()
            model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
            self.transformer = model.transformer
            self.v_head = nn.Linear(model.config.n_embd, 1, bias=False)
            self.eos_token_id = eos_token_id

        def forward(self, input_ids):
            states = self.transformer(input_ids)[0]
            rewards = self.v_head(states).squeeze(-1)
            ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1)
            returns = torch.gather(rewards, 1, ends).squeeze(-1)
            return returns

    print("Reward model creation begin...")
    reward_model = RewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id)
    print("Reward model created!")
    directory = snapshot_download("Dahoas/gptj-rm-static", revision="676bfd4d")
    for fpath in os.listdir(directory):
        if fpath.endswith(".pt") or fpath.endswith(".bin"):
            checkpoint = os.path.join(directory, fpath)
            break

    print("Begin loading reward model...")
    reward_model.load_state_dict(torch.load(checkpoint))
    print("Reward model loaded!")
    reward_model.eval()
    reward_model.requires_grad_(False)
    # reward_device = torch.cuda.device_count() - 1
    reward_device = "cpu"
    print("Reward device:", reward_device, "of", torch.cuda.device_count())
    reward_model = reward_model.half().to(reward_device)
    # reward_batch_size = 48
    # test a smaller batch size
    reward_batch_size = 1
    delta_reward = True

    def get_reward(samples):
        input = reward_tokenizer(
            samples,
            padding=True,
            truncation=True,
            max_length=reward_tokenizer.max_len_single_sentence,
            return_tensors="pt",
        ).to(reward_device)

        mbs = reward_batch_size
        out = []
        for i in range(math.ceil(len(samples) / mbs)):
            batch_ixs = slice(i * mbs, (i + 1) * mbs)
            input_ids = input.input_ids[batch_ixs]
            rewards = reward_model(input_ids)
            out.extend(rewards)
        return torch.hstack(out)

    def reward_fn(samples, prompts, original_output, **kwargs):
        samples = [s + reward_tokenizer.eos_token for s in samples]
        rewards = get_reward(samples)

        if not delta_reward:
            return rewards

        original_samples = [p + o + reward_tokenizer.eos_token for p, o in zip(prompts, original_output)]
        original_rewards = get_reward(original_samples)
        return rewards - original_rewards

    return reward_fn

In [4]:
# config = TRLConfig.update(default_config, hparams)

dataset = load_dataset("Dahoas/rm-static")
prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in dataset["train"]]
print(prompts[0:10])
eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)]
reward_fn = create_reward_fn()

# trlx.train(
#     prompts=prompts,
#     eval_prompts=eval_prompts,
#     reward_fn=reward_fn,
#     config=config,
#     stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
# )

Found cached dataset parquet (/home/thw/.cache/huggingface/datasets/Dahoas___parquet/default-b9d2c4937d617106/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 2/2 [00:00<00:00, 10.40it/s]


[{'prompt': '\n\nHuman: Can you describe the steps to clean fingerprints and smudges from a laptop screen\n\nAssistant: Yes, certainly. To clean your screen, you first need to use a microfiber cloth or soft, damp cloth to gently wipe down the surface of the screen. Next, you’ll want to grab a soft, lint-free, microfiber cleaning cloth and gently rub it back and forth across the screen to remove fingerprints and smudges.\n\nHuman: Can I spray isopropyl alcohol onto the cloth and clean it that way?\n\nAssistant:', 'original_output': ' Yes, you can do that to help the cloth pick up even more dirt from the screen. Be sure to always use a clean, soft cloth, not a piece of scratchy, roughened, or textured material, and make sure it’s lint-free.'}, {'prompt': '\n\nHuman: What are some foods that are good for diabetics?\n\nAssistant: To be honest, some of these are better than others, and they’re a little more like opinions than facts. For example, many of the diets say to limit vegetables wit

In [18]:
reward_fn(prompts[0:1], prompts[0:1]["prompt"], prompts[0:1]["original_output"])
# type(prompts[0])

In [1]:
import json
import math
import os
import sys
from itertools import islice

import numpy as np
import torch
import tritonclient.grpc as client_util
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from tritonclient.utils import np_to_triton_dtype

import trlx
from trlx.data.default_configs import (
    ModelConfig,
    OptimizerConfig,
    SACConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
    TRLConfig,
)
from trlx.trainer.accelerate_sac_trainer import AccelerateSACTrainer
from trlx.utils.modeling import (
    freeze_bottom_causal_layers,
)

default_config = TRLConfig(
    train=TrainConfig(
        seq_length=1024,
        epochs=10000,
        total_steps=10000,
        batch_size=4,
        eval_batch_size=32,
        max_history_size=128,
        checkpoint_interval=10000,
        eval_interval=500,
        pipeline="PromptPipeline",
        trainer="AccelerateSACTrainer",
        checkpoint_dir="checkpoints/sac_hh",
    ),
    model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2),
    tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"),
    optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=8e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
    scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=8e-6)),
    method=SACConfig(
        name="SACConfig",
        num_rollouts=64,
        # chunk_size=16,
        chunk_size=4,
        sac_epochs=4,
        init_kl_coef=0.05,
        target=6,
        horizon=10000,
        alpha=1,
        beta=1,
        gamma=0.001,
        lam=0.95,
        actor_reg_coef=0.9,
        # cliprange=0.2,
        # cliprange_value=0.2,
        # vf_coef=1,
        scale_reward="running",
        ref_mean=None,
        ref_std=None,
        cliprange_reward=10,
        gen_kwargs=dict(
            max_new_tokens=128,
            top_k=0,
            top_p=1.0,
            do_sample=True,
        ),
    ),
)


default_config.train.batch_size = 8
default_config.train.total_steps = 1500
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_125M"
default_config.model.model_path = "Dahoas/pythia-125M-static-sft"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.num_rollouts = 128
default_config.train.tracker = None


config = TRLConfig.update(default_config, {})

# dataset = load_dataset("Dahoas/rm-static")
# prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in dataset["train"]]
# eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)]
# reward_fn = create_reward_fn()

acceleratesactrainer = AccelerateSACTrainer(config)
model = acceleratesactrainer.get_arch(config)
freeze_bottom_causal_layers(model.base_model, config.model.num_layers_unfrozen)
print(model.eval())

# trlx.train(
#     prompts=prompts,
#     eval_prompts=eval_prompts,
#     reward_fn=reward_fn,
#     config=config,
#     stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
# )

  from .autonotebook import tqdm as notebook_tqdm


[2023-07-15 16:03:36,472] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
for name, v in model.base_model.named_parameters():
    print(name, v.requires_grad)
# print(model.frozen_head)

gpt_neox.embed_in.weight True
gpt_neox.layers.0.input_layernorm.weight False
gpt_neox.layers.0.input_layernorm.bias False
gpt_neox.layers.0.post_attention_layernorm.weight False
gpt_neox.layers.0.post_attention_layernorm.bias False
gpt_neox.layers.0.attention.query_key_value.weight False
gpt_neox.layers.0.attention.query_key_value.bias False
gpt_neox.layers.0.attention.dense.weight False
gpt_neox.layers.0.attention.dense.bias False
gpt_neox.layers.0.mlp.dense_h_to_4h.weight False
gpt_neox.layers.0.mlp.dense_h_to_4h.bias False
gpt_neox.layers.0.mlp.dense_4h_to_h.weight False
gpt_neox.layers.0.mlp.dense_4h_to_h.bias False
gpt_neox.layers.1.input_layernorm.weight False
gpt_neox.layers.1.input_layernorm.bias False
gpt_neox.layers.1.post_attention_layernorm.weight False
gpt_neox.layers.1.post_attention_layernorm.bias False
gpt_neox.layers.1.attention.query_key_value.weight False
gpt_neox.layers.1.attention.query_key_value.bias False
gpt_neox.layers.1.attention.dense.weight False
gpt_neox.la

In [13]:
tokens = torch.randint(0, 10, size=[8, 10])
attention_mask = torch.randn(8, 10)
outputs = model(tokens, attention_mask, return_dict=True)

In [15]:
outputs.keys()

odict_keys(['logits', 'past_key_values', 'hidden_states', 'value'])

In [15]:
for name, param in model.v_head.named_parameters():
    print(name, param.requires_grad)

0.weight True
0.bias True
2.weight True
2.bias True


wandb: Network error (ConnectionError), entering retry loop.
