# 10 — Kimi K2: Joint RL with Verifiable + Rubric Rewards

> **Purpose:** Understand Moonshot AI's dual-reward RL system: verifiable rewards for objective tasks + self-critique rubrics for open-ended tasks.

**Key insight:** Combine hard verification (unit tests, math checks) with soft self-critique (rubric-based evaluation) in a unified RL framework.

| Reward Type | For Tasks | Mechanism | Example |
|-------------|-----------|-----------|--------|
| **Verifiable** | Code, Math | Rule-based pass/fail | Unit tests |
| **Rubric-based** | Open-ended | Self-critique scoring | Helpfulness rubric |

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Callable
from dataclasses import dataclass
import re

torch.manual_seed(42)

## 1. Kimi K2 Architecture Overview

```
┌──────────────────────────────────────────────────────────────┐
│                    Kimi K2 (1T params)                       │
│                    32B active / token                        │
├──────────────────────────────────────────────────────────────┤
│  MoE: 384 experts, 8+1 active per token                      │
│  61 layers, 64 attention heads                               │
│  Context: 128K-256K tokens                                   │
├──────────────────────────────────────────────────────────────┤
│                   Joint RL Training                          │
│  ┌─────────────────┐    ┌─────────────────┐                  │
│  │ Verifiable      │    │ Self-Critique   │                  │
│  │ Rewards         │    │ Rubric Rewards  │                  │
│  │ (Code, Math)    │    │ (Open-ended)    │                  │
│  └────────┬────────┘    └────────┬────────┘                  │
│           └──────────┬───────────┘                           │
│                      ▼                                       │
│              Combined Reward → Policy Update                 │
└──────────────────────────────────────────────────────────────┘
```

## 2. Verifiable Reward System

For tasks with objective ground truth (code, math).

In [None]:
@dataclass
class VerificationResult:
    """Result of verifiable reward computation."""
    passed: bool
    score: float  # 0.0 or 1.0 for binary
    details: str


class CodeVerifier:
    """
    Verify code correctness via test execution.
    
    Kimi K2 uses this for coding tasks where
    unit tests provide objective verification.
    """
    
    def verify(self, code: str, test_cases: List[Dict]) -> VerificationResult:
        """
        Run code against test cases.
        
        Args:
            code: Generated code solution
            test_cases: List of {'input': ..., 'expected': ...}
        
        Returns:
            VerificationResult with pass/fail
        """
        passed_count = 0
        total = len(test_cases)
        
        for test in test_cases:
            try:
                # In production: sandboxed execution
                # Here: simplified simulation
                local_vars = {}
                exec(code, {}, local_vars)
                
                # Check if function exists and run it
                func_name = self._extract_function_name(code)
                if func_name and func_name in local_vars:
                    result = local_vars[func_name](test['input'])
                    if result == test['expected']:
                        passed_count += 1
            except Exception as e:
                # Test failed
                pass
        
        success = passed_count == total
        score = 1.0 if success else 0.0
        
        return VerificationResult(
            passed=success,
            score=score,
            details=f"{passed_count}/{total} tests passed"
        )
    
    def _extract_function_name(self, code: str) -> Optional[str]:
        """Extract the main function name from code."""
        match = re.search(r'def (\w+)\(', code)
        return match.group(1) if match else None


class MathVerifier:
    """
    Verify mathematical answer correctness.
    
    Supports exact matching and numerical tolerance.
    """
    
    def verify(self, response: str, expected: str, 
               tolerance: float = 1e-6) -> VerificationResult:
        """
        Check if extracted answer matches expected.
        
        Args:
            response: Model's full response
            expected: Ground truth answer
            tolerance: Numerical tolerance for float comparison
        
        Returns:
            VerificationResult
        """
        # Extract answer from response (look for "Answer: X" pattern)
        extracted = self._extract_answer(response)
        
        if extracted is None:
            return VerificationResult(
                passed=False, score=0.0,
                details="No answer found in response"
            )
        
        # Try numerical comparison
        try:
            extracted_num = float(extracted)
            expected_num = float(expected)
            passed = abs(extracted_num - expected_num) < tolerance
        except ValueError:
            # String comparison
            passed = extracted.strip().lower() == expected.strip().lower()
        
        return VerificationResult(
            passed=passed,
            score=1.0 if passed else 0.0,
            details=f"Extracted '{extracted}', expected '{expected}'"
        )
    
    def _extract_answer(self, response: str) -> Optional[str]:
        """Extract final answer from response."""
        patterns = [
            r'[Aa]nswer[:\s]+([\d\.\-]+)',
            r'\\boxed{([^}]+)}',
            r'= ([\d\.\-]+)$',
        ]
        for pattern in patterns:
            match = re.search(pattern, response)
            if match:
                return match.group(1)
        return None

