# BitNet b1.58 Training - Microsoft Architecture

This notebook implements **proper BitNet training** following Microsoft's approach from:
- [BitNet b1.58 Paper](https://arxiv.org/abs/2402.17764)
- [BitNet b1.58 2B4T Technical Report](https://arxiv.org/abs/2504.12285)
- [Microsoft BitNet GitHub](https://github.com/microsoft/BitNet)

## Key Architecture Differences from Standard LLMs

| Component | **Microsoft BitNet** | **Standard LLaMA** |
|-----------|---------------------|-------------------|
| FFN Activation | **Squared ReLU (ReLU²)** | SwiGLU |
| Normalization | **SubLN** (post-norm after residual) | Pre-norm only |
| Bias Terms | **None** | Optional |
| Weight Decay | **0.0** (none!) | 0.01-0.1 |
| Weight Quantization | Ternary {-1, 0, +1} via absmean | Full precision |
| Activation Quantization | INT8 per-token | Full precision |

## Training Pipeline
1. **Pre-training**: Large-scale next-token prediction
2. **SFT**: Supervised fine-tuning on instruction data
3. **DPO** (optional): Direct preference optimization

In [None]:
# @title Install Dependencies
!pip install -q torch transformers datasets tqdm matplotlib

In [None]:
# @title Configuration
# @markdown Model size preset
MODEL_SIZE = "125M"  # @param ["125M", "350M", "1B", "3B"]

# @markdown Training settings
MAX_STEPS = 10000  # @param {type:"integer"}
BATCH_SIZE = 8  # @param {type:"integer"}
MAX_SEQ_LENGTH = 512  # @param {type:"integer"}
LEARNING_RATE = 1e-3  # @param {type:"number"}

# @markdown Data source
DATA_SOURCE = "cognitive_kernel"  # @param ["cognitive_kernel", "swe_rebench", "huggingface"]
HF_DATASET = "roneneldan/TinyStories"  # @param {type:"string"}

# Model configs
MODEL_CONFIGS = {
    "125M": {"hidden_dim": 768, "num_layers": 12, "num_heads": 12, "mlp_ratio": 4},
    "350M": {"hidden_dim": 1024, "num_layers": 24, "num_heads": 16, "mlp_ratio": 4},
    "1B": {"hidden_dim": 2048, "num_layers": 24, "num_heads": 16, "mlp_ratio": 4},
    "3B": {"hidden_dim": 2560, "num_layers": 32, "num_heads": 20, "mlp_ratio": 4},
}

config = MODEL_CONFIGS[MODEL_SIZE]
print(f"Model config: {MODEL_SIZE}")
print(f"  Hidden dim: {config['hidden_dim']}")
print(f"  Layers: {config['num_layers']}")
print(f"  Heads: {config['num_heads']}")

In [None]:
# @title Core BitNet Implementation (Microsoft-style)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


# =============================================================================
# Straight-Through Estimators for Quantization
# =============================================================================

class STETernary(torch.autograd.Function):
    """STE for ternary quantization to {-1, 0, +1}."""
    @staticmethod
    def forward(ctx, x):
        return torch.clamp(torch.round(x), -1, 1)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output  # Straight-through: pass gradient unchanged


class STERound(torch.autograd.Function):
    """STE for INT8 rounding."""
    @staticmethod
    def forward(ctx, x):
        return torch.round(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


# =============================================================================
# BitLinear - The Core BitNet Layer
# =============================================================================

class BitLinear(nn.Module):
    """
    BitNet b1.58 Linear Layer.
    
    Key differences from nn.Linear:
    - Weights quantized to {-1, 0, +1} using absmean scaling
    - Activations quantized to INT8 using absmax scaling (per-token)
    - NO BIAS (Microsoft BitNet has no bias anywhere)
    """
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Latent weights in FP32 (updated by optimizer)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        
        # NO BIAS - critical for BitNet!
        self.register_parameter('bias', None)
        
        # Initialize
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # === Weight Quantization (absmean scaling) ===
        # Scale = mean(|W|) - better than max for ternary
        w_scale = self.weight.abs().mean() + 1e-8
        w_normalized = self.weight / w_scale
        w_quant = STETernary.apply(w_normalized)  # {-1, 0, +1}
        
        # === Activation Quantization (absmax per-token) ===
        a_scale = x.abs().amax(dim=-1, keepdim=True) + 1e-8
        a_normalized = x / a_scale * 127.0
        a_quant = STERound.apply(torch.clamp(a_normalized, -128, 127))
        
        # === Compute and Rescale ===
        y = F.linear(a_quant, w_quant, None)
        y = y * (w_scale * a_scale / 127.0)
        
        return y


# =============================================================================
# RMSNorm - Root Mean Square Normalization
# =============================================================================

class RMSNorm(nn.Module):
    """RMSNorm without bias (Microsoft BitNet style)."""
    
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        # NO BIAS
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.weight


# =============================================================================
# Rotary Position Embeddings (RoPE)
# =============================================================================

def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
    """Precompute rotary embedding frequencies."""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    """Apply rotary embeddings to queries and keys."""
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = freqs_cis[:xq.shape[1]].unsqueeze(0).unsqueeze(2)
    
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)


# =============================================================================
# BitNet Transformer Block with SubLN and ReLU²
# =============================================================================

class BitNetBlock(nn.Module):
    """
    BitNet Transformer Block following Microsoft architecture:
    - SubLN (post-norm after residual)
    - Squared ReLU (ReLU²) activation instead of SwiGLU
    - No bias anywhere
    """
    
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        max_seq_len: int = 4096,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.mlp_dim = int(hidden_dim * mlp_ratio)
        
        # === Attention ===
        self.q_proj = BitLinear(hidden_dim, hidden_dim)
        self.k_proj = BitLinear(hidden_dim, hidden_dim)
        self.v_proj = BitLinear(hidden_dim, hidden_dim)
        self.o_proj = BitLinear(hidden_dim, hidden_dim)
        
        # === FFN with Squared ReLU ===
        # Microsoft uses ReLU² instead of SwiGLU for BitNet!
        self.up_proj = BitLinear(hidden_dim, self.mlp_dim)
        self.down_proj = BitLinear(self.mlp_dim, hidden_dim)
        
        # === SubLN (post-norm after residual) ===
        # This is different from standard pre-norm!
        self.attn_norm = RMSNorm(hidden_dim)
        self.ffn_norm = RMSNorm(hidden_dim)
        
        # RoPE
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(self.head_dim, max_seq_len),
            persistent=False
        )
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch, seq_len, _ = x.shape
        
        # === Attention ===
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Apply RoPE
        q, k = apply_rotary_emb(q, k, self.freqs_cis.to(x.device))
        
        # Scaled dot-product attention
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(causal_mask, float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_dim)
        out = self.o_proj(out)
        
        # === SubLN: Add residual THEN normalize ===
        x = self.attn_norm(x + out)
        
        # === FFN with Squared ReLU ===
        # ReLU²(x) = ReLU(x)² - simpler and works better with ternary weights!
        ffn_out = self.up_proj(x)
        ffn_out = F.relu(ffn_out) ** 2  # Squared ReLU!
        ffn_out = self.down_proj(ffn_out)
        
        # === SubLN again ===
        x = self.ffn_norm(x + ffn_out)
        
        return x


# =============================================================================
# Full BitNet Model
# =============================================================================

class BitNetLM(nn.Module):
    """BitNet Language Model following Microsoft architecture."""
    
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        max_seq_len: int = 4096,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        # Token embeddings (full precision)
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        
        # Transformer blocks
        self.layers = nn.ModuleList([
            BitNetBlock(hidden_dim, num_heads, mlp_ratio, max_seq_len)
            for _ in range(num_layers)
        ])
        
        # Final norm
        self.norm = RMSNorm(hidden_dim)
        
        # LM head (tied to embeddings)
        self.lm_head = None  # Will use embed_tokens.weight
        
        # Initialize embeddings
        nn.init.normal_(self.embed_tokens.weight, std=0.02)
    
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embed_tokens(input_ids)
        
        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x)
        
        # Tied LM head
        logits = F.linear(x, self.embed_tokens.weight)
        
        return logits
    
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.7):
        """Simple greedy/sampling generation."""
        self.eval()
        with torch.no_grad():
            for _ in range(max_new_tokens):
                logits = self.forward(input_ids)
                next_logits = logits[:, -1, :] / temperature
                probs = F.softmax(next_logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
        return input_ids
    
    def num_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())


