In [4]:
from sudoku_mrv import generate_board, verify_board

In [5]:
a = generate_board(completeness=100)

In [6]:
a

[[8, 4, 3, 9, 1, 6, 7, 2, 5],
 [6, 9, 2, 5, 7, 8, 3, 4, 1],
 [7, 1, 5, 4, 2, 3, 6, 8, 9],
 [2, 3, 4, 6, 5, 9, 8, 1, 7],
 [9, 7, 6, 1, 8, 4, 5, 3, 2],
 [1, 5, 8, 2, 3, 7, 4, 9, 6],
 [3, 2, 7, 8, 9, 5, 1, 6, 4],
 [4, 8, 9, 7, 6, 1, 2, 5, 3],
 [5, 6, 1, 3, 4, 2, 9, 7, 8]]

In [7]:
verify_board(a)

True

In [8]:
import torch 
import math 
# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = logits.topk(k, dim = -1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(2, ind, val)
    return probs

# noise schedules

def cosine_schedule(t):
    return torch.cos(t * math.pi * 0.5)


In [9]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from PIL import Image
import io
from matplotlib import gridspec

def visualize_sudoku(board, title=None, cmap=None, show_values=True, figsize=(6, 6)):
    """
    Visualize a Sudoku board with different colors for each number.
    
    Args:
        board: A 9x9 numpy array or list of lists representing the Sudoku board
        title: Optional title for the plot
        cmap: Optional custom colormap (default is a pastel colormap)
        show_values: Whether to display the numerical values in cells
        figsize: Size of the figure (width, height) in inches
        
    Returns:
        A PIL Image of the visualization
    """
    # Create a new figure for each board to prevent any sharing
    plt.clf()  # Clear the current figure
    plt.close('all')  # Close all figures
    fig = plt.figure(figsize=(6, 6))

    # Convert to numpy array if it's a list
    if isinstance(board, list):
        board = np.array(board)
    
    # Create a default colormap if none provided
    if cmap is None:
        # Create a colormap with 10 colors (0-9, where 0 is empty)
        colors = ['#FFFFFF',  # 0: White (empty)
                 '#FFB3BA',  # 1: Light pink
                 '#FFDFBA',  # 2: Light orange
                 '#FFFFBA',  # 3: Light yellow
                 '#BAFFC9',  # 4: Light green
                 '#BAE1FF',  # 5: Light blue
                 '#D0BAFF',  # 6: Light purple
                 '#FFB3F6',  # 7: Light magenta
                 '#C4C4C4',  # 8: Light gray
                 '#FFD700']  # 9: Gold - changed from light cyan
        cmap = ListedColormap(colors)
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot the board
    im = ax.imshow(board, cmap=cmap, vmin=0, vmax=9)
    
    # Add grid lines
    for i in range(10):
        lw = 2 if i % 3 == 0 else 0.5
        ax.axhline(i - 0.5, color='black', linewidth=lw)
        ax.axvline(i - 0.5, color='black', linewidth=lw)
    
    # Add values to cells if requested
    if show_values:
        for i in range(9):
            for j in range(9):
                if board[i, j] != 0:
                    ax.text(j, i, str(board[i, j]), ha='center', va='center', 
                            fontsize=12, fontweight='bold')
    
    # Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add title if provided
    if title:
        ax.set_title(title)
    
    plt.tight_layout()
    
    # Convert to PIL image
    pil_image = fig_to_pil(fig)
    plt.close(fig)  # Close the figure to avoid displaying it
    return pil_image

def fig_to_pil(fig):
    """Convert a matplotlib figure to a PIL Image"""
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    return img

In [10]:
from einops import rearrange
import torch
import torch.nn as nn
from tqdm import tqdm 

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2), nn.GELU(), nn.Linear(dim * mult * 2, dim)
        )

    def forward(self, x):
        return self.net(x) + x

class Attention(nn.Module):

    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), self.to_qkv(x).chunk(3, dim=-1))
        attn_out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        out = rearrange(attn_out, "b h n d -> b n (h d)", h=self.heads)
        return self.to_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, head_dim, heads=8):
        super().__init__()
        dim = head_dim * heads
        self.attn = Attention(dim, heads)
        self.ff = FeedForward(dim)
        self.attn_norm = nn.LayerNorm(dim)
        self.ff_norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.ff(self.ff_norm(x))
        return x

