# Marxist-Leninist GRPO Training

This notebook fine-tunes **DeepSeek-R1-0528-Qwen3-8B** on the ProleWiki corpus using GRPO (Group Relative Policy Optimization).

**Goal:** Train the model to reason through political theory questions using dialectical materialist analysis, showing reasoning in `<think>` tags.

**Hardware:** Optimized for A40 (48GB VRAM)

**Dataset:** 1,058 Q&A pairs from ProleWiki covering:
- Revisionism and opportunism
- Dialectical and historical materialism
- Anti-colonial theory (Fanon, Rodney, Nkrumah)
- Revolutionary theory (Jackson, Sankara, PFLP)
- Marxist political economy

---

**Based on:** [Unsloth GRPO Notebook](https://github.com/unslothai/notebooks)

## Installation

In [None]:
%%capture
import os

os.environ["UNSLOTH_VLLM_STANDBY"] = "1"  # Extra 30% context lengths!

# Install dependencies
!pip install unsloth vllm sentence-transformers numpy
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

# spaCy with TRANSFORMER model for better semantic understanding
!pip install spacy spacy-curated-transformers
!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl

## Load Model

We load the DeepSeek-R1-0528-Qwen3-8B model with **16-bit precision** and vLLM fast inference.

**GRPO Requirement:** Unlike SFT, GRPO training requires `load_in_4bit=False` because:
1. LoRA adapters need 16-bit precision for proper gradient flow during policy optimization
2. The vLLM generation pipeline handles memory efficiently with `fast_inference=True`
3. `gpu_memory_utilization=0.6` is conservative to leave headroom for reward model inference

In [None]:
from unsloth import FastLanguageModel

# Configuration
MODEL_NAME = "unsloth/DeepSeek-R1-0528-Qwen3-8B"
MAX_SEQ_LENGTH = 2048  # Longer for detailed political theory responses
LORA_RANK = 32
GPU_MEMORY_UTILIZATION = 0.6  # Conservative for 16-bit GRPO

# IMPORTANT: For GRPO training, load_in_4bit must be False
# GRPO requires 16-bit LoRA adapters for proper gradient flow
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=False,  # Must be False for GRPO LoRA 16bit
    fast_inference=True,  # Enable vLLM
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
)

print(f"Model type: {model.config.model_type}")

## Apply LoRA

Apply LoRA adapters to attention and feed-forward layers for efficient fine-tuning.

**GRPO Configuration:**
- `lora_alpha = lora_rank` (same value, not doubled like some SFT configs)
- `use_gradient_checkpointing = "unsloth"` for 30% VRAM reduction
- Targets all attention + MLP layers for comprehensive adaptation

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=LORA_RANK,  # Same as r for GRPO (not r*2)
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")

## Chat Template

The DeepSeek-R1 distilled model uses a specific chat template. Let's verify it works correctly.

In [None]:
# Find special tokens
reasoning_start = None
reasoning_end = None

for token in tokenizer.get_added_vocab():
    if "think" in token and "/" in token:
        reasoning_end = token
    elif "think" in token:
        reasoning_start = token

print(f"Reasoning start: {reasoning_start}")
print(f"Reasoning end: {reasoning_end}")

# Our system prompt for Marxist-Leninist reasoning
SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
Think through political theory questions using dialectical materialist analysis.
Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""

print(f"\nSystem prompt:\n{SYSTEM_PROMPT}")

In [None]:
# Test chat template
test_messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "What is revisionism?"},
]

print(tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True))

## Load Dataset

Load the GRPO-formatted dataset from `grpo_dataset.jsonl`.

In [None]:
from pathlib import Path

from datasets import Dataset

DATA_PATH = Path("grpo_dataset.jsonl")

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Dataset not found: {DATA_PATH}\n" "Run 'python transform_to_grpo.py' first!"
    )

dataset = Dataset.from_json(str(DATA_PATH))
print(f"Loaded {len(dataset)} examples")

# Show sample
sample = dataset[0]
print(f"\nSample prompt: {sample['prompt'][1]['content'][:80]}...")
print(f"Sample answer: {sample['answer'][:80]}...")

## Reward Functions

GRPO uses reward functions to guide the model toward desired behaviors:

1. **Format rewards** - Encourage proper `<think>...</think>` structure
2. **NLI coherence** - Check if response ENTAILS ground truth (defeats word soup)
3. **Self-consistency** - Check for internal contradictions (no external ideology)
4. **Structural coherence** - Verify terms in proper syntactic roles
5. **Completeness reward** - Reward thorough, detailed responses

