# Marxist-Leninist GRPO Training (TRL Direct)

This notebook fine-tunes **DeepSeek-R1-0528-Qwen3-8B** on the ProleWiki corpus using GRPO (Group Relative Policy Optimization).

**This version uses TRL directly** without Unsloth, avoiding torch.compile issues that cause hangs on RunPod/Jupyter.

**Goal:** Train the model to reason through political theory questions using dialectical materialist analysis, showing reasoning in `<think>` tags.

**Hardware:** Optimized for A40 (48GB VRAM)

**Dataset:** 1,058 Q&A pairs from ProleWiki covering:
- Revisionism and opportunism
- Dialectical and historical materialism
- Anti-colonial theory (Fanon, Rodney, Nkrumah)
- Revolutionary theory (Jackson, Sankara, PFLP)
- Marxist political economy

---

**Stack:**
- [TRL GRPOTrainer](https://huggingface.co/docs/trl/main/en/grpo_trainer) - HuggingFace's RL training library
- [PEFT](https://huggingface.co/docs/peft) - Parameter-Efficient Fine-Tuning (LoRA)
- [vLLM](https://github.com/vllm-project/vllm) - Fast inference for generation
- [transformers](https://huggingface.co/docs/transformers) - Model loading

## Installation

In [None]:
%%capture
# Install dependencies
!pip install torch torchvision torchaudio
!pip install transformers accelerate
!pip install trl[vllm]  # TRL with vLLM support
!pip install peft bitsandbytes
!pip install datasets
!pip install sentence-transformers numpy

# spaCy with TRANSFORMER model for better semantic understanding
!pip install spacy spacy-curated-transformers
!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl

print("Installation complete!")

## Load Model

We load the DeepSeek-R1-0528-Qwen3-8B model using standard HuggingFace transformers.

**Key differences from Unsloth:**
- Uses `AutoModelForCausalLM` instead of `FastLanguageModel`
- No torch.compile optimization (avoids hanging)
- Standard PyTorch gradient checkpointing

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Configuration
MODEL_NAME = "unsloth/DeepSeek-R1-0528-Qwen3-8B"  # Or use deepseek-ai/ version
MAX_SEQ_LENGTH = 2048
LORA_RANK = 32

# Check GPU
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "left"  # Required for GRPO
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
print(f"Vocab size: {len(tokenizer)}")

In [None]:
# Load model in bfloat16 (16-bit for GRPO)
# For 4-bit quantization, uncomment the bnb_config below

# Optional: 4-bit quantization (saves VRAM but may affect quality)
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
# )

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    # quantization_config=bnb_config,  # Uncomment for 4-bit
)

print(f"Model loaded: {model.config.model_type}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## Apply LoRA (Low-Rank Adaptation)

Using PEFT library directly instead of Unsloth's wrapper.

**Gradient Checkpointing:** We use standard PyTorch `gradient_checkpointing_enable()` which is stable across all environments (no torch.compile).

In [None]:
from peft import LoraConfig, TaskType, get_peft_model

lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_RANK,  # Same as r for GRPO (scaling = 1.0)
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Enable gradient checkpointing (standard PyTorch, no torch.compile)
model.gradient_checkpointing_enable()

# Print trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
model.print_trainable_parameters()

## Chat Template

Verify the chat template works correctly for DeepSeek-R1.

In [None]:
# Our system prompt for Marxist-Leninist reasoning
SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
Think through political theory questions using dialectical materialist analysis.
Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""

# Test chat template
test_messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "What is revisionism?"},
]

print("Chat template test:")
print(tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True))

## Load Dataset

Load the GRPO-formatted dataset from `grpo_dataset.jsonl`.

In [None]:
from pathlib import Path

from datasets import Dataset

DATA_PATH = Path("grpo_dataset.jsonl")

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Dataset not found: {DATA_PATH}\n" "Run 'python transform_to_grpo.py' first!"
    )

dataset = Dataset.from_json(str(DATA_PATH))
print(f"Loaded {len(dataset)} examples")

