In [3]:
# !unzip smollm2-reward-model-final.zip -d /content/smollm2-reward-model-final

In [4]:
# !pip uninstall -y trl
# !pip install -q git+https://github.com/huggingface/trl.git
# !pip install -q transformers accelerate peft datasets
# !pip install -U bitsandbytes

In [7]:
import torch
import math
import time
import numpy as np
import pandas as pd
from collections import defaultdict
from accelerate import Accelerator
from accelerate.utils import gather_object, broadcast
import gc
from transformers import GenerationConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, BitsAndBytesConfig
from peft import LoraConfig, PeftModel
from datasets import load_dataset
from trl.experimental.ppo import PPOConfig, PPOTrainer
from trl.trainer.utils import (
    batch_generation,
    first_true_indices,
    forward,
    get_reward,
    selective_log_softmax,
    truncate_response,
    empty_cache,
    log_table_to_comet_experiment,
    print_rich_table
)
from trl.models.utils import unwrap_model_for_generation

In [8]:
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
    """Compute variance of tensor with masked values."""
    mean = masked_mean(values, mask)
    centered_values = values - mean
    variance = masked_mean(centered_values**2, mask)
    if unbiased:
        mask_sum = mask.sum()
        if mask_sum == 0:
            raise ValueError(
                "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
                "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
            )
        # note that if mask_sum == 1, then there is a division by zero issue
        # to avoid it you just need to use a larger minibatch_size
        bessel_correction = mask_sum / (mask_sum - 1)
        variance = variance * bessel_correction
    return variance


def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
    """Whiten values with masked values."""
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    if not shift_mean:
        whitened += mean
    return whitened

def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor:
    """Compute mean of tensor with a masked values."""
    if axis is not None:
        return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
    else:
        return (values * mask).sum() / mask.sum()

# Constant from trl
INVALID_LOGPROB = 1.0