### Research Basis
- [NLI as reward paradigm](https://arxiv.org/abs/2508.18212)
- [MO-GRPO mitigating reward hacking](https://arxiv.org/abs/2509.22047)
- [Process-based rewards](https://arxiv.org/abs/2508.05170)

### Why NLI instead of keyword matching?
Simple terminology matching can be gamed with "word soup" - random Marxist terms without coherent meaning. NLI checks logical consistency, not word overlap.

In [None]:
import re

import numpy as np
from sentence_transformers import SentenceTransformer

# Reasoning format tokens
REASONING_START = "<think>"
REASONING_END = "</think>"

# Regex to match format
SOLUTION_END_REGEX = re.compile(rf"{REASONING_END}(.*)", re.DOTALL)

# Lazy-load models to avoid loading at import time
_embedder = None
_nli_pipeline = None
_spacy_nlp = None


def get_embedder():
    """Get or initialize the sentence transformer embedder."""
    global _embedder
    if _embedder is None:
        print("[Reward] Loading sentence-transformers embedder...")
        _embedder = SentenceTransformer("all-MiniLM-L6-v2")
    return _embedder


def get_nli_pipeline():
    """Get or initialize the NLI pipeline (BART-large-MNLI)."""
    global _nli_pipeline
    if _nli_pipeline is None:
        print("[Reward] Loading NLI model (bart-large-mnli)...")
        import torch
        from transformers import pipeline

        device = "cuda" if torch.cuda.is_available() else "cpu"
        _nli_pipeline = pipeline(
            "text-classification",
            model="facebook/bart-large-mnli",
            device=device,
        )
    return _nli_pipeline


def get_spacy_nlp():
    """Get or initialize spaCy NLP pipeline.

    Uses en_core_web_trf (transformer-based) for superior semantic understanding.
    Falls back to en_core_web_md (word vectors) or en_core_web_sm if unavailable.
    """
    global _spacy_nlp
    if _spacy_nlp is None:
        import spacy

        # Try transformer model first (best semantic understanding)
        models_to_try = ["en_core_web_trf", "en_core_web_md", "en_core_web_sm"]

        for model_name in models_to_try:
            try:
                print(f"[Reward] Loading spaCy model: {model_name}...")
                _spacy_nlp = spacy.load(model_name)
                print(f"[Reward] Loaded {model_name} successfully")
                break
            except OSError:
                print(f"[Reward] {model_name} not found, trying next...")
                continue

        if _spacy_nlp is None:
            raise OSError("No spaCy model found!")
    return _spacy_nlp


# Discourse connectives indicating logical structure
DISCOURSE_CONNECTIVES = {
    "because",
    "therefore",
    "thus",
    "hence",
    "consequently",
    "however",
    "although",
    "whereas",
    "nevertheless",
    "moreover",
    "furthermore",
    "additionally",
    "specifically",
    "namely",
    "as a result",
    "due to",
    "in order to",
    "so that",
    "on the other hand",
    "in contrast",
    "similarly",
    "likewise",
}

# Marxist concept equivalences for topic matching
CONCEPT_EQUIVALENCES = {
    # Class terms
    "bourgeoisie": {"capitalist class", "ruling class", "capitalists", "bourgeois", "capital"},
    "proletariat": {"working class", "workers", "wage laborers", "labor", "labourers"},
    "petty bourgeoisie": {"petit bourgeoisie", "small business", "middle class", "petty bourgeois"},
    # Economic concepts
    "surplus value": {"unpaid labor", "profit", "extraction", "surplus labor"},
    "means of production": {"productive forces", "capital goods", "factories", "industry"},
    "exploitation": {"extraction", "appropriation", "expropriation"},
    # Political concepts
    "dictatorship of the proletariat": {
        "workers state",
        "proletarian dictatorship",
        "workers government",
    },
    "vanguard party": {"vanguard", "communist party", "revolutionary party"},
    # Imperialism
    "imperialism": {"colonialism", "neo-colonialism", "empire", "colonial"},
    "national liberation": {"decolonization", "anti-colonial", "liberation movement"},
    # Ideology
    "revisionism": {"opportunism", "reformism", "right deviation"},
    "hegemony": {"ideological hegemony", "cultural hegemony", "domination"},
    # Philosophy
    "dialectical materialism": {"diamat", "materialist dialectics", "dialectics"},
    "historical materialism": {"histmat", "materialist conception of history"},
}

# Question words to ignore when extracting topics
QUESTION_WORDS = {"what", "how", "why", "who", "when", "where", "which", "whom"}

In [None]:
# Marxist terminology for vocabulary reward
MARXIST_TERMS = {
    # Core concepts
    "dialectical",
    "materialism",
    "historical materialism",
    "dialectical materialism",
    # Classes
    "bourgeoisie",
    "proletariat",
    "petty bourgeois",
    "petty bourgeoisie",
    "lumpenproletariat",
    "working class",
    "ruling class",
    # Class struggle
    "class struggle",
    "class consciousness",
    "class war",
    "class conflict",
    # Political economy
    "surplus value",
    "commodity",
    "use value",
    "exchange value",
    "labor power",
    "means of production",
    "relations of production",
    "forces of production",
    "mode of production",
    "primitive accumulation",
    "exploitation",
    "capital accumulation",
    # Imperialism
    "imperialism",
    "colonialism",
    "neo-colonialism",
    "settler colonialism",
    "national liberation",
    "self-determination",
    # State and revolution
    "dictatorship of the proletariat",
    "vanguard",
    "vanguard party",
    "democratic centralism",
    "withering away of the state",
    "proletarian dictatorship",
    # Ideology
    "hegemony",
    "superstructure",
    "base",
    "ideology",
    "false consciousness",
    # Revisionism
    "revisionism",
    "opportunism",
    "reformism",
    "social democracy",
    "ultra-leftism",
    # Alienation
    "alienation",
    "fetishism",
    "commodity fetishism",
    "reification",
    # Historical
    "paris commune",
    "october revolution",
    "bolshevik",
    "menshevik",
    # Anti-colonial
    "decolonization",
    "third world",
    "global south",
    "national bourgeoisie",
    "comprador",
}

In [None]:
# =============================================================================
# DEPTH ANALYSIS CONSTANTS (Anti-Buzzword-Salad)
# =============================================================================

# Explanatory phrases indicating concept is being explained, not just dropped
EXPLANATORY_PHRASES = {
    # Causal explanations
    "because the",
    "because of",
    "this is because",
    "since the",
    "due to the",
    "as a result of",
    "results from",
    "caused by",
    "leads to",
    "results in",
    "enables",
    "produces",
    # Definitional explanations
    "is defined as",
    "refers to",
    "means that",
    "denotes",
    "that is",
    "in other words",
    "namely",
    "i.e.",
    # Elaboration
    "specifically",
    "in particular",
    "for example",
    "such as",
    "this means",
    "which means",
    "this implies",
    "therefore",
    # Mechanism explanations
    "this occurs when",
    "this happens because",
    "the mechanism",
    "through the process of",
    "by means of",
    "works by",
}

# Hollow buzzwords - activist jargon that substitutes for analysis
HOLLOW_BUZZWORDS = {
    # Vague connectors
    "interconnected",
    "interrelated",
    "intersects with",
    "it's all connected",
    "everything is connected",
    "systemic",
    # Performative language
    "centered",
    "centering",
    "uplift",
    "uplifting",
    "do the work",
    "the work",
    "unpack",
    "unpacking",
    "unlearn",
    "unlearning",
    "hold space",
    "sit with",
    "lean into",
    "problematic",
    "harmful",
    "toxic",
    # Vague abstractions
    "in a way",
    "sort of",
    "kind of",
    "essentially",
    "basically",
    "generally speaking",
    "broadly",
}

# Depth markers - phrases indicating analytical depth
DEPTH_MARKERS = {
    # Historical specificity
    "in 1",
    "in 2",
    "during the",
    "after the",
    "before the",
    # Citations and references
    "marx argued",
    "lenin wrote",
    "engels noted",
    "gramsci",
    "according to",
    "as marx",
    "as lenin",
    # Concrete examples
    "for example",
    "such as",
    "in the case of",
    "consider",
    # Definitions
    "defined as",
    "meaning",
    "specifically",
}

In [None]:
def match_format_exactly(completions, **kwargs):
    """
    Reward +3.0 if response contains proper </think> tag.
    This encourages the model to use the reasoning format.
    """
    scores = []
    for completion in completions:
        score = 0.0
        response = completion[0]["content"]
        if SOLUTION_END_REGEX.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores


def match_format_approximately(completions, **kwargs):
    """
    Reward partial format matching.
    +0.5 for exactly one <think> tag
    +0.5 for exactly one </think> tag
    -1.0 for multiple or missing tags
    """
    scores = []
    for completion in completions:
        score = 0.0
        response = completion[0]["content"]

        start_count = response.count(REASONING_START)
        end_count = response.count(REASONING_END)

        score += 0.5 if start_count == 1 else -1.0
        score += 0.5 if end_count == 1 else -1.0

        scores.append(score)
    return scores

In [None]:
def semantic_similarity_reward(prompts, completions, answer, **kwargs):
    """
    Reward responses that are semantically similar to ground truth.
    Uses sentence-transformers to compute cosine similarity.

    Scoring:
        > 0.75 similarity: +5.0
        > 0.60 similarity: +3.0
        > 0.45 similarity: +1.0
        > 0.30 similarity: -1.0
        <= 0.30 similarity: -3.0
    """
    embedder = get_embedder()
    scores = []

    for completion, true_answer in zip(completions, answer, strict=False):
        response = completion[0]["content"]

        # Extract answer after </think> if present
        if REASONING_END in response:
            response = response.split(REASONING_END, 1)[1].strip()

        # Handle empty response
        if not response or len(response.strip()) < 10:
            scores.append(-3.0)
            continue

        # Compute cosine similarity
        emb_response = embedder.encode(response, normalize_embeddings=True)
        emb_truth = embedder.encode(true_answer, normalize_embeddings=True)
        similarity = float(np.dot(emb_response, emb_truth))

        # Scale to reward
        if similarity > 0.75:
            score = 5.0
        elif similarity > 0.60:
            score = 3.0
        elif similarity > 0.45:
            score = 1.0
        elif similarity > 0.30:
            score = -1.0
        else:
            score = -3.0

        scores.append(score)

    return scores

In [None]:
def terminology_reward(completions, **kwargs):
    """
    Reward use of proper Marxist terminology.
    +0.3 per unique term found, capped at +2.0
    """
    scores = []

    for completion in completions:
        response = completion[0]["content"].lower()

        # Count unique terms present
        term_count = sum(1 for term in MARXIST_TERMS if term in response)

        # Reward: 0.3 per term, capped at 2.0
        score = min(term_count * 0.3, 2.0)
        scores.append(score)

    return scores


def completeness_reward(completions, answer, **kwargs):
    """
    Reward thorough, detailed responses.
    Compares response length to ground truth length.

    Scoring:
        50-150% of target length: +2.0
        30-200% of target length: +1.0
        < 20% (too short): -2.0
        > 200% (too verbose): -0.5
    """
    scores = []

    for completion, true_answer in zip(completions, answer, strict=False):
        response = completion[0]["content"]

        # Extract answer after </think>
        if REASONING_END in response:
            answer_part = response.split(REASONING_END, 1)[1].strip()
        else:
            answer_part = response

        answer_len = len(answer_part.split())
        true_len = len(true_answer.split())

        if true_len == 0:
            scores.append(0.0)
            continue

        ratio = answer_len / true_len

        if 0.5 <= ratio <= 1.5:
            score = 2.0
        elif 0.3 <= ratio <= 2.0:
            score = 1.0
        elif ratio < 0.2:
            score = -2.0
        else:
            score = -0.5

        scores.append(score)

    return scores

In [None]:
# =============================================================================
# NLI-BASED COHERENCE REWARDS (Research-backed, defeats word soup)
# =============================================================================


def nli_coherence_reward(completions, answer, **kwargs):
    """
    Reward responses that logically ENTAIL the ground truth answer.

    Uses Natural Language Inference (facebook/bart-large-mnli) to check
    if the response is logically consistent with the expected answer.

    This defeats "word soup" attacks because random terminology won't
    logically entail anything - it will be classified as NEUTRAL.

    Scoring:
        entailment: +3.0 (response supports/implies ground truth)
        neutral: -1.0 (off-topic or incoherent)
        contradiction: -3.0 (contradicts ground truth)
    """
    nli = get_nli_pipeline()
    scores = []

    for completion, true_answer in zip(completions, answer, strict=False):
        response = completion[0]["content"]

        # Extract answer part after </think>
        if REASONING_END in response:
            response = response.split(REASONING_END, 1)[1].strip()

        # Handle empty or very short responses
        if not response or len(response.strip()) < 20:
            scores.append(-2.0)
            continue

        # Truncate to model max length
        response_truncated = response[:512]
        truth_truncated = true_answer[:512]

        try:
            input_text = f"{response_truncated}</s></s>{truth_truncated}"
            result = nli(input_text)[0]
            label = result["label"].lower()

            if label == "entailment":
                score = 3.0
            elif label == "neutral":
                score = -1.0
            else:  # contradiction
                score = -3.0

            scores.append(score)
        except Exception as e:
            print(f"[NLI Reward] Error: {e}")
            scores.append(0.0)

    return scores


def self_consistency_reward(completions, **kwargs):
    """
    Reward responses that are internally self-consistent.

    Checks if any sentence in the response CONTRADICTS another sentence.
    This avoids external ideological bias by only checking within-document
    coherence.

    Scoring:
        No contradictions found: +1.0
        Internal contradiction detected: -2.0
    """
    nli = get_nli_pipeline()
    nlp = get_spacy_nlp()
    scores = []

    for completion in completions:
        response = completion[0]["content"]

        # Parse into sentences
        doc = nlp(response)
        sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]

        # Need at least 2 sentences to check consistency
        if len(sentences) < 2:
            scores.append(0.0)
            continue

        # Check pairs of sentences for contradictions
        has_contradiction = False
        max_pairs_to_check = 10
        pairs_checked = 0

        for i, sent_a in enumerate(sentences[:-1]):
            if pairs_checked >= max_pairs_to_check:
                break
            for j in range(i + 1, min(i + 3, len(sentences))):
                sent_b = sentences[j]
                try:
                    input_text = f"{sent_a[:256]}</s></s>{sent_b[:256]}"
                    result = nli(input_text)[0]
                    if result["label"].lower() == "contradiction":
                        has_contradiction = True
                        break
                    pairs_checked += 1
                except Exception:
                    pass
            if has_contradiction:
                break

        scores.append(-2.0 if has_contradiction else 1.0)

    return scores


def structural_coherence_reward(completions, **kwargs):
    """
    Reward responses with proper linguistic structure.

    Uses spaCy dependency parsing to verify:
    1. Marxist terms appear in meaningful syntactic roles (subject, object)
    2. Response contains logical discourse connectives
    3. Response has proper sentence structure (not word soup)

    Scoring:
        +0.3 per term in subject/object position (max +1.5)
        +0.2 per discourse connective (max +1.0)
        -1.0 if no complete sentences detected
    """
    nlp = get_spacy_nlp()
    scores = []

    for completion in completions:
        response = completion[0]["content"]
        doc = nlp(response)
        score = 0.0

        # Check 1: Are there actual sentences?
        sentences = list(doc.sents)
        if len(sentences) < 1:
            scores.append(-1.0)
            continue

        # Check 2: Marxist terms in meaningful syntactic roles
        terms_in_context = 0
        response_lower = response.lower()

        for term in MARXIST_TERMS:
            if term not in response_lower:
                continue

            for token in doc:
                # Reward if term found in meaningful syntactic role
                is_meaningful_role = token.dep_ in (
                    "nsubj",
                    "nsubjpass",
                    "dobj",
                    "pobj",
                    "attr",
                    "appos",
                )
                is_verb_root = token.head.pos_ == "VERB" and token.head.dep_ == "ROOT"
                if term in token.text.lower() and (is_meaningful_role or is_verb_root):
                    terms_in_context += 1
                    break

        score += min(terms_in_context * 0.3, 1.5)

        # Check 3: Discourse connectives
        connective_count = sum(1 for conn in DISCOURSE_CONNECTIVES if conn in response_lower)
        score += min(connective_count * 0.2, 1.0)

        scores.append(score)

    return scores


def robust_coherence_reward(completions, answer, **kwargs):
    """
    Multi-layered coherence check combining NLI, self-consistency, and structure.

    This is the RECOMMENDED reward function for robust evaluation that defeats
    reward hacking via word soup or other adversarial strategies.

    Layers:
    1. NLI coherence: Does response entail ground truth?
    2. Self-consistency: Does response contradict itself?
    3. Structural coherence: Are terms used in meaningful syntactic roles?
    """
    nli_scores = nli_coherence_reward(completions, answer, **kwargs)
    consistency_scores = self_consistency_reward(completions, **kwargs)
    structure_scores = structural_coherence_reward(completions, **kwargs)

    combined = []
    for nli, consistency, structure in zip(
        nli_scores, consistency_scores, structure_scores, strict=False
    ):
        if nli <= -3.0:
            combined.append(-3.0)  # Contradiction dominates
        elif consistency <= -2.0:
            combined.append(-2.0)  # Internal contradiction
        else:
            total = nli + (consistency * 0.5) + (structure * 0.5)
            combined.append(total)

    return combined

In [None]:
# =============================================================================
# INTERCONNECTION DEPTH REWARD (Anti-Buzzword-Salad)
# =============================================================================


def _compute_depth_ratio(text: str) -> float:
    """Compute words per unique Marxist concept (depth over breadth)."""
    text_lower = text.lower()
    words = text_lower.split()
    word_count = len(words)

    if word_count < 20:
        return 0.0

    concepts_found = sum(1 for term in MARXIST_TERMS if term in text_lower)

    if concepts_found == 0:
        return float(word_count)

    return word_count / concepts_found


def _count_hollow_buzzwords(text: str) -> int:
    """Count hollow buzzwords in text."""
    text_lower = text.lower()
    return sum(1 for phrase in HOLLOW_BUZZWORDS if phrase in text_lower)


def _count_depth_markers(text: str) -> int:
    """Count depth markers (historical specificity, citations, examples)."""
    text_lower = text.lower()
    return sum(1 for marker in DEPTH_MARKERS if marker in text_lower)


def _count_explanatory_phrases(text: str) -> int:
    """Count explanatory phrases."""
    text_lower = text.lower()
    return sum(1 for phrase in EXPLANATORY_PHRASES if phrase in text_lower)


def interconnection_depth_reward(completions, **kwargs):
    """
    Reward deep, meaningful interconnections; penalize buzzword salad.

    Distinguishes:
    - GOOD: "Surplus value relates to imperialism BECAUSE capital export..."
    - BAD: "Surplus value intersects with imperialism, colonialism, patriarchy..."

    Signals:
    1. Depth ratio: words per unique Marxist concept
       - High (>20): +1.0 (deep analysis - few concepts well-explained)
       - Low (<5): -1.5 (severe buzzword soup)
    2. Hollow buzzword density: -0.3 per additional above 2
    3. Depth markers: +0.3 each (max +1.5)
    4. Explanation ratio: +0.5 if >= 50% of concepts have explanations

    Score range: -2.5 to +3.0
    """
    scores = []

    for completion in completions:
        response = completion[0]["content"]

        # Extract answer after </think>
        if REASONING_END in response:
            response = response.split(REASONING_END, 1)[1].strip()

        if not response or len(response.strip()) < 20:
            scores.append(0.0)
            continue

        score = 0.0

        # Signal 1: Depth ratio (words per concept)
        depth_ratio = _compute_depth_ratio(response)

        if depth_ratio > 20:
            score += 1.0  # Deep analysis
        elif depth_ratio > 10:
            score += 0.5  # Adequate depth
        elif depth_ratio > 5:
            score += -0.5  # Shallow
        elif depth_ratio > 0:
            score += -1.5  # Severe buzzword soup

        # Signal 2: Hollow buzzword penalty
        hollow_count = _count_hollow_buzzwords(response)
        if hollow_count > 2:
            penalty = min((hollow_count - 2) * 0.3, 1.5)
            score -= penalty

        # Signal 3: Depth markers bonus
        depth_marker_count = _count_depth_markers(response)
        score += min(depth_marker_count * 0.3, 1.5)

        # Signal 4: Explanation ratio
        text_lower = response.lower()
        concepts_found = sum(1 for term in MARXIST_TERMS if term in text_lower)
        explanations = _count_explanatory_phrases(response)

        if concepts_found > 0:
            explanation_ratio = explanations / concepts_found
            if explanation_ratio >= 0.5:
                score += 0.5  # Well-explained concepts
            elif explanation_ratio < 0.1 and concepts_found > 3:
                score -= 0.5  # Many concepts, no explanations

        scores.append(score)

    return scores

In [None]:
# =============================================================================
# TOPIC RELEVANCE REWARD (Question-Answer Alignment)
# =============================================================================


def _extract_noun_with_preps(token):
    """
    Extract a noun and its prepositional phrase children.
    For "dictatorship of the proletariat", returns:
    {"dictatorship", "proletariat", "dictatorship of proletariat"}
    """
    topics = set()

    if token.pos_ in ("NOUN", "PROPN"):
        topics.add(token.lemma_.lower())

        # Check for compound modifiers (e.g., "surplus value")
        modifiers = []
        for child in token.children:
            if child.dep_ in ("compound", "amod") and child.pos_ in ("NOUN", "ADJ"):
                modifiers.append(child.text.lower())

        if modifiers:
            full_term = " ".join([*modifiers, token.text.lower()])
            topics.add(full_term)

        # Follow prepositional phrases
        for child in token.children:
            if child.dep_ == "prep":
                for pobj in child.children:
                    if pobj.dep_ == "pobj":
                        topics.add(pobj.lemma_.lower())
                        full_phrase = f"{token.text.lower()} {child.text} {pobj.text.lower()}"
                        topics.add(full_phrase)
                        topics.update(_extract_noun_with_preps(pobj))

    return topics


def _extract_question_topics(doc):
    """
    Extract core topics from a question using spaCy dependency parsing.
    For "What is revisionism?", extracts {"revisionism"}
    """
    topics = set()

    # Find the ROOT
    root = None
    for token in doc:
        if token.dep_ == "ROOT":
            root = token
            break

    if root:
        for child in root.children:
            if child.dep_ in ("nsubj", "dobj", "attr", "nsubjpass"):
                if child.text.lower() in QUESTION_WORDS:
                    continue
                topics.update(_extract_noun_with_preps(child))

            if child.dep_ == "prep":
                for pobj in child.children:
                    if pobj.dep_ == "pobj":
                        topics.update(_extract_noun_with_preps(pobj))

    # Fallback: noun chunks
    if not topics:
        for chunk in doc.noun_chunks:
            if chunk.root.text.lower() not in QUESTION_WORDS:
                topics.add(chunk.root.lemma_.lower())

    topics = {t for t in topics if t not in QUESTION_WORDS}
    return topics


def _extract_answer_topics(doc):
    """Extract topics from an answer. Strips determiners for better matching."""
    topics = set()
    determiners = {"the", "a", "an", "this", "that", "these", "those"}

    for chunk in doc.noun_chunks:
        topics.add(chunk.root.lemma_.lower())

        words = chunk.text.lower().strip().split()
        if words and words[0] in determiners:
            words = words[1:]
        chunk_text = " ".join(words)

        if " " in chunk_text and len(chunk_text) < 50:
            topics.add(chunk_text)

    for ent in doc.ents:
        words = ent.text.lower().split()
        if words and words[0] in determiners:
            words = words[1:]
        topics.add(" ".join(words))

    return topics


def _expand_with_synonyms(topics):
    """Expand topics with Marxist concept synonyms."""
    expanded = set(topics)

    for topic in topics:
        if topic in CONCEPT_EQUIVALENCES:
            expanded.update(CONCEPT_EQUIVALENCES[topic])
        for canonical, synonyms in CONCEPT_EQUIVALENCES.items():
            if topic in synonyms or topic == canonical:
                expanded.add(canonical)
                expanded.update(synonyms)

    return expanded


def _compute_topic_coverage(q_topics, a_topics, nlp):
    """Compute how well answer topics cover question topics."""
    if not q_topics:
        return 0.5

    q_expanded = _expand_with_synonyms(q_topics)
    matched = q_expanded & a_topics
    direct_coverage = len(matched) / len(q_topics) if q_topics else 0

    if direct_coverage >= 0.5:
        return min(direct_coverage, 1.0)

    # Fallback: semantic similarity
    unmatched_q = q_topics - matched
    semantic_matches = 0

    for q_topic in unmatched_q:
        q_token = nlp(q_topic)
        if not q_token.has_vector:
            continue

        best_sim = 0.0
        for a_topic in a_topics:
            a_token = nlp(a_topic)
            if a_token.has_vector:
                sim = q_token.similarity(a_token)
                best_sim = max(best_sim, sim)

        if best_sim > 0.6:
            semantic_matches += 1

    total_matched = len(matched) + semantic_matches
    return min(total_matched / len(q_topics), 1.0) if q_topics else 0.5


def topic_relevance_reward(prompts, completions, **kwargs):
    """
    Reward answers that are ON-TOPIC with respect to the question.

    Implements f(A) in f(Q) check where f extracts semantic topics.
    This ensures the model answers WHAT WAS ASKED, not just generates
    coherent Marxist text about something else.

    Scoring:
        > 80% coverage: +2.0 (answer fully addresses question topics)
        > 60% coverage: +1.5 (answer mostly on-topic)
        > 40% coverage: +1.0 (answer partially on-topic)
        > 20% coverage: 0.0 (answer tangentially related)
        <= 20% coverage: -1.5 (answer off-topic)
    """
    nlp = get_spacy_nlp()
    scores = []

    for prompt, completion in zip(prompts, completions, strict=False):
        question = prompt[-1]["content"]
        response = completion[0]["content"]

        if REASONING_END in response:
            response = response.split(REASONING_END, 1)[1].strip()

        if not response or len(response.strip()) < 20:
            scores.append(-1.5)
            continue

        q_doc = nlp(question)
        a_doc = nlp(response[:2000])

        q_topics = _extract_question_topics(q_doc)
        a_topics = _extract_answer_topics(a_doc)

        if not q_topics:
            scores.append(0.5 if len(a_topics) > 3 else 0.0)
            continue

        coverage = _compute_topic_coverage(q_topics, a_topics, nlp)

        if coverage > 0.8:
            score = 2.0
        elif coverage > 0.6:
            score = 1.5
        elif coverage > 0.4:
            score = 1.0
        elif coverage > 0.2:
            score = 0.0
        else:
            score = -1.5

        scores.append(score)

    return scores


def full_coherence_reward(prompts, completions, answer, **kwargs):
    """
    Complete coherence check: robust_coherence + topic_relevance + depth.

    This is the MOST COMPREHENSIVE reward function, checking:
    1. NLI coherence (A entails ground truth)
    2. Self-consistency (A doesn't contradict itself)
    3. Structural coherence (terms in proper syntactic roles)
    4. Topic relevance (A addresses what Q asked about)
    5. Interconnection depth (rewards deep analysis, penalizes buzzword salad)

    Use this for maximum robustness against reward hacking.
    """
    robust_scores = robust_coherence_reward(completions, answer, **kwargs)
    relevance_scores = topic_relevance_reward(prompts, completions, **kwargs)
    depth_scores = interconnection_depth_reward(completions, **kwargs)

    combined = []
    for robust, relevance, depth in zip(
        robust_scores, relevance_scores, depth_scores, strict=False
    ):
        if relevance <= -1.5:
            combined.append(-2.0)  # Severely off-topic
        elif robust <= -2.0:
            combined.append(robust)  # Robust check failed
        elif depth <= -1.5:
            combined.append(-1.5)  # Buzzword salad detected
        else:
            total = robust + (relevance * 0.4) + (depth * 0.3)
            combined.append(total)

    return combined

In [None]:
# Debug reward function for monitoring during training
_PRINT_COUNTER = 0
_PRINT_EVERY = 10


def debug_print_reward(prompts, completions, answer, **kwargs):
    """
    Print sample outputs periodically for monitoring.
    Returns 0.0 (no effect on training).
    """
    global _PRINT_COUNTER

    if _PRINT_COUNTER % _PRINT_EVERY == 0:
        question = prompts[0][-1]["content"]
        response = completions[0][0]["content"]
        true_answer = answer[0]

        print("=" * 60)
        print(f"Step {_PRINT_COUNTER}")
        print(f"Question: {question[:100]}...")
        print(f"Response: {response[:200]}...")
        print(f"Expected: {true_answer[:100]}...")
        print("=" * 60)

    _PRINT_COUNTER += 1

    return [0.0] * len(completions)

In [None]:
# =============================================================================
# WEIGHTS & BIASES LOGGING
# =============================================================================

USE_WANDB = False  # Set to True to enable W&B logging

# Lazy wandb import
_wandb_module = None


def _get_wandb():
    """Lazily import wandb module."""
    global _wandb_module
    if _wandb_module is None:
        try:
            import wandb

            _wandb_module = wandb
        except ImportError:
            print("[WandbLogging] wandb not installed. Install with: pip install wandb")
            _wandb_module = False
    return _wandb_module if _wandb_module else None


class WandbSampleLogger:
    """Logs sample tables to W&B for debugging reward functions."""

    def __init__(self, log_every_n_steps=10, max_samples_per_log=4):
        self.log_every_n_steps = log_every_n_steps
        self.max_samples_per_log = max_samples_per_log
        self._samples = []
        self._table_columns = [
            "step",
            "question",
            "response",
            "ground_truth",
            "format_exact",
            "nli_coherence",
            "topic_relevance",
            "depth",
            "completeness",
            "total",
        ]

    def add_sample(self, step, question, response, ground_truth, rewards):
        """Add a sample to the buffer."""
        self._samples.append(
            {
                "step": step,
                "question": question[:500],
                "response": response[:500],
                "ground_truth": ground_truth[:300],
                "rewards": rewards,
            }
        )
        # Keep only recent samples
        max_buffer = self.max_samples_per_log * 3
        if len(self._samples) > max_buffer:
            self._samples = self._samples[-max_buffer:]

    def should_log(self, step):
        """Check if we should log at this step."""
        return step > 0 and step % self.log_every_n_steps == 0

    def log_table(self, step):
        """Log accumulated samples as a wandb.Table."""
        wandb = _get_wandb()
        if wandb is None or not self._samples:
            return

        samples_to_log = self._samples[-self.max_samples_per_log :]
        table = wandb.Table(columns=self._table_columns)

        for sample in samples_to_log:
            rewards = sample["rewards"]
            total = sum(rewards.values())
            row = [
                sample["step"],
                sample["question"],
                sample["response"],
                sample["ground_truth"],
                rewards.get("format_exact", 0.0),
                rewards.get("nli_coherence", 0.0),
                rewards.get("topic_relevance", 0.0),
                rewards.get("interconnection_depth", 0.0),
                rewards.get("completeness", 0.0),
                total,
            ]
            table.add_data(*row)

        wandb.log({"samples": table}, step=step)
        print(f"[WandbLogging] Logged {len(samples_to_log)} samples at step {step}")

    def clear(self):
        """Clear the sample buffer."""
        self._samples.clear()


def log_reward_metrics(step, reward_scores):
    """Log reward metrics to wandb."""
    wandb = _get_wandb()
    if wandb is None:
        return

    metrics = {}
    for name, scores in reward_scores.items():
        if not scores:
            continue
        metrics[f"rewards/{name}"] = sum(scores) / len(scores)
        metrics[f"rewards/{name}_min"] = min(scores)
        metrics[f"rewards/{name}_max"] = max(scores)

    # Compute total
    if reward_scores:
        all_totals = []
        num_samples = len(next(iter(reward_scores.values())))
        for i in range(num_samples):
            total = sum(scores[i] for scores in reward_scores.values() if i < len(scores))
            all_totals.append(total)
        if all_totals:
            metrics["rewards/total"] = sum(all_totals) / len(all_totals)

    wandb.log(metrics, step=step)


# Global step counter for logging reward
_WANDB_LOGGING_STEP = 0


def create_logging_reward(sample_logger=None, compute_all_rewards=True):
    """
    Create a reward function that logs metrics and samples to wandb.
    Returns [0.0] * len(completions) (no training effect).
    """
    global _WANDB_LOGGING_STEP

    def logging_reward(prompts, completions, answer, **kwargs):
        global _WANDB_LOGGING_STEP
        _WANDB_LOGGING_STEP += 1
        step = _WANDB_LOGGING_STEP

        wandb = _get_wandb()
        if wandb is None or wandb.run is None:
            return [0.0] * len(completions)

        # Compute all reward scores if requested
        if compute_all_rewards:
            reward_scores = {}
            try:
                reward_scores["format_exact"] = match_format_exactly(completions, **kwargs)
            except Exception:
                reward_scores["format_exact"] = [0.0] * len(completions)
            try:
                reward_scores["nli_coherence"] = nli_coherence_reward(completions, answer, **kwargs)
            except Exception:
                reward_scores["nli_coherence"] = [0.0] * len(completions)
            try:
                reward_scores["topic_relevance"] = topic_relevance_reward(
                    prompts, completions, **kwargs
                )
            except Exception:
                reward_scores["topic_relevance"] = [0.0] * len(completions)
            try:
                reward_scores["interconnection_depth"] = interconnection_depth_reward(
                    completions, **kwargs
                )
            except Exception:
                reward_scores["interconnection_depth"] = [0.0] * len(completions)
            try:
                reward_scores["completeness"] = completeness_reward(completions, answer, **kwargs)
            except Exception:
                reward_scores["completeness"] = [0.0] * len(completions)

            log_reward_metrics(step, reward_scores)
        else:
            reward_scores = {}

        # Log samples periodically
        if sample_logger and sample_logger.should_log(step):
            for i in range(min(sample_logger.max_samples_per_log, len(prompts))):
                question = prompts[i][-1]["content"]
                response = completions[i][0]["content"]
                truth = answer[i] if i < len(answer) else ""

                sample_rewards = {
                    name: scores[i] if i < len(scores) else 0.0
                    for name, scores in reward_scores.items()
                }

                sample_logger.add_sample(
                    step=step,
                    question=question,
                    response=response,
                    ground_truth=truth,
                    rewards=sample_rewards,
                )

            sample_logger.log_table(step)

        return [0.0] * len(completions)

    return logging_reward


def init_wandb_logging(project, config, name=None, tags=None):
    """Initialize W&B logging for GRPO training."""
    wandb = _get_wandb()
    if wandb is None:
        return None

    run = wandb.init(
        project=project,
        config=config,
        name=name,
        tags=tags or ["grpo", "marxist-leninist"],
    )
    print(f"[WandbLogging] Initialized run: {run.name}")
    print(f"[WandbLogging] View at: {run.url}")
    return run


def finish_wandb_logging(summary=None):
    """Finish the wandb run with optional summary statistics."""
    wandb = _get_wandb()
    if wandb is None or wandb.run is None:
        return

    if summary:
        for key, value in summary.items():
            wandb.run.summary[key] = value

    wandb.finish()
    print("[WandbLogging] Run finished.")

## Weights & Biases Logging (Optional)

Enable W&B logging for comprehensive training observability:
- Per-step reward metrics (mean/min/max for each reward function)
- Sample tables showing question → response → reward breakdowns
- Training curves for debugging reward function behavior

**To enable:** Set `USE_WANDB = True` below and run `wandb login` first.

## Training Configuration

Configure GRPO training with A40-optimized settings.

In [None]:
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

# Training hyperparameters
MAX_STEPS = 250
SAVE_STEPS = 50
LEARNING_RATE = 5e-6
WARMUP_RATIO = 0.1

# A40 optimized batch settings
BATCH_SIZE = 2
GRADIENT_ACCUMULATION = 2
NUM_GENERATIONS = 4

# Sequence lengths
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 1500

# Output paths
OUTPUT_DIR = "outputs/marxist-grpo"
LORA_OUTPUT = "outputs/marxist-grpo-lora"

In [None]:
# vLLM sampling parameters (temperature is set in GRPOConfig)
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,  # No nucleus sampling
    top_k=-1,
    seed=3407,
    stop=[tokenizer.eos_token],
    include_stop_str_in_output=True,
)

# GRPO training configuration
training_args = GRPOConfig(
    # vLLM
    vllm_sampling_params=vllm_sampling_params,
    temperature=1.0,  # For GRPO training dynamics
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=0.001,
    warmup_ratio=WARMUP_RATIO,
    lr_scheduler_type="linear",
    optim="adamw_8bit",
    # Batch settings
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    num_generations=NUM_GENERATIONS,
    # Sequence lengths
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    # Training duration
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    # Logging
    logging_steps=1,
    report_to="none",
    # Output
    output_dir=OUTPUT_DIR,
)

In [None]:
# Three reward configurations available:
# 1. FULL (recommended) - Uses robust_coherence + topic_relevance + depth (maximum robustness)
# 2. ROBUST - Uses NLI + self-consistency + structural analysis
# 3. LEGACY - Original shallow rewards (vulnerable to word soup)

REWARD_MODE = "FULL"  # Options: "FULL", "ROBUST", "LEGACY"

# Initialize wandb logging if enabled
wandb_run = None
sample_logger = None
logging_reward_func = None

if USE_WANDB:
    wandb_run = init_wandb_logging(
        project="marxist-grpo",
        config={
            "model": MODEL_NAME,
            "learning_rate": LEARNING_RATE,
            "batch_size": BATCH_SIZE,
            "gradient_accumulation": GRADIENT_ACCUMULATION,
            "num_generations": NUM_GENERATIONS,
            "max_steps": MAX_STEPS,
            "reward_mode": REWARD_MODE,
        },
        tags=["grpo", "marxist-leninist", "prolewiki"],
    )
    sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
    logging_reward_func = create_logging_reward(sample_logger, compute_all_rewards=True)
    print("[W&B] Logging enabled - reward metrics and sample tables will be logged")

if REWARD_MODE == "FULL":
    print("Initializing GRPO trainer with FULL reward functions:")
    print("  - match_format_exactly (+3.0 for </think>)")
    print("  - match_format_approximately (tag validation)")
    print("  - full_coherence_reward (NLI + structure + topic + depth)")
    print("  - completeness_reward (length comparison)")
    print("  - debug_print_reward (monitoring)")
    if USE_WANDB:
        print("  - logging_reward (W&B metrics + sample tables)")
    print("\nNote: First run will download NLI model (~1.6GB) + spaCy transformer (~436MB)")

    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        full_coherence_reward,
        completeness_reward,
        debug_print_reward,
    ]