In [None]:
# TEST: Verifiable rewards

code_verifier = CodeVerifier()
math_verifier = MathVerifier()

# Code verification test
code = '''
def add_numbers(x):
    return x[0] + x[1]
'''

test_cases = [
    {'input': [1, 2], 'expected': 3},
    {'input': [5, 7], 'expected': 12},
]

code_result = code_verifier.verify(code, test_cases)
print(f"Code verification: {code_result}")

# Math verification test
response = "Let x = 5. Then 2x + 3 = 13. Answer: 13"
math_result = math_verifier.verify(response, "13")
print(f"Math verification: {math_result}")

## 3. Self-Critique Rubric System

For open-ended tasks where objective verification is impossible.

In [None]:
@dataclass
class RubricCriterion:
    """Single criterion in a rubric."""
    name: str
    description: str
    weight: float = 1.0


@dataclass
class RubricScore:
    """Score for a single criterion."""
    criterion: str
    score: float  # 0-1
    reasoning: str


class SelfCritiqueRubric:
    """
    Kimi K2's self-critique system using rubrics.
    
    The model evaluates its own outputs against
    human-defined criteria, generating preference
    signals for RL training.
    """
    
    def __init__(self, criteria: List[RubricCriterion]):
        self.criteria = criteria
    
    def generate_critique_prompt(self, query: str, response: str) -> str:
        """
        Generate prompt for self-critique.
        
        Args:
            query: Original user query
            response: Model's response to evaluate
        
        Returns:
            Prompt for the critique model
        """
        criteria_text = "\n".join([
            f"- {c.name}: {c.description}" 
            for c in self.criteria
        ])
        
        return f"""Evaluate the following response against these criteria:

CRITERIA:
{criteria_text}

QUERY: {query}

RESPONSE: {response}

For each criterion, provide:
1. Score (0-10)
2. Brief reasoning

Format: [CRITERION]: [SCORE]/10 - [REASONING]"""
    
    def parse_critique(self, critique_output: str) -> List[RubricScore]:
        """
        Parse critique model output into scores.
        
        Args:
            critique_output: Raw model critique
        
        Returns:
            List of RubricScore objects
        """
        scores = []
        
        for criterion in self.criteria:
            # Look for pattern: "CriterionName: X/10 - reasoning"
            pattern = rf"{criterion.name}[:\s]+(\d+)/10[\s\-]+(.+?)(?=\n|$)"
            match = re.search(pattern, critique_output, re.IGNORECASE)
            
            if match:
                score_val = int(match.group(1)) / 10.0
                reasoning = match.group(2).strip()
            else:
                # Default middle score if parsing fails
                score_val = 0.5
                reasoning = "Unable to parse"
            
            scores.append(RubricScore(
                criterion=criterion.name,
                score=score_val,
                reasoning=reasoning
            ))
        
        return scores
    
    def compute_reward(self, scores: List[RubricScore]) -> float:
        """
        Compute weighted reward from rubric scores.
        
        Args:
            scores: List of RubricScore objects
        
        Returns:
            Weighted average reward in [0, 1]
        """
        total_weight = sum(c.weight for c in self.criteria)
        weighted_sum = 0.0
        
        for score, criterion in zip(scores, self.criteria):
            weighted_sum += score.score * criterion.weight
        
        return weighted_sum / total_weight

In [None]:
# Define standard Kimi K2 rubric for helpfulness

