In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sudoku import SudokuGenerator, Sudoku
import math
import subprocess
import gc


In [None]:
# ============================================================================
# CONFIGURATION PARAMETERS
# ============================================================================
# Modify these parameters to experiment with different model configurations

# --- Model Architecture Parameters ---
HIDDEN_DIM = 128              # Network width (default: 256, reduced for memory)
NUM_LAYERS = 9                # Number of residual conv blocks (default: 6, reduced for memory)
KERNEL_SIZE = 3               # Conv kernel size
NUM_GROUPS = 8                # GroupNorm groups (must divide HIDDEN_DIM evenly)
NUM_TIMESTEPS = 81            # Fixed at 81 for Sudoku (9x9 grid)

# --- Embedding Parameters ---
USE_LEARNED_EMBEDDINGS = True  # Use learned embeddings from LLM model
EMBEDDING_MODEL_PATH = './sudoku2vec_trained_model.pt'  # Path to saved embedding model
EMBEDDING_DIM = 15            # Dimension of learned embeddings (from LLM model)

# --- Training Hyperparameters ---
DATASET_SIZE = 20000          # Number of diffusion sequences to pre-generate
NUM_EPOCHS = 4000             # Number of training epochs
BATCH_SIZE = 1024              # Batch size (reduced from 1024 for memory)
LEARNING_RATE = 1e-3          # Optimizer learning rate
WEIGHT_DECAY = 1e-4           # AdamW weight decay for regularization
GRAD_CLIP_MAX_NORM = 1.0      # Gradient clipping threshold

# --- Logging & Evaluation ---
LOG_INTERVAL = 10             # Log metrics every N epochs
EVAL_INTERVAL = 50           # Evaluate and sample every N epochs

# --- Diffusion Parameters ---
K_MAX = 6                     # Maximum number of forward steps for multi-step prediction loss (reduced from 10)

# --- Device Configuration ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Wandb & Checkpointing Configuration ---
USE_WANDB = True              # Enable Weights & Biases logging
WANDB_PROJECT = "sudoku-diffusion"  # Wandb project name
WANDB_ENTITY = None           # Wandb entity (None = default)
CHECKPOINT_DIR = "./checkpoints"  # Directory to save model checkpoints
CHECKPOINT_INTERVAL = 50      # Save checkpoint every N epochs
RESUME_FROM_CHECKPOINT = None  # Path to checkpoint to resume from (None = start fresh)


# --- Sudoku2Vec ---
ATTENTION_DIM = 9
N_HEADS = 9

print("Configuration loaded:")
print(f"  Model: hidden_dim={HIDDEN_DIM}, num_layers={NUM_LAYERS}, kernel_size={KERNEL_SIZE}")
print(f"  Embeddings: use_learned={USE_LEARNED_EMBEDDINGS}, embedding_dim={EMBEDDING_DIM}")
print(f"  Training: dataset_size={DATASET_SIZE}, epochs={NUM_EPOCHS}, batch_size={BATCH_SIZE}, lr={LEARNING_RATE}")
print(f"  Diffusion: k_max={K_MAX}")
print(f"  Device: {DEVICE}")
print(f"  Wandb: enabled={USE_WANDB}, project={WANDB_PROJECT}")
print(f"  Checkpointing: dir={CHECKPOINT_DIR}, interval={CHECKPOINT_INTERVAL}")
print("\n‚ö†Ô∏è  Memory optimizations applied:")
print(f"  - Reduced BATCH_SIZE to {BATCH_SIZE} (from 1024)")
print(f"  - Reduced K_MAX to {K_MAX} (from 10)")
print(f"  - Reduced DATASET_SIZE to {DATASET_SIZE} (from 10000)")
print("  - Added memory cleanup in training loop")
print("\n‚ö° Performance optimizations enabled:")
print("  - torch.compile() for JIT compilation (2-3x speedup)")
print("  - Mixed precision training (AMP) (1.5-2x speedup)")
print("  - Fused AdamW optimizer")
print("  - Optimized compute_loss() (removed unnecessary clones/ops)")
print("  - Expected combined speedup: 3-6x faster training")
print("\nüí° If you still get OOM errors, restart the kernel first!")


Configuration loaded:
  Model: hidden_dim=64, num_layers=9, kernel_size=3
  Embeddings: use_learned=True, embedding_dim=15
  Training: dataset_size=20000, epochs=4000, batch_size=1024, lr=0.001
  Diffusion: k_max=6
  Device: cuda

‚ö†Ô∏è  Memory optimizations applied:
  - Reduced BATCH_SIZE to 1024 (from 1024)
  - Reduced K_MAX to 6 (from 10)
  - Reduced DATASET_SIZE to 20000 (from 10000)
  - Added memory cleanup in training loop

‚ö° Performance optimizations enabled:
  - torch.compile() for JIT compilation (2-3x speedup)
  - Mixed precision training (AMP) (1.5-2x speedup)
  - Fused AdamW optimizer
  - Optimized compute_loss() (removed unnecessary clones/ops)
  - Expected combined speedup: 3-6x faster training

üí° If you still get OOM errors, restart the kernel first!


In [11]:
# ============================================================================
# LOAD EMBEDDING MODEL (from llm_on_sudoku.ipynb)
# ============================================================================
# We need the Sudoku2Vec class definition to load the trained model


class PositionalEncoding(nn.Module):
    """Positional encoding on unit circle for 9x9 Sudoku grid"""
    def __init__(self):
        super(PositionalEncoding, self).__init__()
        # Create a grid of positions (0-8 for both x and y)
        x_coords = torch.arange(0, 9).unsqueeze(0).repeat(9, 1)
        y_coords = torch.arange(0, 9).unsqueeze(1).repeat(1, 9)
        
        # Convert grid positions to linear indices (0-80)
        linear_indices = y_coords * 9 + x_coords  # shape: (9, 9)
        
        # Convert linear indices to angles on unit circle
        angles = 2 * math.pi * linear_indices / 81  # shape: (9, 9)
        
        # Compute x, y coordinates on unit circle
        x_circle = torch.cos(angles)
        y_circle = torch.sin(angles)
        
        # Stack and add batch dimension
        pos_encoding = torch.stack([x_circle, y_circle], dim=-1).unsqueeze(0)  # shape: (1, 9, 9, 2)
        self.register_buffer('pos_encoding', pos_encoding)
    
    def get_embedding_for_position(self, pos):
        # input (batch, 2) where pos[:, 0] is x and pos[:, 1] is y
        linear_indices = pos[:, 1] * 9 + pos[:, 0]  # shape: (batch,)
        angles = 2 * math.pi * linear_indices / 81  # shape: (batch,)
        x_circle = torch.cos(angles).unsqueeze(1)  # shape: (batch, 1)
        y_circle = torch.sin(angles).unsqueeze(1)  # shape: (batch, 1)
        return torch.cat([x_circle, y_circle], dim=1)  # shape: (batch, 2)
    
    def forward(self, x):
        # x is a (batch, 9, 9, embedding_dim) grid
        # output (batch, 9, 9, embedding_dim + 2) grid by adding pos_encoding to x
        batch_size = x.shape[0]
        pos_expanded = self.pos_encoding.repeat(batch_size, 1, 1, 1)
        return torch.cat([x, pos_expanded], dim=-1)

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention
# Helper function to support different mask shapes.
# Output shape supports (batch_size, number of heads, seq length, seq length)
# If 2D: broadcasted over batch size and number of heads
# If 3D: broadcasted over number of heads
# If 4D: leave as is
def expand_mask(mask):
    assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask

