In [1]:
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer, PreTrainedModel


def compute_response_kl(
    actor_model: PreTrainedModel, 
    reference_model: PreTrainedModel, 
    tokenizer: PreTrainedTokenizer, 
    prompt: str, 
    response: str, 
    device: str = "cuda:1" if torch.cuda.is_available() else "cpu"
) -> float:
    full_text = prompt + response
    inputs = tokenizer(full_text, return_tensors='pt').to(device)
    print(f"inputs: {inputs}")
    input_ids = inputs["input_ids"]
    print(f"input_len: {input_ids.shape[-1]}")

    with tokenizer.as_target_tokenizer():
        prompt_ids = tokenizer(prompt, return_tensors='pt')["input_ids"].to(device)
    prompt_len = prompt_ids.shape[-1]
    print(f"prompt_len: {prompt_len}")

    with torch.no_grad():
        actor_logits = actor_model(input_ids).logits
        reference_logits = reference_model(input_ids).logits
        print(f"actor logits shape: {actor_logits.shape}")
        print(f"reference logits shape: {reference_logits.shape}")

    """下一个词是当前词预测的label，满足自回归性质
        去掉最后一个token，因为最后一个token没有下一个预测，也就是没有label"""
    actor_logits = actor_logits[:, :-1, :]
    reference_logits = reference_logits[:, :-1, :]
    print(f"actor logits shape: {actor_logits.shape}")
    print(f"reference logits shape: {reference_logits.shape}")
    labels = input_ids[:, 1: ] # 去掉第一个token，因为第一个token不是label

    log_probs_actor = F.log_softmax(actor_logits, dim=-1)
    log_probs_reference = F.log_softmax(reference_logits, dim=-1)
    # 取log
    print(f"log_probs_actor: {log_probs_actor.shape}")
    print(f"log_probs_reference: {log_probs_reference.shape}")

    logp_actor = log_probs_actor.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    logp_reference = log_probs_reference.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    print(f"logp_actor shape: {logp_actor.shape}")
    print(f"logp_reference shape: {logp_reference.shape}")
       
    response_mask = torch.zeros_like(labels, dtype=torch.float)
    response_mask[:, prompt_len - 1:] = 1.0  # prompt_len-1 开始是 response 的第一个 token

    # 只关注生成阶段的token
    kl_tokenwise = (logp_actor - logp_reference) * response_mask
    kl_total = kl_tokenwise.sum().item()  # 总 KL

    # 可选：返回平均 KL
    # avg_kl = kl_total / response_mask.sum().item()

    return kl_total


if __name__ == '__main__':
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    actor_model = AutoModelForCausalLM.from_pretrained("gpt2").eval().to("cuda:1")
    ref_model = AutoModelForCausalLM.from_pretrained("gpt2").eval().to("cuda:1")

    prompt = "What is the capital of France?\n"
    response = "The capital of France is Paris."

    kl = compute_response_kl(actor_model, ref_model, tokenizer, prompt, response)
    print(f"KL Divergence: {kl:.4f}")


inputs: {'input_ids': tensor([[2061,  318,  262, 3139,  286, 4881,   30,  198,  464, 3139,  286, 4881,
          318, 6342,   13]], device='cuda:1'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:1')}
input_len: 15
prompt_len: 8




actor logits shape: torch.Size([1, 15, 50257])
reference logits shape: torch.Size([1, 15, 50257])
actor logits shape: torch.Size([1, 14, 50257])
reference logits shape: torch.Size([1, 14, 50257])
log_probs_actor: torch.Size([1, 14, 50257])
log_probs_reference: torch.Size([1, 14, 50257])
logp_actor shape: torch.Size([1, 14])
logp_reference shape: torch.Size([1, 14])
KL Divergence: 0.0000