# Show sample
sample = dataset[0]
print(f"\nSample prompt: {sample['prompt'][1]['content'][:80]}...")
print(f"Sample answer: {sample['answer'][:80]}...")

## Reward Functions

GRPO uses reward functions to guide the model toward desired behaviors.

These are the same reward functions from the Unsloth notebook - they're already TRL-compatible!

In [None]:
import re

from sentence_transformers import SentenceTransformer

# Reasoning format tokens
REASONING_START = "<think>"
REASONING_END = "</think>"

# Regex to match format
SOLUTION_END_REGEX = re.compile(rf"{REASONING_END}(.*)", re.DOTALL)

# Lazy-load models to avoid loading at import time
_embedder = None
_nli_pipeline = None
_spacy_nlp = None


def get_embedder():
    """Get or initialize the sentence transformer embedder."""
    global _embedder
    if _embedder is None:
        print("[Reward] Loading sentence-transformers embedder...")
        _embedder = SentenceTransformer("all-MiniLM-L6-v2")
    return _embedder


def get_nli_pipeline():
    """Get or initialize the NLI pipeline (BART-large-MNLI)."""
    global _nli_pipeline
    if _nli_pipeline is None:
        print("[Reward] Loading NLI model (bart-large-mnli)...")
        import torch
        from transformers import pipeline

        device = "cuda" if torch.cuda.is_available() else "cpu"
        _nli_pipeline = pipeline(
            "text-classification",
            model="facebook/bart-large-mnli",
            device=device,
        )
    return _nli_pipeline


def get_spacy_nlp():
    """Get or initialize spaCy NLP pipeline."""
    global _spacy_nlp
    if _spacy_nlp is None:
        import spacy

        models_to_try = ["en_core_web_trf", "en_core_web_md", "en_core_web_sm"]

        for model_name in models_to_try:
            try:
                print(f"[Reward] Loading spaCy model: {model_name}...")
                _spacy_nlp = spacy.load(model_name)
                print(f"[Reward] Loaded {model_name} successfully")
                break
            except OSError:
                print(f"[Reward] {model_name} not found, trying next...")
                continue

        if _spacy_nlp is None:
            raise OSError("No spaCy model found!")
    return _spacy_nlp

In [None]:
# Marxist terminology for vocabulary reward
MARXIST_TERMS = {
    # Core concepts
    "dialectical",
    "materialism",
    "historical materialism",
    "dialectical materialism",
    # Classes
    "bourgeoisie",
    "proletariat",
    "petty bourgeois",
    "petty bourgeoisie",
    "lumpenproletariat",
    "working class",
    "ruling class",
    # Class struggle
    "class struggle",
    "class consciousness",
    "class war",
    "class conflict",
    # Political economy
    "surplus value",
    "commodity",
    "use value",
    "exchange value",
    "labor power",
    "means of production",
    "relations of production",
    "forces of production",
    "mode of production",
    "primitive accumulation",
    "exploitation",
    "capital accumulation",
    # Imperialism
    "imperialism",
    "colonialism",
    "neo-colonialism",
    "settler colonialism",
    "national liberation",
    "self-determination",
    # State and revolution
    "dictatorship of the proletariat",
    "vanguard",
    "vanguard party",
    "democratic centralism",
    "withering away of the state",
    "proletarian dictatorship",
    # Ideology
    "hegemony",
    "superstructure",
    "base",
    "ideology",
    "false consciousness",
    # Revisionism
    "revisionism",
    "opportunism",
    "reformism",
    "social democracy",
    "ultra-leftism",
    # Alienation
    "alienation",
    "fetishism",
    "commodity fetishism",
    "reification",
    # Historical
    "paris commune",
    "october revolution",
    "bolshevik",
    "menshevik",
    # Anti-colonial
    "decolonization",
    "third world",
    "global south",
    "national bourgeoisie",
    "comprador",
}