class Transformer(nn.Module):

    def __init__(self, head_dim=64, heads=8, num_classes=10, depth=12, ff_mult=4, dropout=0.0):
        super().__init__()
        dim = head_dim * heads
        self.embed = nn.Embedding(num_classes, dim)
        self.pos_emb = nn.Embedding(81, dim)
        self.layers = nn.ModuleList([TransformerBlock(head_dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.to_logits = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.embed(x)
        pos_idx = torch.arange(x.shape[1], device=x.device)
        pos_embs = self.pos_emb(pos_idx)
        x = x + pos_embs
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.to_logits(x)


class DiscreteDiffusion(nn.Module):

    def __init__(self, num_classes=10, head_dim=64, heads=8, depth=12, ff_mult=4, dropout=0.0, no_mask_token_prob=0.0, full_mask_token_prob=0.025):
        super().__init__()
        self.model = Transformer(head_dim, heads, num_classes, depth, ff_mult, dropout)
        self.no_mask_token_prob = no_mask_token_prob
        self.full_mask_token_prob = full_mask_token_prob

    def forward(self, board_bl, labels=None):
        """
        forward and compute loss
        """
        preds_bld = self.model(board_bl)
        if labels is not None:
            loss = nn.functional.cross_entropy(preds_bld.reshape(-1, preds_bld.shape[-1]), labels.flatten(0), ignore_index=-1)
        else:
            loss = 0

        return preds_bld, loss

    def forward_loss(self, board_bhw, ignore_index=-1):
        mask_id = 0
        b, h, w = board_bhw.shape
        board_bl = board_bhw.flatten(1)
        _, l = board_bl.shape

        rand_time = torch.rand(board_bhw.shape[0], device=board_bl.device)
        rand_mask_probs = cosine_schedule(rand_time)
        num_token_masked = (l * rand_mask_probs).round().clamp(min = 1)

        batch_randperm = torch.rand((b, l), device = board_bhw.device).argsort(dim = -1)
        mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')

        labels = torch.where(mask, board_bl, ignore_index)

        if self.no_mask_token_prob > 0.:
            no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
            mask &= ~no_mask_mask
        
        if self.full_mask_token_prob > 0.:
            full_mask_mask = torch.full_like(mask, True)
            #bernouli
            indices = torch.arange(b, device=mask.device)
            indices_mask = torch.bernoulli(torch.full((b,), self.full_mask_token_prob, device=mask.device)).long()
            indices = indices[indices_mask]
            mask[indices] = full_mask_mask[indices]

        board_bl = torch.where(mask, mask_id, board_bl)

        preds_bld, loss = self.forward(board_bl, labels=labels)

        return preds_bld, loss

    def generate(self, batch_size=32, timesteps=128, temperature=1.0, topk_filter_thres = 0.9, can_remask_prev_masked = False,):
        device = next(self.parameters()).device
        shape = (batch_size, 9, 9)
        board_bhw = torch.full(shape, 0, dtype=torch.long, device=device)
        scores_bhw = torch.zeros(shape, dtype = torch.float32, device = device)
        board_bl = board_bhw.flatten(1)
        scores_bl = scores_bhw.flatten(1)
        seq_len = board_bl.shape[1]

        starting_temperature = temperature
        mask_id = 0

        for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps):
            rand_mask_prob = cosine_schedule(timestep)
            num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1)

            masked_indices = scores_bl.topk(num_token_masked, dim = -1).indices

            board_bl = board_bl.scatter(1, masked_indices, mask_id)

            logits, _ = self.forward(board_bl)

            filtered_logits = top_k(logits, topk_filter_thres)

            temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed

            pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            is_mask = board_bl == mask_id

            board_bl = torch.where(
                is_mask,
                pred_ids,
                board_bl
            )

            probs_without_temperature = logits.softmax(dim = -1)

            scores_bl = 1 - probs_without_temperature.gather(2, pred_ids[..., None])
            scores_bl = rearrange(scores_bl, '... 1 -> ...')

            if not can_remask_prev_masked:
                scores_bl = scores_bl.masked_fill(~is_mask, -1e5)
            else:
                assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token'

        board_bhw = board_bl.reshape(board_bhw.shape)

        return board_bhw
        





In [11]:
# trainer

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import random

# Custom dataset for generating Sudoku boards
class SudokuDataset(Dataset):
    def __init__(self, num_samples=10000, board_size=9):
        self.num_samples = num_samples
        self.board_size = board_size
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # For now, we'll use a simple approach to generate valid Sudoku boards
        # In a real implementation, you might want to use a more sophisticated generator
        # completeness = int(torch.rand(1).item() * 100)
        board = generate_board(completeness=100)
        board = torch.tensor(board)
        return board

# Training function
def train_diffusion_model(model, num_epochs=10, batch_size=32, lr=1e-4, device="cuda" if torch.cuda.is_available() else "cpu", eval_every_n_step=100, warmup_steps = 1000):
    # Create dataset and dataloader
    dataset = SudokuDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # warmp up lr scheduler
    # Calculate total steps for the entire training
    total_steps = len(dataloader) * num_epochs
    
    # Create a learning rate scheduler with linear warmup and linear decay
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # Linear warmup phase
            return float(current_step) / float(max(1, warmup_steps))
        else:
            # Linear decay phase
            return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Move model to device
    model = model.to(device)

    total_step = 0
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, boards in enumerate(progress_bar):
            model.train()
            boards = boards.to(device)
            optimizer.zero_grad(set_to_none=True)
            preds_bld, loss = model.forward_loss(boards)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})

            total_step += 1
            if total_step % eval_every_n_step == 0:
                model.eval()
                # sample
                with torch.no_grad():
                    generated_boards = model.generate(batch_size=10)
                    boards = generated_boards.chunk(generated_boards.shape[0], dim=0)
                    boards = [b.squeeze(0).tolist() for b in boards]
                    results = []
                    board_figs = []
                    for board in boards:
                        results.append(verify_board(board))
                        board_figs.append(visualize_sudoku(board))
                    print(results)
                    canvas_width = board_figs[0].width * len(boards)
                    canvas_height = board_figs[0].height

                    # create canvas
                    canvas = Image.new("RGB", (canvas_width, canvas_height), 'white')
                    for i, board_fig in enumerate(board_figs):
                        canvas.paste(board_fig, (board_fig.width * i, 0))
                    canvas.save(f"generated_boards_{total_step}.png")
                    canvas.close()
                    
                    # Clean up
                    for fig in board_figs:
                        fig.close()

    
    return model