elif REWARD_MODE == "ROBUST":
    print("Initializing GRPO trainer with ROBUST reward functions:")
    print("  - match_format_exactly (+3.0 for </think>)")
    print("  - match_format_approximately (tag validation)")
    print("  - robust_coherence_reward (NLI + self-consistency + structure)")
    print("  - completeness_reward (length comparison)")
    print("  - debug_print_reward (monitoring)")
    if USE_WANDB:
        print("  - logging_reward (W&B metrics + sample tables)")
    print("\nNote: First run will download NLI model (~1.6GB)")

    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        robust_coherence_reward,
        completeness_reward,
        debug_print_reward,
    ]
else:  # LEGACY
    print("Initializing GRPO trainer with LEGACY reward functions:")
    print("  - match_format_exactly (+3.0 for </think>)")
    print("  - match_format_approximately (tag validation)")
    print("  - semantic_similarity_reward (+5.0 to -3.0)")
    print("  - terminology_reward (+0 to +2.0) [VULNERABLE TO WORD SOUP]")
    print("  - completeness_reward (length comparison)")
    print("  - debug_print_reward (monitoring)")
    if USE_WANDB:
        print("  - logging_reward (W&B metrics + sample tables)")

    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        semantic_similarity_reward,
        terminology_reward,
        completeness_reward,
        debug_print_reward,
    ]