class DensePPOTrainer(PPOTrainer):
    """
    A custom PPO Trainer that implements DENSE rewards by overriding the train method.
    It distributes the sequence-level reward across all tokens in the response.
    """

    def train(self):
        # --- BOILERPLATE SETUP (Copied from original) ---
        args = self.args
        accelerator = self.accelerator
        optimizer = self.optimizer
        model = self.model
        ref_policy = self.ref_model
        reward_model = self.reward_model
        processing_class = self.processing_class
        dataloader = self.dataloader
        device = accelerator.device

        def repeat_generator():
            while True:
                yield from dataloader

        iter_dataloader = iter(repeat_generator())
        generation_config = GenerationConfig(
            max_new_tokens=args.response_length,
            temperature=(args.temperature + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
        )

        accelerator.print("===training policy (DENSE REWARDS)===")
        start_time = time.time()
        stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
        approxkl_stats = torch.zeros(stats_shape, device=device)
        pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
        pg_loss_stats = torch.zeros(stats_shape, device=device)
        vf_loss_stats = torch.zeros(stats_shape, device=device)
        vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
        entropy_stats = torch.zeros(stats_shape, device=device)
        ratio_stats = torch.zeros(stats_shape, device=device)
        model.train()

        # trainer state initialization
        self.state.global_step = 0
        self.state.episode = 0
        self.state.max_steps = args.num_total_batches
        self.state.num_train_epochs = args.total_episodes / self.train_dataset_len

        # Helper for logging/saving
        if args.logging_steps is not None:
            self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) if args.logging_steps < 1 else args.logging_steps
        if args.eval_steps is not None:
            self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) if args.eval_steps < 1 else args.eval_steps
        if args.save_steps is not None:
            self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) if args.save_steps < 1 else args.save_steps

        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        if self.is_deepspeed_enabled:
            self.deepspeed = self.model
            self.model_wrapped = self.model

        # --- TRAINING LOOP ---
        for update in range(1, args.num_total_batches + 1):
            self.state.episode += 1 * args.batch_size
            data = next(iter_dataloader)

            with torch.no_grad():
                queries = data["input_ids"].to(device)
                context_length = queries.shape[1]

                # --- ROLLOUT GENERATION (Copied from original) ---
                responses = []
                postprocessed_responses = []
                logprobs = []
                ref_logprobs = []
                scores = []
                sequence_lengths = []
                values = []

                with unwrap_model_for_generation(self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
                    query_responses, logitss = batch_generation(
                        unwrapped_model.policy, queries, args.local_rollout_forward_batch_size,
                        processing_class.pad_token_id, generation_config
                    )

                for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
                    query = queries[i : i + args.local_rollout_forward_batch_size]
                    query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
                    response = query_response[:, context_length:]
                    logits = logitss[i : i + args.local_rollout_forward_batch_size]
                    logprob = selective_log_softmax(logits, response)
                    del logits
                    empty_cache()

                    if ref_policy is None:
                        with self.null_ref_context():
                            ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
                    else:
                        ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)

                    ref_logits = ref_output.logits[:, context_length - 1 : -1]
                    ref_logits /= args.temperature + 1e-7
                    ref_logprob = selective_log_softmax(ref_logits, response)
                    del ref_output, ref_logits
                    empty_cache()

                    # Response Processing
                    postprocessed_response = response
                    if self.stop_token_id is not None:
                        postprocessed_response = truncate_response(self.stop_token_id, processing_class.pad_token_id, response)

                    postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1

                    unwrapped_value_model = accelerator.unwrap_model(model).value_model
                    full_value, _, _ = get_reward(unwrapped_value_model, query_response, processing_class.pad_token_id, context_length)
                    value = full_value[:, context_length - 1 : -1].squeeze(-1)

                    # GET SCORES (Scalar per sequence)
                    _, score, _ = get_reward(reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length)

                    responses.append(response)
                    postprocessed_responses.append(postprocessed_response)
                    logprobs.append(logprob)
                    ref_logprobs.append(ref_logprob)
                    sequence_lengths.append(sequence_length)
                    scores.append(score)
                    values.append(value)

                # Concatenate batches
                responses = torch.cat(responses, 0)
                postprocessed_responses = torch.cat(postprocessed_responses, 0)
                logprobs = torch.cat(logprobs, 0)
                ref_logprobs = torch.cat(ref_logprobs, 0)
                sequence_lengths = torch.cat(sequence_lengths, 0)
                scores = torch.cat(scores, 0)
                values = torch.cat(values, 0)

                del (logprob, ref_logprob, full_value, value, score)
                empty_cache()
                gc.collect()

                # Filter completion / EOS Check
                contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
                if self.args.missing_eos_penalty is not None:
                    scores[~contain_eos_token] -= self.args.missing_eos_penalty

                # Masks
                response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
                padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
                logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
                ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
                sequence_lengths_p1 = sequence_lengths + 1
                padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
                values = torch.masked_fill(values, padding_mask_p1, 0)

                # --- REWARD COMPUTATION START ---
                logr = ref_logprobs - logprobs
                kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr
                non_score_reward = -args.kl_coef * kl
                rewards = non_score_reward.clone()

                # ==============================================================
                # === MODIFIED SECTION: DENSE REWARD IMPLEMENTATION ============
                # ==============================================================

                # Original Sparse Code:
                # actual_start = torch.arange(rewards.size(0), device=rewards.device)
                # actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
                # rewards[actual_start, actual_end] += scores

                # New Dense Code:
                # We distribute the scalar 'scores' (size: batch) across all valid tokens in the response.
                # Logic: token_reward = sequence_score / sequence_length

                # 1. Avoid division by zero
                lengths_safe = sequence_lengths.float()
                lengths_safe[lengths_safe == 0] = 1.0

                # 2. Calculate dense score per token
                dense_scores = scores / lengths_safe

                # 3. Add this dense score to every valid token in 'rewards'
                # Expand dense_scores to (batch, seq_len)
                dense_scores_expanded = dense_scores.unsqueeze(1).expand_as(rewards)

                # 4. Apply only to valid tokens (not padding)
                # Note: We use ~padding_mask to identify valid tokens
                rewards = rewards + (dense_scores_expanded * (~padding_mask).float())

                # ==============================================================
                # === END MODIFIED SECTION =====================================
                # ==============================================================

                if args.whiten_rewards:
                    rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
                    rewards = torch.masked_fill(rewards, padding_mask_p1, 0)

                # --- ADVANTAGE & GAE (Copied from original) ---
                lastgaelam = 0
                advantages_reversed = []
                gen_length = responses.shape[1]
                for t in reversed(range(gen_length)):
                    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
                    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
                    lastgaelam = delta + args.gamma * args.lam * lastgaelam
                    advantages_reversed.append(lastgaelam)
                advantages = torch.stack(advantages_reversed[::-1], axis=1)
                returns = advantages + values
                advantages = masked_whiten(advantages, ~padding_mask)
                advantages = torch.masked_fill(advantages, padding_mask, 0)
                empty_cache()

            # --- PPO EPOCHS (Copied from original) ---
            for ppo_epoch_idx in range(args.num_ppo_epochs):
                b_inds = np.random.permutation(args.local_batch_size)
                minibatch_idx = 0
                for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
                    mini_batch_end = mini_batch_start + args.local_mini_batch_size
                    mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
                    gradient_accumulation_idx = 0
                    for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
                        with accelerator.accumulate(model):
                            micro_batch_end = micro_batch_start + args.per_device_train_batch_size
                            micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
                            mb_advantage = advantages[micro_batch_inds]
                            mb_responses = responses[micro_batch_inds]
                            mb_query_responses = query_responses[micro_batch_inds]
                            mb_logprobs = logprobs[micro_batch_inds]
                            mb_return = returns[micro_batch_inds]
                            mb_values = values[micro_batch_inds]

                            output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
                            logits = output.logits[:, context_length - 1 : -1]
                            logits /= args.temperature + 1e-7
                            new_logprobs = selective_log_softmax(logits, mb_responses)
                            new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB)
                            vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
                            vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
                            vpredclipped = torch.clamp(vpred, mb_values - args.cliprange_value, mb_values + args.cliprange_value)

                            vf_losses1 = torch.square(vpred - mb_return)
                            vf_losses2 = torch.square(vpredclipped - mb_return)
                            vf_loss_max = torch.max(vf_losses1, vf_losses2)
                            vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
                            vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds])

                            logprobs_diff = new_logprobs - mb_logprobs
                            ratio = torch.exp(logprobs_diff)
                            pg_losses = -mb_advantage * ratio
                            pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
                            pg_loss_max = torch.max(pg_losses, pg_losses2)
                            pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])

                            loss = pg_loss + args.vf_coef * vf_loss
                            accelerator.backward(loss)
                            optimizer.step()
                            optimizer.zero_grad()

                            with torch.no_grad():
                                pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds])
                                prob_dist = torch.nn.functional.softmax(logits, dim=-1)
                                entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
                                approxkl = 0.5 * (logprobs_diff**2).mean()
                                approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
                                pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
                                pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
                                vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
                                vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
                                entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
                                ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
                            gradient_accumulation_idx += 1
                    minibatch_idx += 1

                    del (output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs)
                    empty_cache()

            # --- LOGGING ---
            with torch.no_grad():
                mean_kl = kl.sum(1).mean()
                mean_entropy = (-logprobs).sum(1).mean()
                mean_non_score_reward = non_score_reward.sum(1).mean()
                rlhf_reward = mean_non_score_reward + scores.mean()
                eps = int(self.state.episode / (time.time() - start_time))
                metrics = {}
                metrics["eps"] = eps
                metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
                metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
                metrics["objective/non_score_reward"] = self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
                metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
                metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
                metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
                metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
                metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
                metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
                metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
                metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
                metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
                metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
                metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
                metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
                metrics["episode"] = self.state.episode
                self.state.epoch = self.state.episode / self.train_dataset_len
                self.state.global_step += 1
                self.log(metrics)

            self.lr_scheduler.step()
            self.control = self.callback_handler.on_step_end(args, self.state, self.control)
            if self.control.should_save:
                self._save_checkpoint(model, trial=None)
                self.control = self.callback_handler.on_save(self.args, self.state, self.control)

            del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
            empty_cache()
            gc.collect()

            if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
                self.generate_completions(sampling=True)
                empty_cache()

            del (query_responses, responses, postprocessed_responses, logprobs, ref_logprobs, values, sequence_lengths, contain_eos_token, sequence_lengths_p1, response_idxs, padding_mask, padding_mask_p1, rewards, advantages, returns)
            empty_cache()

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)
        if self.control.should_save:
            self._save_checkpoint(model, trial=None)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