print("BitNet architecture defined successfully!")
print("Key features:")
print("  - Squared ReLU (ReLU²) activation")
print("  - SubLN (post-norm after residual)")
print("  - No bias anywhere")
print("  - Absmean weight quantization")
print("  - Per-token INT8 activation quantization")

In [None]:
# @title Data Loading

import json
import os
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm


class MultiFormatDataset(Dataset):
    """
    Dataset supporting multiple formats:
    - cognitive_kernel: {"prompt": ..., "response": ..., "type": "cognitive_kernel"}
    - swe_rebench: {"prompt": ..., "response": ...}
    - plain: {"text": ...}
    - messages: {"messages": [{"role": ..., "content": ...}, ...]}
    """
    
    def __init__(self, paths: list, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = []
        
        for path in paths:
            if not os.path.exists(path):
                print(f"Warning: {path} not found")
                continue
            
            print(f"Loading: {path}")
            with open(path, 'r') as f:
                for line in f:
                    try:
                        item = json.loads(line.strip())
                        text = self._extract_text(item)
                        if text and len(text) > 20:
                            self.samples.append(text)
                    except json.JSONDecodeError:
                        continue
        
        print(f"Loaded {len(self.samples)} samples")
    
    def _extract_text(self, item: dict) -> str:
        """Extract training text from various formats."""
        # Cognitive kernel format
        if 'prompt' in item and 'response' in item:
            return f"### Instruction:\n{item['prompt']}\n\n### Response:\n{item['response']}"
        
        # Plain text
        if 'text' in item:
            return item['text']
        
        # Messages format
        if 'messages' in item:
            parts = []
            for msg in item['messages']:
                role = msg.get('role', 'user')
                content = msg.get('content', '')
                parts.append(f"### {role.title()}:\n{content}")
            return "\n\n".join(parts)
        
        return ""
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text = self.samples[idx]
        tokens = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt',
        )
        input_ids = tokens['input_ids'].squeeze(0)
        return {'input_ids': input_ids}