# Add wandb logging reward if enabled
if USE_WANDB and logging_reward_func is not None:
    reward_funcs.append(logging_reward_func)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset,
)

## Train!

Run GRPO training. Monitor the `reward` column - it should increase over time.

**Expected behavior:**
- Steps 0-50: Format rewards stabilize
- Steps 50-100: Semantic similarity improves
- Steps 100-250: Content quality improves

**Estimated time:** ~1-2 hours on A40

In [None]:
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Steps: {MAX_STEPS}")
print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
print(f"Learning rate: {LEARNING_RATE}")
if USE_WANDB:
    print(f"W&B Run: {wandb_run.name if wandb_run else 'N/A'}")
print()

trainer.train()

# Finish wandb logging
if USE_WANDB:
    finish_wandb_logging(
        summary={
            "final_step": MAX_STEPS,
            "reward_mode": REWARD_MODE,
        }
    )

## Save LoRA

Save the trained LoRA adapter.

In [None]:
import os

os.makedirs(LORA_OUTPUT, exist_ok=True)

model.save_lora(LORA_OUTPUT)
print(f"LoRA saved to: {LORA_OUTPUT}")

In [None]:
# Verify LoRA is actually trained (non-zero weights)
from safetensors import safe_open

tensors = {}
with safe_open(f"{LORA_OUTPUT}/adapter_model.safetensors", framework="pt") as f:
    for key in f:
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert n_zeros.item() != tensor.numel(), f"Layer {key} is all zeros!"