In [9]:
gc.collect()
torch.cuda.empty_cache()


In [10]:
config = PPOConfig(
    output_dir="./smollm2-ppo-results",
    num_ppo_epochs=2,
    num_train_epochs=5,
    mini_batch_size=8,
    batch_size=64,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    stop_token="eos",
    response_length=53,
)

In [11]:
model_id = "HuggingFaceTB/smollm2-135M-SFT-Only"
reward_model_path = "./smollm2-reward-model-final"

In [12]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
if "<|im_end|>" in tokenizer.get_vocab():
    print("Found <|im_end|> in vocab. Setting as EOS.")
    tokenizer.eos_token = "<|im_end|>"
    tokenizer.pad_token = tokenizer.eos_token
else:
    # Fallback only if it really doesn't exist
    print("Adding <|im_end|> as special token.")
    tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})

tokenizer.padding_side = "left"
tokenizer.model_max_length = 512

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

Found <|im_end|> in vocab. Setting as EOS.


In [13]:
policy_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)
policy_model.gradient_checkpointing_enable()

bnb_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

peft_config = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)

value_model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=1,
    quantization_config=bnb_config_4bit,
    dtype=torch.float16,
    device_map="auto"
)

rm_base = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=1,
    quantization_config=bnb_config_4bit,
    device_map="auto"
)
reward_model = PeftModel.from_pretrained(rm_base, reward_model_path)
reward_model.eval().requires_grad_(False)