KIMI_HELPFULNESS_RUBRIC = SelfCritiqueRubric([
    RubricCriterion(
        name="Accuracy",
        description="Information is factually correct and up-to-date",
        weight=2.0
    ),
    RubricCriterion(
        name="Relevance",
        description="Response directly addresses the user's query",
        weight=1.5
    ),
    RubricCriterion(
        name="Completeness",
        description="All aspects of the question are covered",
        weight=1.5
    ),
    RubricCriterion(
        name="Clarity",
        description="Response is well-organized and easy to understand",
        weight=1.0
    ),
    RubricCriterion(
        name="Helpfulness",
        description="Response provides actionable or useful information",
        weight=1.0
    ),
])

# Generate critique prompt
query = "How do I learn Python effectively?"
response = "To learn Python: 1) Start with basics. 2) Practice daily. 3) Build projects."

prompt = KIMI_HELPFULNESS_RUBRIC.generate_critique_prompt(query, response)
print("Self-Critique Prompt:")
print("=" * 50)
print(prompt)

In [None]:
# Simulate critique output and parse

simulated_critique = """
Accuracy: 8/10 - The advice is generally correct but lacks specifics.
Relevance: 9/10 - Directly addresses the question about learning Python.
Completeness: 6/10 - Missing details on resources and timeline.
Clarity: 9/10 - Well-structured numbered list format.
Helpfulness: 7/10 - Provides a starting framework but could be more actionable.
"""

scores = KIMI_HELPFULNESS_RUBRIC.parse_critique(simulated_critique)
reward = KIMI_HELPFULNESS_RUBRIC.compute_reward(scores)

print("Parsed Scores:")
for s in scores:
    print(f"  {s.criterion}: {s.score:.1f} - {s.reasoning}")
print(f"\nWeighted Rubric Reward: {reward:.3f}")

## 4. Joint Reward Combiner

Kimi K2 combines verifiable and rubric rewards into unified RL signal.

In [None]:
class JointRewardSystem:
    """
    Kimi K2's joint reward system combining
    verifiable and self-critique rewards.
    """
    
    def __init__(self,
                 code_verifier: CodeVerifier,
                 math_verifier: MathVerifier,
                 rubric: SelfCritiqueRubric,
                 verifiable_weight: float = 0.7,
                 rubric_weight: float = 0.3):
        
        self.code_verifier = code_verifier
        self.math_verifier = math_verifier
        self.rubric = rubric
        self.verifiable_weight = verifiable_weight
        self.rubric_weight = rubric_weight
    
    def compute_reward(self,
                       response: str,
                       task_type: str,
                       ground_truth: Optional[str] = None,
                       test_cases: Optional[List[Dict]] = None,
                       critique_output: Optional[str] = None) -> Dict:
        """
        Compute joint reward for a response.
        
        Args:
            response: Model's response
            task_type: 'code', 'math', or 'open_ended'
            ground_truth: Expected answer (for math)
            test_cases: Test cases (for code)
            critique_output: Self-critique result (for open-ended)
        
        Returns:
            Dict with 'total_reward', 'verifiable_reward', 'rubric_reward'
        """
        verifiable_reward = 0.0
        rubric_reward = 0.0
        has_verifiable = False
        
        # Verifiable rewards
        if task_type == 'code' and test_cases:
            result = self.code_verifier.verify(response, test_cases)
            verifiable_reward = result.score
            has_verifiable = True
        
        elif task_type == 'math' and ground_truth:
            result = self.math_verifier.verify(response, ground_truth)
            verifiable_reward = result.score
            has_verifiable = True
        
        # Rubric rewards (always applicable)
        if critique_output:
            scores = self.rubric.parse_critique(critique_output)
            rubric_reward = self.rubric.compute_reward(scores)
        
        # Combine rewards
        if has_verifiable:
            # Use weighted combination
            total = (self.verifiable_weight * verifiable_reward + 
                     self.rubric_weight * rubric_reward)
        else:
            # Open-ended: only rubric
            total = rubric_reward
        
        return {
            'total_reward': total,
            'verifiable_reward': verifiable_reward,
            'rubric_reward': rubric_reward,
            'task_type': task_type,
        }