print("LoRA verification passed - all layers have non-zero weights")

## Inference Testing

Test the model with and without the trained LoRA adapter.

In [None]:
TEST_QUESTIONS = [
    "What is revisionism in the Marxist sense?",
    "Explain the concept of surplus value.",
    "What is the dictatorship of the proletariat?",
    "How does dialectical materialism differ from idealism?",
]

test_sampling_params = SamplingParams(
    temperature=0.7,
    top_k=50,
    max_tokens=1024,
)

In [None]:
print("=" * 60)
print("TESTING WITHOUT LORA")
print("=" * 60)

for question in TEST_QUESTIONS[:2]:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]
    text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    output = (
        model.fast_generate(text, sampling_params=test_sampling_params, lora_request=None)[0]
        .outputs[0]
        .text
    )

    print(f"\nQ: {question}")
    print(f"A: {output[:500]}...")

In [None]:
print("=" * 60)
print("TESTING WITH LORA")
print("=" * 60)

for question in TEST_QUESTIONS[:2]:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]
    text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    output = (
        model.fast_generate(
            text,
            sampling_params=test_sampling_params,
            lora_request=model.load_lora(LORA_OUTPUT),
        )[0]
        .outputs[0]
        .text
    )

    print(f"\nQ: {question}")
    print(f"A: {output[:500]}...")