def load_from_huggingface(dataset_name: str, tokenizer, max_length: int, num_samples: int = 10000):
    """Load dataset from HuggingFace."""
    from datasets import load_dataset
    
    print(f"Loading from HuggingFace: {dataset_name}")
    dataset = load_dataset(dataset_name, split="train", streaming=True)
    
    samples = []
    for i, item in enumerate(tqdm(dataset, total=num_samples)):
        if i >= num_samples:
            break
        
        # Get text
        if 'text' in item:
            text = item['text']
        elif 'content' in item:
            text = item['content']
        else:
            continue
        
        if text and len(text) > 20:
            samples.append(text)
    
    print(f"Loaded {len(samples)} samples from HuggingFace")
    return samples


print("Data loading utilities defined!")

In [None]:
# @title Load Tokenizer and Create Model

from transformers import AutoTokenizer

# Use LLaMA-3 tokenizer (128k vocab) like Microsoft
# Fall back to LLaMA-2 if not available
try:
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
    print("Loaded LLaMA-3 tokenizer")
except:
    tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
    print("Loaded LLaMA-2 tokenizer (fallback)")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

VOCAB_SIZE = tokenizer.vocab_size
print(f"Vocab size: {VOCAB_SIZE}")

# Create model
model = BitNetLM(
    vocab_size=VOCAB_SIZE,
    hidden_dim=config['hidden_dim'],
    num_layers=config['num_layers'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    max_seq_len=MAX_SEQ_LENGTH * 2,
).to(device)

num_params = model.num_parameters()
print(f"\nModel created: {num_params / 1e6:.1f}M parameters")
print(f"Memory footprint: {num_params * 4 / 1e9:.2f} GB (FP32 training)")
print(f"Inference footprint: {num_params * 1.58 / 8 / 1e9:.2f} GB (1.58-bit)")

In [None]:
# @title Load Training Data

# Define data paths based on source
if DATA_SOURCE == "cognitive_kernel":
    data_paths = [
        "../data/distillation/cognitive_kernel_v2.jsonl",
        "../data/distillation/train_full_10k.jsonl",
    ]
    dataset = MultiFormatDataset(data_paths, tokenizer, MAX_SEQ_LENGTH)
    
elif DATA_SOURCE == "swe_rebench":
    data_paths = [
        "../data/distillation/swe_rebench_verified.jsonl",
        "../data/distillation/swe_rebench_10k.jsonl",
    ]
    dataset = MultiFormatDataset(data_paths, tokenizer, MAX_SEQ_LENGTH)
    
else:  # HuggingFace
    samples = load_from_huggingface(HF_DATASET, tokenizer, MAX_SEQ_LENGTH)
    
    # Create simple dataset
    class SimpleDataset(Dataset):
        def __init__(self, samples, tokenizer, max_length):
            self.samples = samples
            self.tokenizer = tokenizer
            self.max_length = max_length
        
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            tokens = self.tokenizer(
                self.samples[idx],
                truncation=True,
                max_length=self.max_length,
                padding='max_length',
                return_tensors='pt',
            )
            return {'input_ids': tokens['input_ids'].squeeze(0)}
    
    dataset = SimpleDataset(samples, tokenizer, MAX_SEQ_LENGTH)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True,
)

print(f"\nDataset size: {len(dataset)} samples")
print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(dataloader)}")

In [None]:
# @title Training Loop (Microsoft-style: NO WEIGHT DECAY!)

import time
import matplotlib.pyplot as plt

# === CRITICAL: NO WEIGHT DECAY ===
# Microsoft uses weight_decay=0.0 for BitNet!
# Weight decay + ternary quantization = weight collapse
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.0,  # CRITICAL: No weight decay!
)

# Learning rate scheduler with warmup
def get_lr(step, warmup_steps=100, max_steps=MAX_STEPS):
    if step < warmup_steps:
        return LEARNING_RATE * step / warmup_steps
    # Cosine decay
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return LEARNING_RATE * 0.5 * (1 + math.cos(math.pi * progress))