In [None]:
# TEST: Joint reward system

joint_system = JointRewardSystem(
    code_verifier=CodeVerifier(),
    math_verifier=MathVerifier(),
    rubric=KIMI_HELPFULNESS_RUBRIC,
    verifiable_weight=0.7,
    rubric_weight=0.3,
)

# Code task
code_response = "def add(x): return x[0] + x[1]"
code_reward = joint_system.compute_reward(
    response=code_response,
    task_type='code',
    test_cases=[{'input': [1, 2], 'expected': 3}],
    critique_output="Accuracy: 10/10\nRelevance: 10/10\nCompleteness: 8/10\nClarity: 9/10\nHelpfulness: 9/10"
)
print(f"Code Task: {code_reward}")

# Math task
math_response = "2 + 2 = 4. Answer: 4"
math_reward = joint_system.compute_reward(
    response=math_response,
    task_type='math',
    ground_truth="4",
    critique_output="Accuracy: 10/10\nRelevance: 10/10\nCompleteness: 10/10\nClarity: 10/10\nHelpfulness: 10/10"
)
print(f"Math Task: {math_reward}")

# Open-ended task
open_response = "Here's how to cook pasta..."
open_reward = joint_system.compute_reward(
    response=open_response,
    task_type='open_ended',
    critique_output="Accuracy: 8/10\nRelevance: 9/10\nCompleteness: 7/10\nClarity: 8/10\nHelpfulness: 8/10"
)
print(f"Open-ended Task: {open_reward}")

## 5. Agentic Tool-Use Reward

Kimi K2 excels at multi-step tool orchestration (200-300 calls).

In [None]:
@dataclass
class ToolCall:
    """Representation of a tool call."""
    tool_name: str
    arguments: Dict
    result: Optional[str] = None
    success: bool = True


class AgenticRewardSystem:
    """
    Reward system for multi-step agentic tasks.
    
    Evaluates tool call sequences for correctness,
    efficiency, and goal achievement.
    """
    
    def __init__(self,
                 step_reward: float = 0.01,
                 success_bonus: float = 1.0,
                 failure_penalty: float = -0.1,
                 efficiency_weight: float = 0.2):
        
        self.step_reward = step_reward
        self.success_bonus = success_bonus
        self.failure_penalty = failure_penalty
        self.efficiency_weight = efficiency_weight
    
    def compute_trajectory_reward(
        self,
        tool_calls: List[ToolCall],
        goal_achieved: bool,
        optimal_steps: Optional[int] = None
    ) -> Dict:
        """
        Compute reward for a sequence of tool calls.
        
        Args:
            tool_calls: List of tool calls made
            goal_achieved: Whether the final goal was met
            optimal_steps: Known optimal number of steps (for efficiency)
        
        Returns:
            Reward breakdown
        """
        num_steps = len(tool_calls)
        successful_steps = sum(1 for t in tool_calls if t.success)
        failed_steps = num_steps - successful_steps
        
        # Base reward for progress
        step_rewards = successful_steps * self.step_reward
        step_penalties = failed_steps * self.failure_penalty
        
        # Goal achievement bonus
        goal_reward = self.success_bonus if goal_achieved else 0.0
        
        # Efficiency bonus (if optimal is known)
        efficiency_reward = 0.0
        if optimal_steps and goal_achieved:
            efficiency = optimal_steps / max(num_steps, 1)
            efficiency_reward = self.efficiency_weight * efficiency
        
        total = step_rewards + step_penalties + goal_reward + efficiency_reward
        
        return {
            'total_reward': total,
            'step_rewards': step_rewards,
            'step_penalties': step_penalties,
            'goal_reward': goal_reward,
            'efficiency_reward': efficiency_reward,
            'num_steps': num_steps,
            'success_rate': successful_steps / max(num_steps, 1),
        }

In [None]:
# TEST: Agentic reward

agentic_system = AgenticRewardSystem()