# Discourse connectives indicating logical structure
DISCOURSE_CONNECTIVES = {
    "because",
    "therefore",
    "thus",
    "hence",
    "consequently",
    "however",
    "although",
    "whereas",
    "nevertheless",
    "moreover",
    "furthermore",
    "additionally",
    "specifically",
    "namely",
    "as a result",
    "due to",
    "in order to",
    "so that",
    "on the other hand",
    "in contrast",
    "similarly",
    "likewise",
}

In [None]:
def match_format_exactly(completions, **kwargs):
    """
    Reward +3.0 if response contains proper </think> tag.
    This encourages the model to use the reasoning format.
    """
    scores = []
    for completion in completions:
        # Handle both string and conversational formats
        if isinstance(completion, list):
            response = completion[0]["content"] if completion else ""
        else:
            response = completion

        score = 3.0 if SOLUTION_END_REGEX.search(response) else 0.0
        scores.append(score)
    return scores


def match_format_approximately(completions, **kwargs):
    """
    Reward partial format matching.
    +0.5 for exactly one <think> tag
    +0.5 for exactly one </think> tag
    -1.0 for multiple or missing tags
    """
    scores = []
    for completion in completions:
        if isinstance(completion, list):
            response = completion[0]["content"] if completion else ""
        else:
            response = completion

        score = 0.0
        start_count = response.count(REASONING_START)
        end_count = response.count(REASONING_END)

        score += 0.5 if start_count == 1 else -1.0
        score += 0.5 if end_count == 1 else -1.0

        scores.append(score)
    return scores


def completeness_reward(completions, answer, **kwargs):
    """
    Reward thorough, detailed responses.
    Compares response length to ground truth length.
    """
    scores = []

    for completion, true_answer in zip(completions, answer, strict=False):
        if isinstance(completion, list):
            response = completion[0]["content"] if completion else ""
        else:
            response = completion

        # Extract answer after </think>
        if REASONING_END in response:
            answer_part = response.split(REASONING_END, 1)[1].strip()
        else:
            answer_part = response

        answer_len = len(answer_part.split())
        true_len = len(true_answer.split())

        if true_len == 0:
            scores.append(0.0)
            continue

        ratio = answer_len / true_len

        if 0.5 <= ratio <= 1.5:
            score = 2.0
        elif 0.3 <= ratio <= 2.0:
            score = 1.0
        elif ratio < 0.2:
            score = -2.0
        else:
            score = -0.5

        scores.append(score)

    return scores

In [None]:
def nli_coherence_reward(completions, answer, **kwargs):
    """
    Reward responses that logically ENTAIL the ground truth answer.
    Uses Natural Language Inference (facebook/bart-large-mnli).
    """
    nli = get_nli_pipeline()
    scores = []

    for completion, true_answer in zip(completions, answer, strict=False):
        if isinstance(completion, list):
            response = completion[0]["content"] if completion else ""
        else:
            response = completion

        # Extract answer part after </think>
        if REASONING_END in response:
            response = response.split(REASONING_END, 1)[1].strip()

        # Handle empty or very short responses
        if not response or len(response.strip()) < 20:
            scores.append(-2.0)
            continue

        # Truncate to model max length
        response_truncated = response[:512]
        truth_truncated = true_answer[:512]

        try:
            input_text = f"{response_truncated}</s></s>{truth_truncated}"
            result = nli(input_text)[0]
            label = result["label"].lower()

            if label == "entailment":
                score = 3.0
            elif label == "neutral":
                score = -1.0
            else:  # contradiction
                score = -3.0

            scores.append(score)
        except Exception as e:
            print(f"[NLI Reward] Error: {e}")
            scores.append(0.0)

    return scores