class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # stack all weight matrices 1...h together for efficiency
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, input_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            mask = expand_mask(mask)
        qkv = self.qkv_proj(x)

        # seperate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0,2,1,3) # [batch, head, seqlen, dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0,2,1,3) # [batch, seqlen, head, dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values) # [batch, seq_length, 81]

        if return_attention:
            return o, attention
        else:
            return o

class Sudoku2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim, attention_dim=ATTENTION_DIM, num_heads=N_HEADS, device='cpu'):
        super(Sudoku2Vec, self).__init__()
        self.device = device
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        self.pe = PositionalEncoding()
        self.embed = nn.Embedding(vocab_size, embedding_dim) # this will provide the key queries and values
        self.total_dim = self.embedding_dim + 2

        self.mha = MultiheadAttention(
            input_dim=self.total_dim,
            embed_dim=attention_dim,
            num_heads=num_heads
        )
        
        # Move model to device
        self.to(device)
    
    def get_embeddings(self):
        """
        Returns the learned token embeddings.
        
        Returns:
            torch.Tensor: Embedding weight matrix of shape [vocab_size, embedding_dim]
        """
        return self.embed.weight.detach()
    
    def get_embedding_for_token(self, token):
        """
        Get the embedding vector for a specific token.
        
        Args:
            token: Integer token ID or tensor of token IDs
            
        Returns:
            torch.Tensor: Embedding vector(s) for the given token(s)
        """
        if isinstance(token, int):
            token = torch.tensor([token], device=self.device)
        elif not isinstance(token, torch.Tensor):
            token = torch.tensor(token, device=self.device)
        return self.embed(token).detach()
    
    def save_model(self, filepath):
        """
        Save the model in a portable format that can be easily loaded.
        This saves the model architecture and weights in a single file.
        
        Args:
            filepath: Path where to save the model (should end with .pt or .pth)
        """
        save_dict = {
            'model_state_dict': self.state_dict(),
            'model_config': {
                'vocab_size': self.embed.num_embeddings,
                'embedding_dim': self.embedding_dim,
                'attention_dim': self.mha.embed_dim,
                'num_heads': self.num_heads,
            },
            'model_class': 'Sudoku2Vec',
        }
        torch.save(save_dict, filepath)
        print(f"Model saved to {filepath}")
    
    @classmethod
    def load_model(cls, filepath, device='cpu'):
        """
        Load a saved Sudoku2Vec model from file.
        
        Args:
            filepath: Path to the saved model file
            device: Device to load the model on ('cpu', 'cuda', 'mps')
            
        Returns:
            Sudoku2Vec: Loaded model instance
        """
        checkpoint = torch.load(filepath, map_location=device)
        
        # Extract configuration
        config = checkpoint['model_config']
        
        # Create model instance
        model = cls(
            vocab_size=config['vocab_size'],
            embedding_dim=config['embedding_dim'],
            attention_dim=config['attention_dim'],
            num_heads=config['num_heads'],
            device=device
        )
        
        # Load weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()  # Set to evaluation mode by default
        
        print(f"Model loaded from {filepath}")
        print(f"Configuration: vocab_size={config['vocab_size']}, "
              f"embedding_dim={config['embedding_dim']}, "
              f"attention_dim={config['attention_dim']}, "
              f"num_heads={config['num_heads']}")
        
        return model
    
    def forward(self, target, position, sudoku_grid, mask=True):
        # target - the token in the target blank space we try to predict shape [batch] i.e [0, 3, 3, 5, 1, ...]
        # position - the (x, y) position of the target shape [batch, 2] - [[1, 1], [0, 3], [7,7], ...]
        # sudoku_grid - the sudoku grid for the problem with target we want to predict shape [batch, 9, 9]
        batch_size = target.shape[0]
        
        target_token_embeddings = self.embed(target) # shape [batch, embedding_dim]
        target_position_vectors = self.pe.get_embedding_for_position(position) # [batch, 2]
        target_token_with_position = torch.cat([target_token_embeddings, target_position_vectors], dim=-1)  # shape [batch, total_dim]

        # mask the target in the grid
        sudoku_grid_masked = sudoku_grid
        if mask:
            batch_indices = torch.arange(sudoku_grid.shape[0], device=self.device)
            sudoku_grid_masked = sudoku_grid.clone()
            sudoku_grid_masked[batch_indices, position[:, 1], position[:, 0]] = 0 # 0 is a mask token aka blank
        
        masked_sudoku_grid_embeddings = self.embed(sudoku_grid_masked)
        masked_sudoku_grid_with_position = self.pe(masked_sudoku_grid_embeddings) # shape [batch, 9, 9, total_dim]
        # Reshape grid to sequence: [batch, 81, total_dim]
        masked_grid_seq = masked_sudoku_grid_with_position.view(batch_size, 81, self.total_dim)

        grid_seq_embeddings = self.embed(sudoku_grid)
        grid_seq_embeddings = grid_seq_embeddings.view(batch_size, 81, self.embedding_dim) 
        
        # Query from target token: [batch, 1, total_dim]
        # query = target_token_with_position.unsqueeze(1)

        output, attention = self.mha(masked_grid_seq, return_attention=True)
        # output is shape [batch, 81, total_dim]
        
        return output, attention, target_token_with_position, grid_seq_embeddings
    