# Simulate a successful multi-step task
tool_calls = [
    ToolCall("search", {"query": "Python tutorial"}, "Found 10 results", True),
    ToolCall("read_url", {"url": "python.org"}, "Content loaded", True),
    ToolCall("extract", {"selector": "code"}, "Code examples", True),
    ToolCall("write_file", {"path": "notes.md"}, "Saved", True),
]

result = agentic_system.compute_trajectory_reward(
    tool_calls=tool_calls,
    goal_achieved=True,
    optimal_steps=4
)

print("Agentic Task Reward:")
for k, v in result.items():
    print(f"  {k}: {v}")

## 6. Complete Joint RL Training Loop

Simplified version of Kimi K2's joint RL framework.

In [None]:
class KimiK2JointRLTrainer:
    """
    Simplified Kimi K2 Joint RL Trainer.
    
    Combines:
    - Verifiable rewards (code, math)
    - Self-critique rubric rewards
    - Agentic trajectory rewards
    """
    
    def __init__(self,
                 joint_reward: JointRewardSystem,
                 agentic_reward: AgenticRewardSystem):
        
        self.joint_reward = joint_reward
        self.agentic_reward = agentic_reward
    
    def compute_batch_rewards(
        self,
        batch: List[Dict]
    ) -> List[float]:
        """
        Compute rewards for a batch of samples.
        
        Args:
            batch: List of dicts with task info
        
        Returns:
            List of reward values
        """
        rewards = []
        
        for sample in batch:
            task_type = sample.get('task_type', 'open_ended')
            
            if task_type == 'agentic':
                # Use agentic reward system
                result = self.agentic_reward.compute_trajectory_reward(
                    tool_calls=sample.get('tool_calls', []),
                    goal_achieved=sample.get('goal_achieved', False),
                )
                rewards.append(result['total_reward'])
            else:
                # Use joint verifiable + rubric system
                result = self.joint_reward.compute_reward(
                    response=sample.get('response', ''),
                    task_type=task_type,
                    ground_truth=sample.get('ground_truth'),
                    test_cases=sample.get('test_cases'),
                    critique_output=sample.get('critique_output'),
                )
                rewards.append(result['total_reward'])
        
        return rewards

In [None]:
# TEST: Complete Joint RL system

trainer = KimiK2JointRLTrainer(
    joint_reward=joint_system,
    agentic_reward=agentic_system,
)

# Mixed batch of tasks
batch = [
    {'task_type': 'math', 'response': 'Answer: 42', 'ground_truth': '42',
     'critique_output': 'Accuracy: 10/10\nRelevance: 10/10\nCompleteness: 10/10\nClarity: 10/10\nHelpfulness: 10/10'},
    {'task_type': 'code', 'response': 'def f(x): return x*2',
     'test_cases': [{'input': [5], 'expected': 10}],
     'critique_output': 'Accuracy: 8/10\nRelevance: 9/10\nCompleteness: 7/10\nClarity: 8/10\nHelpfulness: 8/10'},
    {'task_type': 'agentic',
     'tool_calls': [ToolCall('search', {}, 'ok', True), ToolCall('parse', {}, 'ok', True)],
     'goal_achieved': True},
]

rewards = trainer.compute_batch_rewards(batch)

print("Batch Rewards:")
for i, (sample, reward) in enumerate(zip(batch, rewards)):
    print(f"  Sample {i+1} ({sample['task_type']}): {reward:.3f}")

## 7. Summary: Kimi K2 Joint RL

| Component | Purpose | Implementation |
|-----------|---------|----------------|
| **Verifiable Rewards** | Objective tasks | Code tests, math checks |
| **Self-Critique Rubrics** | Open-ended tasks | LLM evaluates own output |
| **Agentic Rewards** | Tool orchestration | Trajectory success + efficiency |
| **Joint Weighting** | Combine signals | 70% verifiable + 30% rubric |

### Key Innovations

1. **Dual-reward system** scales to all task types
2. **Self-critique** provides scalable feedback without human labels
3. **Online critic improvement** using verifiable task data
4. **Agentic synthesis pipeline** for tool-use training data

---
**Tier 3 Progress:** Section 10 Complete!