## Save & Export

Options for saving and exporting the trained model.

In [None]:
# Save to 16bit for VLLM deployment
if False:
    model.save_pretrained_merged("model-16bit", tokenizer, save_method="merged_16bit")

# Save to 4bit
if False:
    model.save_pretrained_merged("model-4bit", tokenizer, save_method="merged_4bit")

In [None]:
# Save to GGUF for llama.cpp / Ollama
if False:
    # Q8_0 - Fast conversion, high quality
    model.save_pretrained_gguf("model-gguf", tokenizer)

if False:
    # Q4_K_M - Recommended balance of size/quality
    model.save_pretrained_gguf("model-gguf", tokenizer, quantization_method="q4_k_m")

if False:
    # Multiple quantizations at once
    model.save_pretrained_gguf(
        "model-gguf",
        tokenizer,
        quantization_method=["q4_k_m", "q8_0", "q5_k_m"],
    )

## Ollama Integration

To use with Ollama after GGUF export:

```bash
# Create Modelfile
cat > Modelfile << 'EOF'
FROM ./model-gguf-Q4_K_M.gguf

SYSTEM """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
Think through political theory questions using dialectical materialist analysis.
Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""

PARAMETER temperature 0.7
PARAMETER top_k 50
PARAMETER top_p 0.9
EOF

# Create and run
ollama create marxist-assistant -f Modelfile
ollama run marxist-assistant
```

## Training Complete!

**Next steps:**
1. Test the model with various political theory questions
2. Export to GGUF if satisfied with results
3. Create Ollama Modelfile for deployment
4. Consider extended training (more steps) for better results

---

**Resources:**
- [Unsloth Documentation](https://docs.unsloth.ai/)
- [TRL GRPO Documentation](https://huggingface.co/docs/trl/main/en/grpo_trainer)
- [ProleWiki](https://en.prolewiki.org/)