def structural_coherence_reward(completions, **kwargs):
    """
    Reward responses with proper linguistic structure.
    Uses spaCy dependency parsing.
    """
    nlp = get_spacy_nlp()
    scores = []

    for completion in completions:
        if isinstance(completion, list):
            response = completion[0]["content"] if completion else ""
        else:
            response = completion

        doc = nlp(response)
        score = 0.0

        # Check 1: Are there actual sentences?
        sentences = list(doc.sents)
        if len(sentences) < 1:
            scores.append(-1.0)
            continue

        # Check 2: Marxist terms in meaningful syntactic roles
        terms_in_context = 0
        response_lower = response.lower()

        for term in MARXIST_TERMS:
            if term not in response_lower:
                continue

            for token in doc:
                is_meaningful_role = token.dep_ in (
                    "nsubj",
                    "nsubjpass",
                    "dobj",
                    "pobj",
                    "attr",
                    "appos",
                )
                is_verb_root = token.head.pos_ == "VERB" and token.head.dep_ == "ROOT"
                if term in token.text.lower() and (is_meaningful_role or is_verb_root):
                    terms_in_context += 1
                    break

        score += min(terms_in_context * 0.3, 1.5)

        # Check 3: Discourse connectives
        connective_count = sum(1 for conn in DISCOURSE_CONNECTIVES if conn in response_lower)
        score += min(connective_count * 0.2, 1.0)

        scores.append(score)

    return scores

In [None]:
def full_coherence_reward(prompts, completions, answer, **kwargs):
    """
    Complete coherence check combining:
    1. NLI coherence (response entails ground truth)
    2. Structural coherence (terms in proper syntactic roles)

    This is the RECOMMENDED reward function for robust evaluation.
    """
    nli_scores = nli_coherence_reward(completions, answer, **kwargs)
    structure_scores = structural_coherence_reward(completions, **kwargs)

    combined = []
    for nli, structure in zip(nli_scores, structure_scores, strict=False):
        if nli <= -3.0:
            combined.append(-3.0)  # Contradiction dominates
        else:
            total = nli + (structure * 0.5)
            combined.append(total)

    return combined

In [None]:
# Debug reward function for monitoring during training
_PRINT_COUNTER = 0
_PRINT_EVERY = 10


def debug_print_reward(prompts, completions, answer, **kwargs):
    """
    Print sample outputs periodically for monitoring.
    Returns 0.0 (no effect on training).
    """
    global _PRINT_COUNTER

    if _PRINT_COUNTER % _PRINT_EVERY == 0:
        # Handle conversational format (use ternary for ruff SIM108)
        question = prompts[0][-1]["content"] if isinstance(prompts[0], list) else prompts[0]
        response = (
            completions[0][0]["content"] if isinstance(completions[0], list) else completions[0]
        )
        true_answer = answer[0]

        print("=" * 60)
        print(f"Step {_PRINT_COUNTER}")
        print(f"Question: {question[:100]}...")
        print(f"Response: {response[:200]}...")
        print(f"Expected: {true_answer[:100]}...")
        print("=" * 60)

    _PRINT_COUNTER += 1

    return [0.0] * len(completions)

## Training Configuration

Configure GRPO training using TRL's `GRPOConfig`.

**Key differences from Unsloth:**
- Uses `GRPOConfig` directly instead of wrapper
- vLLM configured via `use_vllm` and `vllm_mode`
- No torch.compile optimizations

In [None]:
from trl import GRPOConfig, GRPOTrainer

# Training hyperparameters
MAX_STEPS = 250
SAVE_STEPS = 50
LEARNING_RATE = 5e-6
WARMUP_RATIO = 0.1

# A40 optimized batch settings
BATCH_SIZE = 2
GRADIENT_ACCUMULATION = 2
NUM_GENERATIONS = 4

# Sequence lengths
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 1500

# Output paths
OUTPUT_DIR = "outputs/marxist-grpo-trl"
LORA_OUTPUT = "outputs/marxist-grpo-trl-lora"

In [None]:
# GRPO training configuration
training_args = GRPOConfig(
    # Output
    output_dir=OUTPUT_DIR,
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=0.001,
    warmup_ratio=WARMUP_RATIO,
    lr_scheduler_type="linear",
    optim="adamw_torch",  # Use standard AdamW (not 8bit without bitsandbytes setup)
    # Batch settings
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    num_generations=NUM_GENERATIONS,
    # Sequence lengths
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    # Training duration
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    # Logging
    logging_steps=1,
    report_to="none",
    # GRPO specific
    temperature=1.0,  # For GRPO training dynamics
    beta=0.0,  # No KL penalty (modern default)
    scale_rewards=True,
    # vLLM for fast generation
    use_vllm=True,
    vllm_mode="colocate",  # Share GPU with training model
    vllm_gpu_memory_utilization=0.5,  # Conservative to leave room for training
    # Precision
    bf16=True,
)

