In [70]:
from sudoku_mrv import generate_board, verify_board

In [71]:
a = generate_board(completeness=100)

In [72]:
a

[[5, 1, 6, 3, 7, 4, 2, 8, 9],
 [3, 4, 7, 9, 8, 2, 5, 6, 1],
 [2, 8, 9, 6, 1, 5, 4, 3, 7],
 [8, 9, 2, 7, 5, 3, 6, 1, 4],
 [4, 3, 1, 8, 2, 6, 9, 7, 5],
 [7, 6, 5, 1, 4, 9, 3, 2, 8],
 [9, 7, 3, 5, 6, 1, 8, 4, 2],
 [1, 5, 4, 2, 3, 8, 7, 9, 6],
 [6, 2, 8, 4, 9, 7, 1, 5, 3]]

In [73]:
verify_board(a)

True

In [None]:
from einops import rearrange
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2), nn.GELU(), nn.Linear(dim * mult * 2, dim)
        )

    def forward(self, x):
        return self.net(x) + x

class Attention(nn.Module):

    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.to_qkv = nn.Linear(dim, dim * 3 * heads, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), self.to_qkv(x).chunk(3, dim=-1))
        attn_out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        out = rearrange(attn_out, "b h n d -> b n (h d)", h=self.heads)
        return self.to_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.attn = Attention(dim, heads)
        self.ff = FeedForward(dim)
        self.attn_norm = nn.LayerNorm(dim)
        self.ff_norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.ff(self.ff_norm(x))
        return x

class Transformer(nn.Module):

    def __init__(self, head_dim=64, heads=8, num_classes=10, depth=12, ff_mult=4, dropout=0.0):
        super().__init__()
        self.embed = nn.Embedding(num_classes, head_dim)
        self.layers = nn.ModuleList([TransformerBlock(head_dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(head_dim)
        self.to_logits = nn.Linear(head_dim, num_classes)

    def forward(self, x):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.to_logits(x)


class DiscreteDiffusion(nn.Module):

    def __init__(self, model, num_classes=10, head_dim=64, heads=8, depth=12, ff_mult=4, dropout=0.0):
        super().__init__()
        self.model = Transformer(head_dim, heads, num_classes, depth, ff_mult, dropout)

    def forward(self, board_bhw, labels=None):
        """
        forward and compute loss
        """
        b, h, w = board_bhw.shape
        board_bl = board_bhw.flatten(1)
        preds_bl = self.model(board_bl)
        if labels is not None:
            loss = nn.functional.cross_entropy(preds_bl, labels.flatten(1))
        else:
            loss = 0

        return preds_bl, loss

    def gen_sample(self, shape=(1, 9, 9)):
        board = torch.zeros(shape)



