In [None]:
%env CUDA_VISIBILE_DEVICES=0,1,2,3

In [None]:
import os
from functools import partial
from IPython.display import display, HTML
from vllm import LLM, SamplingParams
from omegaconf import DictConfig
import torch
from torchtune import config
from torchtune.config._utils import _get_component_from_path
from torch.utils.data import DataLoader, DistributedSampler
from metaflow import Run
from utils import fetch_and_load_weights

## Load from flow

In [None]:
# Where to download the model on local disk?
checkpoint_cache="./trained_models"

# Properties of upstream Metaflow run.
artifact_name = "model_ref"

# Reward server version selection.
version = 'v1'
if version == 'v0':
    from rewards_gsm8k_aaa_v0 import RewardServer 
    reward_tag = 'reward:gsm8k_aaa_v0'
elif version == 'v1':
    from rewards_gsm8k_aaa_v1 import RewardServer
    reward_tag = 'reward:gsm8k_aaa_v1'
    
# Properties of torchtune / finetuning run.
dataset_component = 'torchtune.dev.grpo.gsm8k.gsm8k_dataset'
dataset_partition = '3-5/100'

# Inference server properties
n_gpu = len(os.environ['CUDA_VISIBILE_DEVICES'].split(','))
batch_size = 2
grpo_size = 2

### Option 1: Fetch by tags

In [None]:
model_tag = 'model:meta-llama/Llama-3.2-3B-Instruct'
flow_name = 'TorchtuneGRPOSingleNode'

# Fetch model weights. 
# model_dir can be consumed by vLLM, or another inference server constructor.
model_dir = fetch_and_load_weights(
    model_tag = model_tag,
    reward_tag = reward_tag,
    flow_name = flow_name,
    checkpoint_cache=checkpoint_cache
)

### Option 2: Fetch with specific run

In [None]:
run = Run('TorchtuneGRPOSingleNode/9195') # NOTE: this particular run id is aaa_v1, a spectacularly dumb trial design fail. 
model_dir = fetch_and_load_weights(
    run=run,
    reward_tag = reward_tag,
    checkpoint_cache=checkpoint_cache
)

## Load model server

In [None]:
# Load weights into memory. 
# vLLM optimizes layout automatically.
llm = LLM(
    model=model_dir, 
    task="generate", 
    trust_remote_code=True,
    tensor_parallel_size=n_gpu,
    dtype='bfloat16'
)

#### Do inference, unrolling a single batch

In [None]:
## Setup torchtune dependencies.
world_size = n_gpu
rank = 0
cfg_dataset = DictConfig({'_component_': dataset_component, 'partition': dataset_partition})
cfg_tokenizer = DictConfig({
    '_component_': 'torchtune.models.llama3.llama3_tokenizer',
    'path': os.path.join(model_dir, 'original/tokenizer.model'),
    'max_seq_len': 'null'
})
collate_fn = 'torchtune.dev.grpo.data.padded_collate_rl'

tokenizer = config.instantiate(cfg_tokenizer)
ds = config.instantiate(cfg_dataset, tokenizer)
collate_fn = _get_component_from_path(collate_fn)

In [None]:
sampler = DistributedSampler(
    ds,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
    seed=42,
)

dataloader = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    sampler=sampler,
    # dropping last avoids shape issues with compile + flex attention
    drop_last=True,
    collate_fn=(
        partial(
            collate_fn,
            padding_idx=tokenizer.pad_id,
        )
    ),
)

## View `batch_size=2` sample

In [None]:
batch = next(dataloader._get_iterator())
tokens = batch["tokens"]         # tokenized prompts
answers = batch["answers"]       # untokenized answers
tokens = tokens #.to(self._device) # [batch_size x num_tokens_per_prompt]
tokens_ls = tokens.tolist()
out = []
_prompts = []
_answers = []
for i in range(tokens.shape[0]):
    prompt = tokenizer.decode(tokens_ls[i])
    _prompts.extend([prompt] * grpo_size) 
    answer = answers[i]
    _answers.extend([answer] * grpo_size)
    out.append(prompt+'\n' + '-'*24 + '\n' + 'GROUND_TRUTH_ANSWER: ' + answer)
sep =  '\n' + "-"*24 + '\n'

formatted_output = sep.join(out).replace('\n', '<br>')
display(HTML(f"<div style='max-width:500px'>{formatted_output}</div>"))

## Sample the LLM

In [None]:
sampling_params = SamplingParams(
    temperature=0.8, 
    top_p=0.95,
    max_tokens=512
)
output = llm.generate(_prompts, sampling_params)

In [None]:
stop_token_ids = [
    128001,
    128009,
    128008
]
pad_id = 128004
max_tokens = 512

data = []
for o in output:
    out_tokens = list(o.outputs[0].token_ids)
    if len(out_tokens) < max_tokens:
        out_tokens += [pad_id] * (max_tokens - len(out_tokens))
    data.append(out_tokens)
responses=torch.tensor(data, dtype=torch.int32).reshape(batch_size, grpo_size, max_tokens)

## Pluggable Reward Server

In [None]:
reward_server = RewardServer()

In [None]:
rewards, successes, details = reward_server.batch_shaped_correctness_reward(
    tokenizer, responses, answers*2, details_report=True
)

In [None]:
rewards

In [None]:
successes

In [None]:
advantages = (rewards - rewards.mean(1, keepdim=True)) / (
    rewards.std(1, keepdim=True) + 1e-4
)
# advantages = advantages.reshape(batch_size * grpo_size)

In [None]:
rewards.shape, advantages.shape, responses.shape

In [None]:
advantages

In [None]:
details

In [None]:
display(HTML(
    reward_server.display_responses(
        responses=responses,
        tokenizer=tokenizer, 
        grpo_size=grpo_size, 
        advantages=advantages, 
        rewards=rewards, 
        successes=successes,
        details=details,
    )
))