print("Training configuration:")
print(
    f"  Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}"
)
print(f"  Generations per prompt: {NUM_GENERATIONS}")
print(f"  Max steps: {MAX_STEPS}")
print("  vLLM mode: colocate")

In [None]:
# Define reward functions
print("Initializing GRPO trainer with reward functions:")
print("  - match_format_exactly (+3.0 for </think>)")
print("  - match_format_approximately (tag validation)")
print("  - full_coherence_reward (NLI + structure)")
print("  - completeness_reward (length comparison)")
print("  - debug_print_reward (monitoring)")
print("\nNote: First run will download NLI model (~1.6GB) + spaCy transformer (~436MB)")

reward_funcs = [
    match_format_exactly,
    match_format_approximately,
    full_coherence_reward,
    completeness_reward,
    debug_print_reward,
]

# Create trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset,
)

## Train!

Run GRPO training. Monitor the `reward` column - it should increase over time.

**Expected behavior:**
- Steps 0-50: Format rewards stabilize
- Steps 50-100: Semantic similarity improves
- Steps 100-250: Content quality improves

**Estimated time:** ~1-2 hours on A40

In [None]:
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Steps: {MAX_STEPS}")
print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
print(f"Learning rate: {LEARNING_RATE}")
print()

trainer.train()

## Save LoRA

Save the trained LoRA adapter.

In [None]:
import os

os.makedirs(LORA_OUTPUT, exist_ok=True)

# Save the PEFT adapter
model.save_pretrained(LORA_OUTPUT)
tokenizer.save_pretrained(LORA_OUTPUT)

print(f"LoRA saved to: {LORA_OUTPUT}")

In [None]:
# Verify LoRA is actually trained (non-zero weights)
import os

from safetensors import safe_open

adapter_path = os.path.join(LORA_OUTPUT, "adapter_model.safetensors")
if os.path.exists(adapter_path):
    with safe_open(adapter_path, framework="pt") as f:
        for key in list(f.keys())[:5]:  # Check first 5 layers
            tensor = f.get_tensor(key)
            n_nonzero = (tensor != 0).sum().item()
            print(f"{key}: {n_nonzero}/{tensor.numel()} non-zero")
    print("\nLoRA verification passed!")
else:
    print(f"Adapter not found at {adapter_path}")

## Inference Testing

Test the model with the trained LoRA adapter.

In [None]:
TEST_QUESTIONS = [
    "What is revisionism in the Marxist sense?",
    "Explain the concept of surplus value.",
    "What is the dictatorship of the proletariat?",
    "How does dialectical materialism differ from idealism?",
]

# Put model in eval mode
model.eval()

print("=" * 60)
print("TESTING WITH TRAINED LORA")
print("=" * 60)

for question in TEST_QUESTIONS[:2]:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]

    text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_k=50,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )

    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True)

    print(f"\nQ: {question}")
    print(f"A: {response[:500]}...")

## Merge and Export (Optional)

Merge LoRA weights into base model for deployment.

In [None]:
# Merge LoRA into base model
if False:  # Set to True to merge
    merged_model = model.merge_and_unload()
    merged_model.save_pretrained("outputs/marxist-grpo-merged")
    tokenizer.save_pretrained("outputs/marxist-grpo-merged")
    print("Merged model saved!")

## Training Complete!

**Next steps:**
1. Test the model with various political theory questions
2. Merge LoRA if satisfied with results
3. Convert to GGUF for Ollama deployment (use `llama.cpp`)

---

**Resources:**
- [TRL GRPOTrainer Documentation](https://huggingface.co/docs/trl/main/en/grpo_trainer)
- [PEFT Documentation](https://huggingface.co/docs/peft)
- [ProleWiki](https://en.prolewiki.org/)