# Imports

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Loading Models

In [None]:
qwen2_5_sft = AutoModelForCausalLM.from_pretrained("VerlTool/Qwen2.5-Math-1.5B-TIR-SFT", dtype=torch.bfloat16)
qwen2_5_rl = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", dtype=torch.bfloat16)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Fetching 2 files: 100%|██████████| 2/2 [01:43<00:00, 51.63s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.94s/it]


In [9]:
def compute_deltas(sft_model, rl_model):
    rl_state_dict = rl_model.state_dict()
    sft_state_dict = sft_model.state_dict()

    missing_in_rl = [name for name in sft_state_dict if name not in rl_state_dict]
    missing_in_sft = [name for name in rl_state_dict if name not in sft_state_dict]

    if missing_in_rl or missing_in_sft:
        print("Missing in RL:", missing_in_rl)
        print("Missing in SFT:", missing_in_sft)

    all_deltas = []
    num_nonzero_dict = {}

    with torch.no_grad():
        for name, _ in tqdm(sft_model.named_parameters(), desc="Computing deltas"):
            try:
                delta = rl_state_dict[name] - sft_state_dict[name]
                num_nonzero = (delta != 0).sum().item()
                num_nonzero_dict[name] = num_nonzero / delta.numel()
                all_deltas.append(delta.view(-1))
            except Exception as e:
                print(f"Error in {name}: {e}")
    return all_deltas, num_nonzero_dict

In [10]:
all_deltas, num_nonzero_dict = compute_deltas(qwen2_5_sft, qwen2_5_rl)

Computing deltas: 338it [00:14, 23.74it/s]


In [11]:
all_deltas_tensor = torch.cat(all_deltas, dim=0)
print(f"\nDeltas shape: {all_deltas_tensor.size()}")

pct_zeros = (all_deltas_tensor == 0).sum().item() / len(all_deltas_tensor)
print(f"Percentage of 0 values in the task vector: {pct_zeros:.4f}")


Deltas shape: torch.Size([1543714304])
Percentage of 0 values in the task vector: 0.0074


In [13]:
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]

for tol in tolerances:
    fraction_close_to_zero = torch.isclose(
        all_deltas_tensor, torch.tensor(0, dtype=all_deltas_tensor.dtype), atol=tol
    ).sum().item() / all_deltas_tensor.numel()
    print(f"Tolerance = {tol:.0e} -> Fraction close to zero: {fraction_close_to_zero:.4f}")

Tolerance = 1e-05 -> Fraction close to zero: 0.0075
Tolerance = 1e-04 -> Fraction close to zero: 0.0125
Tolerance = 1e-03 -> Fraction close to zero: 0.1368
Tolerance = 1e-02 -> Fraction close to zero: 0.8575
