In [None]:
%env CUDA_VISIBILE_DEVICES=0,1,2,3

In [None]:
import os
from pprint import pprint
from functools import partial
from IPython.display import display, HTML
from vllm import LLM, SamplingParams
from omegaconf import DictConfig
import torch
from torchtune import config
from torchtune.config._utils import _get_component_from_path
from torch.utils.data import DataLoader, DistributedSampler
from metaflow import Run, Task
from utils import fetch_and_load_weights, load_gutenberg_dataset

## Load from flow

In [None]:
# Where to download the model on local disk?
checkpoint_cache="./trained_models"

# Properties of upstream Metaflow run.
from rewards_gutenberg_v1 import RewardServer
reward_tag = 'reward:gutenberg_eras_v1'
model_tag = 'model:meta-llama/Llama-3.2-3B-Instruct'
flow_name = 'GutenbergErasGRPOPostTrain'

artifact_name="model_ref"

# Properties of torchtune / finetuning run.
# dataset_component = 'torchtune.dev.grpo.gsm8k.gsm8k_dataset'
# dataset_partition = '3-5/100'
# NOTE: Gutenberg is a custom dataset.

# Inference server properties
n_gpu = 4
batch_size = 2
grpo_size = 2

In [None]:
task = Task('GutenbergErasGRPOPostTrain/9153/train/68625')
model_dir = fetch_and_load_weights(
    task=task,
    reward_tag = reward_tag,
    checkpoint_cache=checkpoint_cache
)

In [None]:
# Load weights into memory. 
# vLLM optimizes layout automatically.
llm = LLM(
    model=model_dir, 
    task="generate", 
    trust_remote_code=True,
    tensor_parallel_size=n_gpu,
    dtype='bfloat16'
)

#### Do inference, unrolling a single batch

In [None]:
## Setup torchtune dependencies.
world_size = n_gpu
rank = 0

# NOTE: This repo contains a single validation file, small enough to fit in git repo.
data_path = os.path.join(os.getcwd(), "gutenberg_dataset")

cfg_tokenizer = DictConfig({
    '_component_': 'torchtune.models.llama3.llama3_tokenizer',
    'path': os.path.join(model_dir, 'original/tokenizer.model'),
    'max_seq_len': 'null'
})
collate_fn = 'torchtune.dev.grpo.data.padded_collate_rl'

tokenizer = config.instantiate(cfg_tokenizer)
ds = load_gutenberg_dataset(tokenizer, data_path=data_path)
collate_fn = _get_component_from_path(collate_fn)

In [None]:
sampler = DistributedSampler(
    ds,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
    seed=42,
)

dataloader = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    sampler=sampler,
    # dropping last avoids shape issues with compile + flex attention
    drop_last=True,
    collate_fn=(
        partial(
            collate_fn,
            padding_idx=tokenizer.pad_id,
        )
    ),
)

## View `batch_size=2` sample

In [None]:
batch = next(dataloader._get_iterator())
tokens = batch["tokens"]         # tokenized prompts
answers = batch["answers"]       # untokenized answers
tokens = tokens                  # [batch_size x num_tokens_per_prompt]
tokens_ls = tokens.tolist()
out = []
_prompts = []
_answers = []
for i in range(tokens.shape[0]):
    prompt = tokenizer.decode(tokens_ls[i])
    _prompts.extend([prompt] * grpo_size) 
    answer = answers[i]
    _answers.extend([answer] * grpo_size)

pprint(prompt)

## Sample the LLM

In [None]:
max_tokens = 512

sampling_params = SamplingParams(
    temperature=0.8, 
    top_p=0.95,
    max_tokens=max_tokens
)
output = llm.generate(_prompts, sampling_params)

In [None]:
# NOTE: Unique to the model/tokenizer
# This specific configuration is for meta-llama tokenizers.
stop_token_ids = [
    128001,
    128009,
    128008
]
pad_id = 128004

data = []
for o in output:
    out_tokens = list(o.outputs[0].token_ids)
    if len(out_tokens) < max_tokens:
        out_tokens += [pad_id] * (max_tokens - len(out_tokens))
    data.append(out_tokens)
responses=torch.tensor(data, dtype=torch.int32).reshape(batch_size, grpo_size, max_tokens)

## Pluggable Reward Server

In [None]:
reward_server = RewardServer()

In [None]:
rewards, successes, details = reward_server.batch_shaped_correctness_reward(
  tokenizer=tokenizer,      
  completions=responses,      
  answers=_answers,
  details_report=True
)

In [None]:
batch_idx = 0
group_member_idx = 0
reward_server.print_reward_details_summary(details[batch_idx][group_member_idx])

In [None]:
advantages = (rewards - rewards.mean(1, keepdim=True)) / (
    rewards.std(1, keepdim=True) + 1e-4
)
# advantages = advantages.reshape(batch_size * grpo_size)

In [None]:
display(HTML(
    reward_server.display_responses(
        responses,
        tokenizer, 
        grpo_size, 
        advantages=advantages, 
        rewards=rewards, 
        successes=successes,
        details=details
    )
))