def check_weight_stats(model):
    """Check weight statistics to detect collapse."""
    stats = {}
    for name, param in model.named_parameters():
        if 'weight' in name and param.dim() >= 2:
            stats[name] = {
                'std': param.std().item(),
                'abs_mean': param.abs().mean().item(),
            }
    
    # Check for collapse (std < 0.001)
    collapsed = [k for k, v in stats.items() if v['std'] < 0.001]
    healthy = [k for k, v in stats.items() if v['std'] >= 0.001]
    
    return {
        'collapsed': len(collapsed),
        'healthy': len(healthy),
        'total': len(stats),
        'avg_std': sum(v['std'] for v in stats.values()) / len(stats) if stats else 0,
    }


# Training
print("\n" + "=" * 60)
print("TRAINING BITNET (Microsoft-style)")
print("=" * 60)
print(f"Steps: {MAX_STEPS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: 0.0 (CRITICAL for BitNet!)")
print("=" * 60)

model.train()
losses = []
data_iter = iter(dataloader)
start_time = time.time()

progress = tqdm(range(MAX_STEPS), desc="Training")
for step in progress:
    # Get batch
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(dataloader)
        batch = next(data_iter)
    
    input_ids = batch['input_ids'].to(device)
    
    # Update learning rate
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Forward
    logits = model(input_ids)
    
    # Compute loss (next token prediction)
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    loss = F.cross_entropy(
        shift_logits.view(-1, VOCAB_SIZE),
        shift_labels.view(-1),
        ignore_index=tokenizer.pad_token_id,
    )
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    losses.append(loss.item())
    
    # Logging
    if (step + 1) % 50 == 0:
        avg_loss = sum(losses[-50:]) / len(losses[-50:])
        health = check_weight_stats(model)
        elapsed = time.time() - start_time
        
        progress.set_postfix({
            'loss': f"{avg_loss:.4f}",
            'lr': f"{lr:.2e}",
            'health': f"{health['healthy']}/{health['total']}",
        })

print(f"\nTraining complete!")
print(f"Final loss: {sum(losses[-100:]) / 100:.4f}")

# Check final weight health
final_health = check_weight_stats(model)
print(f"Weight health: {final_health['healthy']}/{final_health['total']} healthy")
print(f"Average std: {final_health['avg_std']:.6f}")

# Plot loss
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Smoothed
plt.subplot(1, 2, 2)
window = 50
smoothed = [sum(losses[max(0,i-window):i+1]) / min(i+1, window) for i in range(len(losses))]
plt.plot(smoothed)
plt.xlabel('Step')
plt.ylabel('Smoothed Loss')
plt.title('Smoothed Training Loss')
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# @title Test Generation

print("\n" + "=" * 60)
print("GENERATION TEST")
print("=" * 60)

model.eval()

test_prompts = [
    "def fibonacci(n):",
    "### Instruction:\nWrite a function to reverse a string.\n\n### Response:",
    "The quick brown fox",
]

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 40)
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    output_ids = model.generate(input_ids, max_new_tokens=50, temperature=0.7)
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    print(f"Output: {output_text}")

In [None]:
# @title Save Model

save_path = f"bitnet_{MODEL_SIZE}_step{MAX_STEPS}.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'vocab_size': VOCAB_SIZE,
    'step': MAX_STEPS,
    'final_loss': sum(losses[-100:]) / 100 if losses else 0,
}, save_path)

print(f"Model saved to: {save_path}")

# Download link (for Colab)
try:
    from google.colab import files
    files.download(save_path)
except:
    print("(Not in Colab, skipping download)")

## Summary

This notebook implements **proper BitNet training** following Microsoft's architecture:

### Key Differences from Standard LLMs

1. **Squared ReLU (ReLU²)** instead of SwiGLU
   - Simpler gradient flow
   - Better sparsity with ternary weights

2. **SubLN** (post-norm after residual) instead of pre-norm
   - `x = Norm(x + Attn(x))` instead of `x = x + Attn(Norm(x))`
   - Critical for training stability

3. **No bias anywhere**
   - Cleaner gradient flow
   - Reduces parameters

4. **ZERO weight decay**
   - Weight decay + ternary quantization = weight collapse!
   - This is the most critical difference

### Training Pipeline

For production models, Microsoft uses:
1. **Pre-training**: 4T tokens, next-token prediction
2. **SFT**: Instruction tuning with sum loss aggregation
3. **DPO**: Alignment with human preferences

### References

- [BitNet Paper](https://arxiv.org/abs/2402.17764)
- [BitNet 2B4T Report](https://arxiv.org/abs/2504.12285)
- [Microsoft BitNet GitHub](https://github.com/microsoft/BitNet)