In [None]:
class SudokuDiffusionDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for pre-generated diffusion sequences.
    
    Each item is a diffusion sequence of shape (82, 9, 9) where:
    - Index 0: completely masked grid (all zeros)
    - Index 81: completely solved grid
    - Indices 1-80: intermediate states with progressively more cells revealed
    """
    def __init__(self, sequences):
        """
        Args:
            sequences: Tensor of shape (dataset_size, 82, 9, 9) containing diffusion sequences
        """
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return self.sequences[idx]


class SudokuDiffusionModel(nn.Module):
    """
    Diffusion model for Sudoku puzzles inspired by DDPM.
    
    Forward process: Progressively mask cells from a complete sudoku (T=81) to empty grid (T=0)
    Reverse process: Learn to predict which cells to reveal to go from T to T+1
    
    The model learns to reverse the masking process, predicting which cell should be revealed
    at each timestep given the current partially revealed grid.
    """
    def __init__(self, hidden_dim=256, num_layers=6, kernel_size=3, num_groups=8, 
                 embedding_layer=None, device='cuda'):
        super().__init__()
        self.device = device
        self.num_timesteps = 81  # 81 cells in a sudoku grid
        self.embedding_layer = embedding_layer
        
        # Determine input channels based on whether we use embeddings
        if embedding_layer is not None:
            # Using learned embeddings: embedding_dim channels
            input_channels = embedding_layer.embedding_dim
            self.use_embeddings = True
        else:
            # Using simple normalization: 1 channel
            input_channels = 1
            self.use_embeddings = False
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Convolutional layers for processing the sudoku grid
        self.conv_in = nn.Conv2d(input_channels, hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
        
        self.conv_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size//2),
                nn.GroupNorm(num_groups, hidden_dim),
                nn.SiLU(),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size//2),
                nn.GroupNorm(num_groups, hidden_dim),
                nn.SiLU()
            ) for _ in range(num_layers)
        ])
        
        # Output: dual heads for position and value prediction
        self.conv_out = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
        
        # Position head: which cell to reveal (81 possibilities)
        self.position_head = nn.Sequential(
            nn.Linear(hidden_dim * 9 * 9, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 81)  # Logits for 81 cells
        )
        
        # Value head: what value to place (10 classes: 0-9)
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim * 9 * 9, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 10)  # Logits for 10 classes
        )
        
    def forward(self, x, t):
        """
        Predict which cell should be revealed next and what value to place.
        
        Args:
            x: (batch, 9, 9) sudoku grids at timestep t (0 = masked, 1-9 = revealed)
            t: (batch,) timesteps (0 to 80)
            
        Returns:
            position_logits: (batch, 81) logits for which cell should be revealed next
            value_logits: (batch, 10) logits for what value (0-9) to place
        """
        batch_size = x.shape[0]
        
        # Process input: either use embeddings or simple normalization
        if self.use_embeddings:
            # Use learned embeddings: (batch, 9, 9) -> (batch, 9, 9, embedding_dim)
            x_embedded = self.embedding_layer(x.long())  # (batch, 9, 9, embedding_dim)
            x_norm = x_embedded.permute(0, 3, 1, 2)  # (batch, embedding_dim, 9, 9)
        else:
            # Simple normalization to [-1, 1] range
            x_norm = (x / 4.5) - 1.0
            x_norm = x_norm.unsqueeze(1)  # (batch, 1, 9, 9)
        
        # Time embedding
        t_norm = t.float().unsqueeze(1) / self.num_timesteps  # (batch, 1)
        t_emb = self.time_embed(t_norm)  # (batch, hidden_dim)
        
        # Process through conv layers
        h = self.conv_in(x_norm)  # (batch, hidden_dim, 9, 9)
        
        # Add time embedding to spatial features
        t_emb_spatial = t_emb.view(batch_size, -1, 1, 1).expand(-1, -1, 9, 9)
        h = h + t_emb_spatial
        
        # Apply conv blocks with residual connections
        for block in self.conv_blocks:
            h = h + block(h)
        
        # Output processing
        h = self.conv_out(h)  # (batch, hidden_dim, 9, 9)
        h_flat = h.reshape(batch_size, -1)  # (batch, hidden_dim * 81)
        
        # Dual predictions
        position_logits = self.position_head(h_flat)  # (batch, 81)
        value_logits = self.value_head(h_flat)  # (batch, 10)
        
        return position_logits, value_logits
    
    def compute_loss(self, sequences, k_max=10):
        """
        Compute the diffusion loss for training using K-step iterative prediction.
        
        The forward diffusion process (from sudoku.py) goes from empty (T=0) to complete (T=81).
        We learn to predict K steps ahead by iteratively applying the model.
        
        Args:
            sequences: (batch, 82, 9, 9) diffusion sequences where:
                      - sequences[:, 0] is completely masked (all zeros)
                      - sequences[:, 81] is completely solved
            k_max: Maximum number of forward steps for multi-step prediction
                      
        Returns:
            loss: scalar combined loss (position + value)
            accuracy: prediction accuracy for logging
        """
        batch_size = sequences.shape[0]
        
        # Sample random starting timestep B from [0, 81-k_max]
        max_start = max(1, self.num_timesteps - k_max)
        B = torch.randint(0, max_start, (batch_size,), device=self.device)
        
        # Sample random K from [1, k_max]
        K = torch.randint(1, k_max + 1, (batch_size,), device=self.device)
        
        # Get starting grids at timestep B (remove unnecessary clone)
        x_current = sequences[torch.arange(batch_size), B].float()  # (batch, 9, 9)
        
        # Track losses and accuracies across all K steps
        total_position_loss = 0.0
        total_value_loss = 0.0
        total_position_acc = 0.0
        total_value_acc = 0.0
        
        # Pre-allocate batch_indices outside loop
        batch_indices = torch.arange(batch_size, device=self.device)
        
        # Iteratively predict K steps
        for step in range(k_max):
            # Current timestep for each batch element
            t_current = B + step
            
            # Only compute loss for elements where step < K[i] and t_current < 81
            active_mask = (step < K) & (t_current < self.num_timesteps)
            
            if not active_mask.any():
                break
            
            # Get target grid at next timestep
            t_next = torch.clamp(t_current + 1, max=self.num_timesteps)
            x_target = sequences[batch_indices, t_next].float()  # (batch, 9, 9)
            
            # Find which cell was revealed (difference between current and target)
            diff = (x_target != x_current).view(batch_size, 81)  # (batch, 81)
            
            # Check if there's actually a difference (cell was revealed)
            has_diff = diff.any(dim=1)  # (batch,)
            active_mask = active_mask & has_diff  # Only process if there's a change
            
            if not active_mask.any():
                break
            
            target_position = diff.float().argmax(dim=1)  # (batch,)
            
            # Get target values at revealed positions (vectorized for efficiency)
            rows = target_position // 9
            cols = target_position % 9
            target_values = x_target[batch_indices, rows, cols].long()
            
            # Predict position and value
            position_logits, value_logits = self.forward(x_current, t_current)  # (batch, 81), (batch, 10)
            
            # Mask out already revealed cells in position prediction (in-place operation)
            already_revealed = (x_current.view(batch_size, 81) != 0)  # (batch, 81)
            position_logits.masked_fill_(already_revealed, float('-inf'))
            
            # Compute losses only for active batch elements
            if active_mask.any():
                position_loss = F.cross_entropy(position_logits[active_mask], target_position[active_mask], reduction='sum')
                value_loss = F.cross_entropy(value_logits[active_mask], target_values[active_mask], reduction='sum')
                
                # Accumulate losses (keep in computation graph for backprop)
                total_position_loss = total_position_loss + position_loss
                total_value_loss = total_value_loss + value_loss
                
                # Compute accuracy for logging
                with torch.no_grad():
                    pred_position = position_logits[active_mask].argmax(dim=1)
                    pred_value = value_logits[active_mask].argmax(dim=1)
                    total_position_acc += (pred_position == target_position[active_mask]).float().sum()
                    total_value_acc += (pred_value == target_values[active_mask]).float().sum()
            
            # Update x_current with ground truth for next iteration (remove unnecessary clone)
            x_current = x_target
        
        # Average losses over all active predictions
        num_predictions = K.float().sum()
        
        # Avoid division by zero
        if num_predictions == 0:
            return torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)
        
        total_position_loss = total_position_loss / num_predictions
        total_value_loss = total_value_loss / num_predictions
        
        # Combine losses
        loss = total_position_loss + total_value_loss
        
        # Average accuracy
        accuracy = (total_position_acc + total_value_acc) / (2 * num_predictions)
        
        return loss, accuracy
    
    @torch.no_grad()
    def sample(self, batch_size=1, return_trajectory=False):
        """
        Generate sudoku puzzles by running the reverse diffusion process.
        Start from empty grid (T=0) and progressively reveal cells to T=81.
        
        Args:
            batch_size: number of puzzles to generate
            return_trajectory: if True, return full trajectory of generation
            
        Returns:
            samples: (batch_size, 9, 9) generated sudoku grids
            trajectory: (batch_size, 82, 9, 9) if return_trajectory=True
        """
        # Start from completely masked grid (T=0)
        x = torch.zeros(batch_size, 9, 9, device=self.device)
        
        if return_trajectory:
            trajectory = torch.zeros(batch_size, 82, 9, 9, device=self.device)
            trajectory[:, 0] = x
        
        # Progressively reveal cells using model predictions
        for t in tqdm(range(self.num_timesteps), desc='Sampling'):
            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            
            # Predict position and value
            position_logits, value_logits = self.forward(x, t_batch)
            
            # Mask out already revealed cells
            already_revealed = (x.view(batch_size, 81) != 0)
            position_logits = position_logits.masked_fill(already_revealed, float('-inf'))
            
            # Sample or take argmax for position
            position_probs = F.softmax(position_logits, dim=-1)
            cell_idx = torch.multinomial(position_probs, 1).squeeze(-1)  # (batch,)
            
            # Take argmax for value (deterministic)
            value_probs = F.softmax(value_logits, dim=-1)
            values = torch.argmax(value_probs, dim=-1)  # (batch,)
            
            # Update grid with predicted values
            for b in range(batch_size):
                idx = cell_idx[b].item()
                row = idx // 9
                col = idx % 9
                x[b, row, col] = values[b]
            
            if return_trajectory:
                trajectory[:, t + 1] = x
        
        if return_trajectory:
            return x.long(), trajectory.long()
        return x.long()


def train_sudoku_diffusion(model, dataset, num_epochs=1000, batch_size=32, lr=1e-4, 
                           weight_decay=1e-4, grad_clip_max_norm=1.0,
                           log_interval=10, eval_interval=100, k_max=10, device='cuda',
                           use_wandb=False, wandb_project="sudoku-diffusion", wandb_entity=None,
                           checkpoint_dir="./checkpoints", checkpoint_interval=50, 
                           resume_from=None):
    """
    Train the sudoku diffusion model with proper logging and performance optimizations.
    
    Args:
        model: SudokuDiffusionModel instance
        dataset: SudokuDiffusionDataset instance with pre-generated sequences
        num_epochs: number of training epochs
        batch_size: batch size for training
        lr: learning rate
        weight_decay: weight decay for AdamW optimizer
        grad_clip_max_norm: max norm for gradient clipping
        log_interval: log metrics every N epochs
        eval_interval: evaluate and sample every N epochs
        k_max: maximum number of forward steps for multi-step prediction
        device: device to train on
        use_wandb: whether to log to Weights & Biases
        wandb_project: wandb project name
        wandb_entity: wandb entity (None = default)
        checkpoint_dir: directory to save checkpoints
        checkpoint_interval: save checkpoint every N epochs
        resume_from: path to checkpoint to resume from (None = start fresh)
    """
    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Create DataLoader for batching and shuffling with optimizations
    # Note: pin_memory should be False when data is already on GPU
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Keep 0 for GPU tensors
        pin_memory=False,  # Data is already on GPU, no need to pin
        persistent_workers=False  # No workers, so this doesn't apply
    )
    
    # Use fused optimizer for faster updates (requires CUDA)
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=lr, 
        weight_decay=weight_decay,
        fused=True if device == 'cuda' else False
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Initialize GradScaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
    
    # Training metrics
    train_losses = []
    train_accuracies = []
    epoch_losses = []
    epoch_accuracies = []
    
    # Starting epoch
    start_epoch = 0
    
    # Resume from checkpoint if specified
    if resume_from is not None and os.path.exists(resume_from):
        print(f"Loading checkpoint from {resume_from}...")
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if scaler is not None and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        epoch_losses = checkpoint.get('epoch_losses', [])
        epoch_accuracies = checkpoint.get('epoch_accuracies', [])
        print(f"‚úì Resumed from epoch {start_epoch}")
        print(f"  Previous best loss: {min(epoch_losses) if epoch_losses else 'N/A'}")
    
    # Initialize wandb
    if use_wandb:
        wandb_config = {
            "hidden_dim": model.conv_in.out_channels,
            "num_layers": len(model.conv_blocks),
            "batch_size": batch_size,
            "learning_rate": lr,
            "weight_decay": weight_decay,
            "num_epochs": num_epochs,
            "k_max": k_max,
            "dataset_size": len(dataset),
        }
        if start_epoch == 0:
            wandb.init(project=wandb_project, entity=wandb_entity, config=wandb_config)
        else:
            # Resume wandb run if checkpoint has run_id
            run_id = checkpoint.get('wandb_run_id', None)
            wandb.init(project=wandb_project, entity=wandb_entity, config=wandb_config, 
                      id=run_id, resume="allow")
        print(f"‚úì Wandb initialized: {wandb.run.name}")
    
    model.train()
    if start_epoch == 0:
        print(f"Starting training for {num_epochs} epochs...")
    else:
        print(f"Resuming training from epoch {start_epoch} to {num_epochs}...")
    print(f"Dataset size: {len(dataset)}, Batch size: {batch_size}, Batches per epoch: {len(dataloader)}")
    print(f"Learning rate: {lr}")
    print(f"Using mixed precision: {scaler is not None}")
    print("-" * 60)
    
    for epoch in range(start_epoch, num_epochs):
        epoch_loss = 0.0
        epoch_acc = 0.0
        num_batches = 0
        
        # Iterate through batches in the dataset
        for batch_sequences in dataloader:
            # Mixed precision training
            if scaler is not None:
                # Forward pass with autocast
                with torch.cuda.amp.autocast():
                    loss, accuracy = model.compute_loss(batch_sequences, k_max=k_max)
                
                # Backward pass with gradient scaling
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training (CPU or non-CUDA)
                loss, accuracy = model.compute_loss(batch_sequences, k_max=k_max)
                
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
                optimizer.step()
            
            # Accumulate metrics
            epoch_loss += loss.item()
            epoch_acc += accuracy.item()
            num_batches += 1
        
        # Average metrics over all batches in the epoch
        avg_epoch_loss = epoch_loss / num_batches
        avg_epoch_acc = epoch_acc / num_batches
        
        # Record metrics
        epoch_losses.append(avg_epoch_loss)
        epoch_accuracies.append(avg_epoch_acc)
        
        # Update learning rate scheduler
        scheduler.step()
        
        # Log to wandb
        if use_wandb:
            wandb.log({
                "epoch": epoch + 1,
                "train/loss": avg_epoch_loss,
                "train/accuracy": avg_epoch_acc,
                "train/learning_rate": scheduler.get_last_lr()[0],
            })
        
        # Logging
        if (epoch + 1) % log_interval == 0:
            current_lr = scheduler.get_last_lr()[0]
            print(f"Epoch {epoch + 1:4d}/{num_epochs} | "
                  f"Loss: {avg_epoch_loss:.4f} | "
                  f"Acc: {avg_epoch_acc:.4f} | "
                  f"LR: {current_lr:.6f}")
        
        # Evaluation and sampling
        if (epoch + 1) % eval_interval == 0:
            print(f"\n{'='*60}")
            print(f"Evaluation at epoch {epoch + 1}")
            print(f"{'='*60}")
            
            model.eval()
            with torch.no_grad():
                # Generate a sample
                sample = model.sample(batch_size=1)
                print("\nGenerated Sudoku:")
                print(sample[0].cpu().numpy())
                
                # Check if valid
                sudoku_obj = Sudoku(sample[0].cpu().numpy(), backend='numpy')
                is_valid = sudoku_obj.is_valid()
                print(f"\nIs valid: {is_valid}")
                
                # Log to wandb
                if use_wandb:
                    wandb.log({
                        "eval/is_valid": int(is_valid),
                        "eval/sample": wandb.Table(
                            data=[[str(sample[0].cpu().numpy())]], 
                            columns=["sudoku_grid"]
                        )
                    })
            
            model.train()
            print(f"{'='*60}\n")
        
        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0 or (epoch + 1) == num_epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pt")
            checkpoint_data = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch_losses': epoch_losses,
                'epoch_accuracies': epoch_accuracies,
                'config': {
                    'batch_size': batch_size,
                    'learning_rate': lr,
                    'weight_decay': weight_decay,
                    'k_max': k_max,
                }
            }
            if scaler is not None:
                checkpoint_data['scaler_state_dict'] = scaler.state_dict()
            if use_wandb:
                checkpoint_data['wandb_run_id'] = wandb.run.id
            
            torch.save(checkpoint_data, checkpoint_path)
            print(f"‚úì Checkpoint saved: {checkpoint_path}")
            
            # Also save as "latest" for easy resumption
            latest_path = os.path.join(checkpoint_dir, "checkpoint_latest.pt")
            torch.save(checkpoint_data, latest_path)
            print(f"‚úì Latest checkpoint updated: {latest_path}")
    
    print("\nTraining completed!")
    
    # Finish wandb run
    if use_wandb:
        wandb.finish()
        print("‚úì Wandb run finished")
    
    # Plot training curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(epoch_losses)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.grid(True)
    
    ax2.plot(epoch_accuracies)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training Accuracy')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return model, epoch_losses, epoch_accuracies


In [13]:
# Check GPU information and clear memory

if torch.cuda.is_available():
    print("GPU Information:")
    print(f"  Device: {torch.cuda.get_device_name(0)}")
    print(f"  Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    print(f"\nInitial GPU Memory Usage:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    
    # Aggressively clear GPU memory
    print("\n‚ö†Ô∏è  Clearing GPU memory...")
    
    # Delete all variables in the current namespace that might hold GPU tensors
    if 'model' in dir():
        del model
    if 'generator' in dir():
        del generator
    if 'sequences' in dir():
        del sequences
    
    # Force garbage collection
    gc.collect()
    
    # Clear PyTorch cache
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()
    
    print(f"\nAfter clearing:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    print(f"  Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3:.2f} GB")
    
    # If still not enough memory, suggest kernel restart
    free_memory = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3
    if free_memory < 1.0:
        print("\n‚ö†Ô∏è  WARNING: Very little GPU memory available!")
        print("   Consider: Kernel -> Restart Kernel to fully clear GPU memory")
else:
    print("CUDA not available, will use CPU")


GPU Information:
  Device: NVIDIA RTX A4000
  Total Memory: 15.63 GB

Initial GPU Memory Usage:
  Allocated: 0.54 GB
  Cached: 3.44 GB

‚ö†Ô∏è  Clearing GPU memory...

After clearing:
  Allocated: 0.54 GB
  Cached: 0.58 GB
  Free: 15.05 GB


In [14]:
# ============================================================================
# CELL 5: LOAD EMBEDDING MODEL
# ============================================================================
# Run this cell once per session to load the pre-trained Sudoku2Vec embeddings.
# This is fast (~1 second) and only needs to run once.

print(f"Using device: {DEVICE}")

# Clear GPU memory if using CUDA
if DEVICE == 'cuda':
    torch.cuda.empty_cache()
    print(f"\nGPU Memory before setup:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    print(f"  Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3:.2f} GB")

# Load embedding model if configured
embedding_layer = None
if USE_LEARNED_EMBEDDINGS:
    print(f"\n{'='*60}")
    print("Loading learned embeddings from LLM model...")
    print(f"{'='*60}")
    try:
        sudoku2vec_model = Sudoku2Vec.load_model(EMBEDDING_MODEL_PATH, device=DEVICE)
        embedding_layer = sudoku2vec_model.embed
        print(f"‚úì Successfully loaded embedding layer with {embedding_layer.num_embeddings} tokens")
        print(f"  Embedding dimension: {embedding_layer.embedding_dim}")
    except FileNotFoundError:
        print(f"‚ö†Ô∏è  WARNING: Embedding model not found at {EMBEDDING_MODEL_PATH}")
        print("   Falling back to simple normalization")
        USE_LEARNED_EMBEDDINGS = False
    except Exception as e:
        print(f"‚ö†Ô∏è  WARNING: Failed to load embedding model: {e}")
        print("   Falling back to simple normalization")
        USE_LEARNED_EMBEDDINGS = False
else:
    print("\nUsing simple normalization (no learned embeddings)")

print("\n‚úì Embedding layer ready!")


Using device: cuda

GPU Memory before setup:
  Allocated: 0.54 GB
  Reserved: 0.58 GB
  Free: 15.05 GB

Loading learned embeddings from LLM model...
Model loaded from ./sudoku2vec_trained_model.pt
Configuration: vocab_size=10, embedding_dim=15, attention_dim=9, num_heads=9
‚úì Successfully loaded embedding layer with 10 tokens
  Embedding dimension: 15

‚úì Embedding layer ready!


In [15]:
# ============================================================================
# CELL 6: GENERATE TRAINING DATASET
# ============================================================================
# Run this cell ONCE to generate the training dataset.
# This is SLOW (~4 minutes for 20k sequences) but you only need to run it once!
# 
# After running this cell, you can:
# - Rerun Cell 7 to try different model architectures
# - Rerun Cell 8 to try different training hyperparameters
# - All without regenerating this expensive dataset!

print(f"\n{'='*60}")
print(f"Pre-generating training dataset...")
print(f"{'='*60}")

# Initialize generator
generator = SudokuGenerator(backend='torch', device=DEVICE)

# Generate diffusion sequences
print(f"Generating {DATASET_SIZE} diffusion sequences...")
import time
start_time = time.time()

sequences = generator.generate_diffusion_sequence(size=DATASET_SIZE)

generation_time = time.time() - start_time
print(f"‚úì Dataset generated in {generation_time:.2f} seconds ({generation_time/DATASET_SIZE*1000:.2f} ms per sequence)")

# Report dataset statistics
sequence_shape = sequences.shape
memory_mb = sequences.element_size() * sequences.nelement() / (1024 ** 2)
print(f"\nDataset statistics:")
print(f"  Shape: {sequence_shape}")
print(f"  Memory: {memory_mb:.2f} MB")
print(f"  Device: {sequences.device}")

# Create PyTorch Dataset
train_dataset = SudokuDiffusionDataset(sequences)
print(f"‚úì Created SudokuDiffusionDataset with {len(train_dataset)} sequences")
print(f"\nüí° Dataset is ready! You can now rerun Cells 7 & 8 with different hyperparameters.")



Pre-generating training dataset...
Generating 20000 diffusion sequences...
‚úì Dataset generated in 262.51 seconds (13.13 ms per sequence)

Dataset statistics:
  Shape: torch.Size([20000, 82, 9, 9])
  Memory: 506.74 MB
  Device: cuda:0
‚úì Created SudokuDiffusionDataset with 20000 sequences

üí° Dataset is ready! You can now rerun Cells 7 & 8 with different hyperparameters.


In [16]:
# ============================================================================
# CELL 7: INITIALIZE DIFFUSION MODEL
# ============================================================================
# Run this cell to create and compile the diffusion model.
# This is FAST (~5 seconds) and you can rerun it to experiment with:
# - Different model sizes (HIDDEN_DIM, NUM_LAYERS)
# - Different architectures (KERNEL_SIZE, NUM_GROUPS)
# 
# The dataset from Cell 6 will be reused!

print(f"\n{'='*60}")
print("Initializing Diffusion Model...")
print(f"{'='*60}")

model = SudokuDiffusionModel(
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    kernel_size=KERNEL_SIZE,
    num_groups=NUM_GROUPS,
    embedding_layer=embedding_layer,
    device=DEVICE
).to(DEVICE)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"‚úì Model initialized successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Using embeddings: {model.use_embeddings}")

# Compile model for faster training (PyTorch 2.0+)
if DEVICE == 'cuda':
    print(f"\n‚ö° Compiling model with torch.compile for faster training...")
    try:
        model = torch.compile(model, mode='reduce-overhead')
        print(f"‚úì Model compiled successfully")
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not compile model: {e}")
        print(f"   Continuing without compilation")

if DEVICE == 'cuda':
    print(f"\nGPU Memory after model loading:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

print(f"\n‚úì Model is ready for training!")



Initializing Diffusion Model...
‚úì Model initialized successfully
  Total parameters: 1,386,673
  Trainable parameters: 1,386,673
  Using embeddings: True

‚ö° Compiling model with torch.compile for faster training...
‚úì Model compiled successfully

GPU Memory after model loading:
  Allocated: 1.04 GB
  Reserved: 1.07 GB

‚úì Model is ready for training!


In [None]:
# ============================================================================
# CELL 8: TRAIN THE MODEL
# ============================================================================
# Run this cell to train the model with the current hyperparameters.
# You can rerun this cell to experiment with different:
# - Learning rates (LEARNING_RATE)
# - Batch sizes (BATCH_SIZE)
# - Training strategies (K_MAX, WEIGHT_DECAY, GRAD_CLIP_MAX_NORM)
# - Logging intervals (LOG_INTERVAL, EVAL_INTERVAL)
# 
# The model from Cell 7 and dataset from Cell 6 will be used!

print(f"\n{'='*60}")
print("Starting Training...")
print(f"{'='*60}")

model, losses, accuracies = train_sudoku_diffusion(
    model=model,
    dataset=train_dataset,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    grad_clip_max_norm=GRAD_CLIP_MAX_NORM,
    log_interval=LOG_INTERVAL,
    eval_interval=EVAL_INTERVAL,
    k_max=K_MAX,
    device=DEVICE,
    use_wandb=USE_WANDB,
    wandb_project=WANDB_PROJECT,
    wandb_entity=WANDB_ENTITY,
    checkpoint_dir=CHECKPOINT_DIR,
    checkpoint_interval=CHECKPOINT_INTERVAL,
    resume_from=RESUME_FROM_CHECKPOINT
)



Starting Training...
Starting training for 4000 epochs...
Dataset size: 20000, Batch size: 1024, Batches per epoch: 20
Learning rate: 0.001
Using mixed precision: True
------------------------------------------------------------


  scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
  with torch.cuda.amp.autocast():


Epoch   10/4000 | Loss: 5.7406 | Acc: 0.0849 | LR: 0.001000
Epoch   20/4000 | Loss: 5.7233 | Acc: 0.0941 | LR: 0.001000
Epoch   30/4000 | Loss: 5.7025 | Acc: 0.1001 | LR: 0.001000
Epoch   40/4000 | Loss: 5.7024 | Acc: 0.1018 | LR: 0.001000
Epoch   50/4000 | Loss: 5.6892 | Acc: 0.1047 | LR: 0.001000

Evaluation at epoch 50


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 304.37it/s]



Generated Sudoku:
[[7 8 6 9 6 9 1 1 9]
 [2 3 1 7 3 6 7 1 7]
 [5 7 5 5 1 6 4 1 6]
 [7 5 4 8 6 8 9 6 5]
 [3 7 9 4 5 7 7 3 6]
 [7 7 1 3 3 7 8 9 6]
 [9 4 8 2 5 2 4 2 3]
 [6 4 3 6 3 4 6 5 8]
 [2 9 4 7 1 5 8 2 9]]

Is valid: False

Epoch   60/4000 | Loss: 5.6712 | Acc: 0.1076 | LR: 0.000999
Epoch   70/4000 | Loss: 5.6674 | Acc: 0.1077 | LR: 0.000999
Epoch   80/4000 | Loss: 5.6630 | Acc: 0.1081 | LR: 0.000999
Epoch   90/4000 | Loss: 5.6691 | Acc: 0.1079 | LR: 0.000999
Epoch  100/4000 | Loss: 5.6532 | Acc: 0.1098 | LR: 0.000998

Evaluation at epoch 100


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 322.67it/s]



Generated Sudoku:
[[2 2 1 9 6 6 6 2 7]
 [6 3 6 4 1 1 4 5 1]
 [8 4 6 8 3 2 5 3 7]
 [6 3 1 7 2 5 7 5 5]
 [5 7 4 7 8 5 8 1 5]
 [6 9 9 7 2 4 6 3 6]
 [1 7 4 1 2 8 9 5 3]
 [7 8 9 8 9 4 9 9 2]
 [2 3 6 9 8 3 5 3 7]]

Is valid: False

Epoch  110/4000 | Loss: 5.6550 | Acc: 0.1084 | LR: 0.000998
Epoch  120/4000 | Loss: 5.6514 | Acc: 0.1083 | LR: 0.000998
Epoch  130/4000 | Loss: 5.6371 | Acc: 0.1113 | LR: 0.000997
Epoch  140/4000 | Loss: 5.6469 | Acc: 0.1091 | LR: 0.000997
Epoch  150/4000 | Loss: 5.6456 | Acc: 0.1097 | LR: 0.000997

Evaluation at epoch 150


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 310.68it/s]



Generated Sudoku:
[[7 9 3 9 7 3 8 1 1]
 [2 4 3 3 7 1 8 5 7]
 [5 2 3 7 1 4 9 1 6]
 [3 2 8 5 5 5 2 6 3]
 [8 1 6 5 7 2 7 9 4]
 [3 6 4 8 2 2 9 6 8]
 [5 8 2 6 9 1 4 7 1]
 [4 2 7 9 9 6 8 4 6]
 [5 6 4 9 4 3 5 8 4]]

Is valid: False

Epoch  160/4000 | Loss: 5.6405 | Acc: 0.1115 | LR: 0.000996
Epoch  170/4000 | Loss: 5.6294 | Acc: 0.1121 | LR: 0.000996
Epoch  180/4000 | Loss: 5.6372 | Acc: 0.1119 | LR: 0.000995
Epoch  190/4000 | Loss: 5.6313 | Acc: 0.1126 | LR: 0.000994
Epoch  200/4000 | Loss: 5.6402 | Acc: 0.1120 | LR: 0.000994

Evaluation at epoch 200


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 314.93it/s]



Generated Sudoku:
[[4 1 9 1 4 9 7 9 6]
 [6 6 5 6 1 9 2 9 6]
 [7 4 7 9 8 2 8 4 1]
 [4 5 7 8 6 9 3 2 7]
 [1 3 8 4 8 7 8 5 7]
 [4 5 5 8 9 2 4 4 1]
 [7 6 5 2 5 7 3 5 6]
 [3 1 3 2 2 7 3 6 8]
 [1 5 3 3 1 2 1 6 3]]

Is valid: False

Epoch  210/4000 | Loss: 5.6367 | Acc: 0.1126 | LR: 0.000993
Epoch  220/4000 | Loss: 5.6234 | Acc: 0.1144 | LR: 0.000993
Epoch  230/4000 | Loss: 5.6316 | Acc: 0.1134 | LR: 0.000992
Epoch  240/4000 | Loss: 5.6276 | Acc: 0.1158 | LR: 0.000991
Epoch  250/4000 | Loss: 5.6203 | Acc: 0.1162 | LR: 0.000990

Evaluation at epoch 250


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 311.69it/s]



Generated Sudoku:
[[9 2 7 5 1 3 9 1 9]
 [8 2 5 7 3 3 4 3 4]
 [3 6 8 8 9 1 2 6 4]
 [1 1 8 5 5 1 3 8 5]
 [5 5 4 6 7 1 4 2 9]
 [6 2 8 3 3 9 2 3 8]
 [6 4 7 2 4 6 6 7 8]
 [7 9 9 2 1 4 4 7 1]
 [5 9 7 2 7 8 6 9 5]]

Is valid: False

Epoch  260/4000 | Loss: 5.6331 | Acc: 0.1162 | LR: 0.000990
Epoch  270/4000 | Loss: 5.6262 | Acc: 0.1164 | LR: 0.000989
Epoch  280/4000 | Loss: 5.6305 | Acc: 0.1192 | LR: 0.000988
Epoch  290/4000 | Loss: 5.6133 | Acc: 0.1195 | LR: 0.000987
Epoch  300/4000 | Loss: 5.6106 | Acc: 0.1209 | LR: 0.000986

Evaluation at epoch 300


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 314.46it/s]



Generated Sudoku:
[[8 7 3 3 4 8 7 1 1]
 [6 4 2 1 3 1 4 9 1]
 [2 5 8 9 3 7 4 6 8]
 [5 8 8 2 1 2 3 7 8]
 [6 7 9 2 4 9 4 1 6]
 [6 3 5 6 4 5 6 2 7]
 [5 2 4 3 5 9 6 9 7]
 [7 9 5 8 9 5 8 5 7]
 [3 1 1 6 1 2 7 3 2]]

Is valid: False

Epoch  310/4000 | Loss: 5.6196 | Acc: 0.1205 | LR: 0.000985
Epoch  320/4000 | Loss: 5.6047 | Acc: 0.1226 | LR: 0.000984
Epoch  330/4000 | Loss: 5.5877 | Acc: 0.1244 | LR: 0.000983
Epoch  340/4000 | Loss: 5.5990 | Acc: 0.1250 | LR: 0.000982
Epoch  350/4000 | Loss: 5.5935 | Acc: 0.1259 | LR: 0.000981

Evaluation at epoch 350


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 316.32it/s]



Generated Sudoku:
[[3 3 4 7 8 4 9 9 9]
 [3 2 5 8 5 7 3 4 6]
 [3 7 9 6 1 8 7 3 6]
 [8 6 4 2 3 1 2 1 1]
 [7 9 7 8 2 8 4 3 9]
 [4 7 4 1 2 5 2 5 7]
 [5 7 4 6 8 5 8 5 1]
 [1 6 3 8 2 5 9 1 9]
 [1 9 2 5 2 6 6 3 9]]

Is valid: False

Epoch  360/4000 | Loss: 5.5967 | Acc: 0.1265 | LR: 0.000980
Epoch  370/4000 | Loss: 5.5826 | Acc: 0.1265 | LR: 0.000979
Epoch  380/4000 | Loss: 5.5867 | Acc: 0.1271 | LR: 0.000978
Epoch  390/4000 | Loss: 5.5739 | Acc: 0.1296 | LR: 0.000977
Epoch  400/4000 | Loss: 5.5669 | Acc: 0.1308 | LR: 0.000976

Evaluation at epoch 400


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 313.33it/s]



Generated Sudoku:
[[4 9 1 4 7 4 6 5 2]
 [1 7 3 9 1 3 6 2 4]
 [2 2 7 8 8 8 8 5 6]
 [1 9 9 3 2 4 6 7 2]
 [5 2 1 2 5 9 8 6 2]
 [3 6 5 7 7 3 1 5 3]
 [5 5 4 8 4 7 9 1 4]
 [3 5 3 3 1 6 6 4 8]
 [6 1 9 5 7 2 9 8 9]]

Is valid: False

Epoch  410/4000 | Loss: 5.5813 | Acc: 0.1335 | LR: 0.000974
Epoch  420/4000 | Loss: 5.5682 | Acc: 0.1323 | LR: 0.000973
Epoch  430/4000 | Loss: 5.5570 | Acc: 0.1340 | LR: 0.000972
Epoch  440/4000 | Loss: 5.5390 | Acc: 0.1354 | LR: 0.000970
Epoch  450/4000 | Loss: 5.5479 | Acc: 0.1373 | LR: 0.000969

Evaluation at epoch 450


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 318.81it/s]



Generated Sudoku:
[[7 1 2 4 1 9 8 6 2]
 [1 2 7 2 3 4 5 5 1]
 [7 8 1 2 9 3 9 1 7]
 [5 3 3 5 8 7 7 7 6]
 [8 2 6 8 3 6 5 2 1]
 [5 4 4 9 6 5 4 3 4]
 [8 3 9 5 7 5 6 5 1]
 [6 9 1 2 9 9 4 4 8]
 [2 1 3 7 6 3 9 6 2]]

Is valid: False

Epoch  460/4000 | Loss: 5.5398 | Acc: 0.1401 | LR: 0.000968
Epoch  470/4000 | Loss: 5.5263 | Acc: 0.1392 | LR: 0.000966
Epoch  480/4000 | Loss: 5.5302 | Acc: 0.1402 | LR: 0.000965
Epoch  490/4000 | Loss: 5.5264 | Acc: 0.1421 | LR: 0.000963
Epoch  500/4000 | Loss: 5.5039 | Acc: 0.1449 | LR: 0.000962

Evaluation at epoch 500


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 323.01it/s]



Generated Sudoku:
[[8 2 1 7 5 7 5 3 7]
 [5 1 5 2 7 3 2 5 3]
 [5 4 2 9 1 8 4 5 2]
 [2 8 9 8 3 9 7 6 9]
 [7 6 4 4 8 7 9 3 6]
 [2 1 9 9 7 5 9 5 2]
 [6 6 1 1 3 4 8 6 6]
 [9 1 4 4 3 8 3 2 4]
 [7 7 3 4 1 8 6 8 6]]

Is valid: False

Epoch  510/4000 | Loss: 5.5139 | Acc: 0.1454 | LR: 0.000960
Epoch  520/4000 | Loss: 5.5264 | Acc: 0.1457 | LR: 0.000959
Epoch  530/4000 | Loss: 5.5002 | Acc: 0.1472 | LR: 0.000957
Epoch  540/4000 | Loss: 5.4960 | Acc: 0.1499 | LR: 0.000956
Epoch  550/4000 | Loss: 5.4949 | Acc: 0.1516 | LR: 0.000954

Evaluation at epoch 550


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 320.61it/s]



Generated Sudoku:
[[8 3 5 3 9 2 2 7 6]
 [8 4 8 7 3 2 7 2 1]
 [8 9 6 1 4 2 3 8 2]
 [9 3 5 6 7 9 8 5 2]
 [1 6 9 9 4 7 4 5 6]
 [8 7 9 4 3 7 4 5 9]
 [5 1 6 8 3 9 1 2 7]
 [1 4 7 1 1 8 2 5 6]
 [8 5 6 6 1 3 2 3 5]]

Is valid: False

Epoch  560/4000 | Loss: 5.4898 | Acc: 0.1514 | LR: 0.000952
Epoch  570/4000 | Loss: 5.4803 | Acc: 0.1515 | LR: 0.000951
Epoch  580/4000 | Loss: 5.4719 | Acc: 0.1547 | LR: 0.000949
Epoch  590/4000 | Loss: 5.4818 | Acc: 0.1531 | LR: 0.000947
Epoch  600/4000 | Loss: 5.4679 | Acc: 0.1559 | LR: 0.000946

Evaluation at epoch 600


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 326.52it/s]



Generated Sudoku:
[[7 9 4 9 6 1 3 5 3]
 [7 2 2 7 4 7 7 4 5]
 [1 8 6 2 1 5 6 7 7]
 [5 4 5 4 9 8 9 3 1]
 [4 7 8 9 2 1 6 3 1]
 [3 3 1 5 8 4 5 8 2]
 [6 7 6 1 2 9 6 6 2]
 [9 2 3 8 9 3 8 2 8]
 [5 9 9 1 4 3 1 7 5]]

Is valid: False

Epoch  610/4000 | Loss: 5.4717 | Acc: 0.1544 | LR: 0.000944
Epoch  620/4000 | Loss: 5.4517 | Acc: 0.1583 | LR: 0.000942
Epoch  630/4000 | Loss: 5.4586 | Acc: 0.1585 | LR: 0.000940
Epoch  640/4000 | Loss: 5.4290 | Acc: 0.1621 | LR: 0.000938
Epoch  650/4000 | Loss: 5.4496 | Acc: 0.1609 | LR: 0.000936

Evaluation at epoch 650


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 318.54it/s]



Generated Sudoku:
[[3 5 2 4 1 9 2 6 8]
 [5 9 7 7 4 6 3 5 9]
 [2 5 1 8 8 3 3 3 9]
 [3 1 8 2 9 2 6 9 9]
 [8 5 8 1 6 4 5 8 6]
 [9 6 3 7 7 8 7 5 2]
 [4 2 2 1 7 4 6 9 1]
 [7 5 1 4 3 5 8 4 4]
 [6 2 3 7 1 2 7 9 4]]

Is valid: False

Epoch  660/4000 | Loss: 5.4484 | Acc: 0.1612 | LR: 0.000934
Epoch  670/4000 | Loss: 5.4248 | Acc: 0.1646 | LR: 0.000932
Epoch  680/4000 | Loss: 5.4293 | Acc: 0.1641 | LR: 0.000930
Epoch  690/4000 | Loss: 5.4130 | Acc: 0.1671 | LR: 0.000928
Epoch  700/4000 | Loss: 5.4116 | Acc: 0.1672 | LR: 0.000926

Evaluation at epoch 700


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 314.13it/s]



Generated Sudoku:
[[4 5 7 4 1 9 6 5 8]
 [6 7 3 1 4 8 2 9 7]
 [4 5 2 5 5 6 9 1 6]
 [6 1 3 6 6 5 1 9 8]
 [4 4 4 2 3 8 9 8 9]
 [3 5 2 7 8 4 3 2 7]
 [6 6 7 7 2 5 3 8 1]
 [8 9 4 7 3 3 2 9 2]
 [9 8 2 1 1 3 7 6 5]]

Is valid: False

Epoch  710/4000 | Loss: 5.4185 | Acc: 0.1676 | LR: 0.000924
Epoch  720/4000 | Loss: 5.4100 | Acc: 0.1694 | LR: 0.000922
Epoch  730/4000 | Loss: 5.3988 | Acc: 0.1716 | LR: 0.000920
Epoch  740/4000 | Loss: 5.3887 | Acc: 0.1720 | LR: 0.000918
Epoch  750/4000 | Loss: 5.3824 | Acc: 0.1743 | LR: 0.000916

Evaluation at epoch 750


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 335.43it/s]



Generated Sudoku:
[[8 5 4 1 4 4 3 9 4]
 [1 4 1 6 5 7 1 5 9]
 [3 4 1 5 6 2 9 6 8]
 [9 3 9 7 2 7 6 4 7]
 [1 5 2 3 6 8 2 9 7]
 [9 8 5 8 5 3 4 2 2]
 [3 3 5 7 7 8 4 7 1]
 [2 1 6 5 8 2 2 6 8]
 [7 6 1 6 9 3 8 9 9]]

Is valid: False

Epoch  760/4000 | Loss: 5.3986 | Acc: 0.1725 | LR: 0.000914
Epoch  770/4000 | Loss: 5.3785 | Acc: 0.1745 | LR: 0.000911
Epoch  780/4000 | Loss: 5.3683 | Acc: 0.1759 | LR: 0.000909
Epoch  790/4000 | Loss: 5.3612 | Acc: 0.1780 | LR: 0.000907
Epoch  800/4000 | Loss: 5.3448 | Acc: 0.1788 | LR: 0.000905

Evaluation at epoch 800


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 318.46it/s]



Generated Sudoku:
[[3 8 1 5 2 6 1 4 6]
 [6 2 4 5 6 3 8 7 2]
 [6 8 5 9 6 4 4 8 5]
 [2 6 7 8 2 9 5 2 2]
 [9 9 5 1 9 1 7 3 5]
 [4 6 3 4 3 7 9 5 7]
 [1 9 3 5 1 7 2 1 7]
 [9 7 3 4 5 4 2 8 6]
 [3 9 8 8 2 1 8 6 1]]

Is valid: False

Epoch  810/4000 | Loss: 5.3633 | Acc: 0.1778 | LR: 0.000902
Epoch  820/4000 | Loss: 5.3515 | Acc: 0.1790 | LR: 0.000900
Epoch  830/4000 | Loss: 5.3582 | Acc: 0.1796 | LR: 0.000897
Epoch  840/4000 | Loss: 5.3494 | Acc: 0.1806 | LR: 0.000895
Epoch  850/4000 | Loss: 5.3457 | Acc: 0.1806 | LR: 0.000893

Evaluation at epoch 850


Sampling: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 81/81 [00:00<00:00, 328.65it/s]



Generated Sudoku:
[[2 9 3 8 7 5 8 7 5]
 [2 4 7 2 3 2 7 5 2]
 [7 3 6 1 4 2 6 9 2]
 [9 6 7 1 6 9 5 4 4]
 [9 2 5 7 2 4 6 4 7]
 [1 1 6 1 4 9 6 5 8]
 [3 2 6 3 8 7 5 8 3]
 [3 4 8 9 8 8 9 1 5]
 [3 4 9 1 3 5 8 1 1]]

Is valid: False



# ============================================================================
# RESUMING TRAINING
# ============================================================================
# To resume training from a checkpoint:
# 
# 1. Set RESUME_FROM_CHECKPOINT to the path of your checkpoint:
#    RESUME_FROM_CHECKPOINT = "./checkpoints/checkpoint_latest.pt"
#    or
#    RESUME_FROM_CHECKPOINT = "./checkpoints/checkpoint_epoch_500.pt"
# 
# 2. Rerun cells 7 & 8 (model initialization and training)
# 
# The training will automatically:
# - Load the model weights from the checkpoint
# - Load the optimizer state
# - Resume from the saved epoch
# - Continue logging to the same wandb run (if using wandb)
# 
# Your checkpoints are saved in: ./checkpoints/
# - checkpoint_latest.pt: Always contains the most recent checkpoint
# - checkpoint_epoch_N.pt: Checkpoints saved at specific epochs
# 
# Example workflow after SSH disconnection:
# 1. Reconnect to SSH
# 2. Open this notebook
# 3. Run cells 1-2 (config and imports)
# 4. Set: RESUME_FROM_CHECKPOINT = "./checkpoints/checkpoint_latest.pt"
# 5. Run cells 4-7 (load embeddings, dataset, initialize model)
# 6. Run cell 8 (training will resume from checkpoint)