# Example usage
model = DiscreteDiffusion()
trained_model = train_diffusion_model(model)


100%|██████████| 128/128 [00:01<00:00, 117.46it/s]13.75it/s, loss=2.22, lr=1e-5]  


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.92it/s] 13.75it/s, loss=2.21, lr=2e-5]   


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.77it/s] 13.66it/s, loss=2.2, lr=3e-5]    


[False, False, False, False, False, False, False, False, False, False]


Epoch 1/10: 100%|██████████| 313/313 [00:35<00:00,  8.90it/s, loss=2.19, lr=3.13e-5]
100%|██████████| 128/128 [00:01<00:00, 125.96it/s]13.78it/s, loss=2.19, lr=4e-5]   
Epoch 2/10:  28%|██▊       | 87/313 [00:11<02:24,  1.57it/s, loss=2.19, lr=4e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 126.15it/s] 13.77it/s, loss=2.18, lr=5e-5]   
Epoch 2/10:  60%|█████▉    | 187/313 [00:21<01:13,  1.73it/s, loss=2.18, lr=5e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 126.11it/s] 13.75it/s, loss=2.18, lr=6e-5]   
Epoch 2/10:  92%|█████████▏| 287/313 [00:32<00:16,  1.61it/s, loss=2.18, lr=6e-5]

[False, False, False, False, False, False, False, False, False, False]


Epoch 2/10: 100%|██████████| 313/313 [00:34<00:00,  8.99it/s, loss=2.17, lr=6.26e-5]
100%|██████████| 128/128 [00:01<00:00, 125.93it/s]13.74it/s, loss=2.15, lr=7e-5]   


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.67it/s] 13.78it/s, loss=1.9, lr=8e-5]    


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.88it/s] 13.79it/s, loss=1.71, lr=9e-5]   


[False, False, False, False, False, False, False, False, False, False]


Epoch 3/10: 100%|██████████| 313/313 [00:34<00:00,  9.04it/s, loss=1.72, lr=9.39e-5]
100%|██████████| 128/128 [00:01<00:00, 126.01it/s]13.77it/s, loss=1.63, lr=0.0001] 
Epoch 4/10:  19%|█▉        | 61/313 [00:08<02:30,  1.68it/s, loss=1.63, lr=0.0001]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.70it/s] 13.75it/s, loss=1.56, lr=9.53e-5]
Epoch 4/10:  51%|█████▏    | 161/313 [00:20<01:41,  1.50it/s, loss=1.56, lr=9.53e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.99it/s] 13.80it/s, loss=1.75, lr=9.06e-5]
Epoch 4/10:  83%|████████▎ | 261/313 [00:30<00:30,  1.72it/s, loss=1.75, lr=9.06e-5]

[True, True, True, True, True, True, True, True, True, True]