policy_model.resize_token_embeddings(len(tokenizer))
policy_model.generation_config.eos_token_id = tokenizer.eos_token_id
policy_model.generation_config.pad_token_id = tokenizer.pad_token_id

config.json:   0%|          | 0.00/804 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/smollm2-135M-SFT-Only and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/smollm2-135M-SFT-Only and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")

def format_prompt(sample):
    prompt = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
    return {
        "input_ids": tokenizer.encode(
            prompt,
            truncation=True,
            max_length=512
        )
    }

dataset = dataset.select(range(2000)) # Optional: Limit dataset size for speed
dataset = dataset.map(format_prompt, batched=False, remove_columns=dataset.column_names)
eval_dataset = dataset.select(range(20))

README.md:   0%|          | 0.00/196 [00:00<?, ?B/s]

orca_rlhf.jsonl:   0%|          | 0.00/36.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12859 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [15]:
print("Initializing Trainer...")
trainer = DensePPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy_model,
    ref_model=None,
    peft_config=peft_config,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=dataset,
    eval_dataset=eval_dataset
)

Initializing Trainer...


In [16]:
wrapper_class = type(trainer.model)
wrapper_class.gradient_checkpointing_enable = lambda self, **kwargs: self.policy.gradient_checkpointing_enable(**kwargs)
wrapper_class.gradient_checkpointing_disable = lambda self: self.policy.gradient_checkpointing_disable()

In [17]:
print("Starting PPO Training...")
trainer.train()

trainer.save_model("./smollm2-ppo-dense-final")
print("Training Complete!")

Starting PPO Training...
===training policy (DENSE REWARDS)===


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m27100046[0m ([33m27100046-lahore-university-of-management-sciences[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
























Training Complete!


In [18]:
from google.colab import files
import shutil

shutil.make_archive('smollm2-ppo-dense-final', 'zip', './smollm2-ppo-dense-final')

files.download('smollm2-ppo-dense-final.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>