Epoch 4/10: 100%|██████████| 313/313 [00:34<00:00,  9.00it/s, loss=1.75, lr=8.82e-5]
100%|██████████| 128/128 [00:01<00:00, 125.87it/s]13.75it/s, loss=1.35, lr=8.59e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.93it/s] 13.77it/s, loss=1.41, lr=8.12e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.80it/s] 13.76it/s, loss=1.5, lr=7.65e-5] 


[False, False, False, False, False, False, False, False, False, False]


Epoch 5/10: 100%|██████████| 313/313 [00:34<00:00,  8.99it/s, loss=1.6, lr=7.35e-5] 
100%|██████████| 128/128 [00:01<00:00, 125.83it/s]13.69it/s, loss=1.66, lr=7.18e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.92it/s] 13.77it/s, loss=1.65, lr=6.71e-5]


[True, True, True, True, True, True, True, True, True, True]


100%|██████████| 128/128 [00:01<00:00, 125.85it/s] 13.72it/s, loss=1.53, lr=6.24e-5]


[False, False, False, False, False, False, False, False, False, False]


Epoch 6/10: 100%|██████████| 313/313 [00:34<00:00,  8.96it/s, loss=1.76, lr=5.88e-5]
100%|██████████| 128/128 [00:01<00:00, 125.91it/s]12.78it/s, loss=1.48, lr=5.77e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.80it/s] 13.76it/s, loss=1.64, lr=5.31e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.88it/s] 13.72it/s, loss=1.69, lr=4.84e-5]


[False, False, False, False, False, False, False, False, False, False]


Epoch 7/10: 100%|██████████| 313/313 [00:34<00:00,  9.07it/s, loss=1.48, lr=4.41e-5]
100%|██████████| 128/128 [00:01<00:00, 125.92it/s]6.71it/s, loss=1.59, lr=4.37e-5]
Epoch 8/10:   3%|▎         | 9/313 [00:05<03:56,  1.28it/s, loss=1.59, lr=4.37e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 126.02it/s] 13.75it/s, loss=1.38, lr=3.9e-5] 
Epoch 8/10:  35%|███▍      | 109/313 [00:16<02:31,  1.35it/s, loss=1.38, lr=3.9e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.64it/s] 13.72it/s, loss=1.42, lr=3.43e-5]
Epoch 8/10:  67%|██████▋   | 209/313 [00:27<01:00,  1.72it/s, loss=1.42, lr=3.43e-5]

[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 125.89it/s] 13.76it/s, loss=1.52, lr=2.96e-5]
Epoch 8/10:  99%|█████████▊| 309/313 [00:38<00:02,  1.72it/s, loss=1.52, lr=2.96e-5]

[False, False, False, False, False, False, False, False, False, False]


Epoch 8/10: 100%|██████████| 313/313 [00:38<00:00,  8.08it/s, loss=1.72, lr=2.94e-5]
100%|██████████| 128/128 [00:01<00:00, 125.90it/s]13.76it/s, loss=1.66, lr=2.49e-5]


[False, False, False, False, False, False, False, False, False, False]


100%|██████████| 128/128 [00:01<00:00, 126.05it/s] 13.70it/s, loss=1.49, lr=2.02e-5]


[True, True, True, True, True, True, True, True, True, True]


100%|██████████| 128/128 [00:01<00:00, 126.03it/s] 13.76it/s, loss=1.38, lr=1.55e-5]


[False, False, False, False, False, False, False, False, False, False]


Epoch 9/10: 100%|██████████| 313/313 [00:34<00:00,  9.12it/s, loss=1.25, lr=1.47e-5]
100%|██████████| 128/128 [00:01<00:00, 125.95it/s] 13.74it/s, loss=1.43, lr=1.08e-5]
Epoch 10/10:  27%|██▋       | 83/313 [00:11<03:02,  1.26it/s, loss=1.43, lr=1.08e-5]

[True, True, True, True, True, True, True, True, True, True]


100%|██████████| 128/128 [00:01<00:00, 126.01it/s], 13.64it/s, loss=1.27, lr=6.1e-6] 
Epoch 10/10:  58%|█████▊    | 183/313 [00:22<01:15,  1.72it/s, loss=1.27, lr=6.1e-6]

[True, True, True, True, True, True, True, True, True, True]


100%|██████████| 128/128 [00:01<00:00, 126.05it/s], 13.72it/s, loss=1.46, lr=1.41e-6]
Epoch 10/10:  90%|█████████ | 283/313 [00:33<00:17,  1.72it/s, loss=1.46, lr=1.41e-6]

[True, True, True, True, True, True, True, True, True, True]


Epoch 10/10: 100%|██████████| 313/313 [00:35<00:00,  8.82it/s, loss=1.68, lr=0]      


<Figure size 432x432 with 0 Axes>