In [1]:
from sudoku_mrv import generate_board, verify_board

In [6]:
a = generate_board(completeness=20, outer_grid_size=9)

In [7]:
a

[[6, 4, 3, 7, 2, 9, 8, 1, 5],
 [9, 2, 5, 0, 0, 0, 4, 0, 0],
 [1, 7, 8, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [4]:
verify_board(a)

True

In [6]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from PIL import Image
import io
from matplotlib import gridspec


def visualize_sudoku(board, title=None, cmap=None, show_values=True, figsize=(6, 6), outer_grid_size=9):
    """
    Visualize a Sudoku board with different colors for each number.
    
    Args:
        board: A 9x9 numpy array or list of lists representing the Sudoku board
        title: Optional title for the plot
        cmap: Optional custom colormap (default is a pastel colormap)
        show_values: Whether to display the numerical values in cells
        figsize: Size of the figure (width, height) in inches
        
    Returns:
        A PIL Image of the visualization
    """
    # Create a new figure for each board to prevent any sharing
    plt.clf()  # Clear the current figure
    plt.close('all')  # Close all figures
    fig = plt.figure(figsize=(6, 6))

    # Convert to numpy array if it's a list
    if isinstance(board, list):
        board = np.array(board)
    
    # Create a default colormap if none provided
    if outer_grid_size == 9:
        if cmap is None:
            # Create a colormap with 10 colors (0-9, where 0 is empty)
            colors = [
                '#FFFFFF',  # 0: White (empty)
                '#FFB3BA',  # 1: Light pink
                '#FFDFBA',  # 2: Light orange
                '#FFFFBA',  # 3: Light yellow
                '#BAFFC9',  # 4: Light green
                '#BAE1FF',  # 5: Light blue
                '#D0BAFF',  # 6: Light purple
                '#FFB3F6',  # 7: Light magenta
                '#C4C4C4',  # 8: Light gray
                '#FFD700',  # 9: Gold - changed from light cyan
            ]  
            cmap = ListedColormap(colors)
    else:
        # Create a colormap with colors from red to blue for numbers 0 to outer_grid_size
        # White for 0 (empty cells)
        colors = ['#FFFFFF']  
        # Linear interpolation from red to blue for numbers 1 to outer_grid_size
        for i in range(outer_grid_size):
            r = int(255 * (outer_grid_size - i) / outer_grid_size)
            b = int(255 * i / outer_grid_size)
            colors.append(f'#{r:02x}00{b:02x}')
        cmap = ListedColormap(colors)
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot the board
    im = ax.imshow(board, cmap=cmap, vmin=0, vmax=outer_grid_size)

    inner_grid_size = int(outer_grid_size ** 0.5)
    
    # Add grid lines
    for i in range(outer_grid_size+1):
        lw = 2 if i % inner_grid_size == 0 else 0.5
        ax.axhline(i - 0.5, color='black', linewidth=lw)
        ax.axvline(i - 0.5, color='black', linewidth=lw)
    
    # Add values to cells if requested
    if show_values:
        for i in range(outer_grid_size):
            for j in range(outer_grid_size):
                if board[i, j] != 0:
                    ax.text(j, i, str(board[i, j]), ha='center', va='center', 
                            fontsize=12, fontweight='bold')
    
    # Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add title if provided
    if title:
        ax.set_title(title)
    
    plt.tight_layout()
    
    # Convert to PIL image
    pil_image = fig_to_pil(fig)
    plt.close(fig)  # Close the figure to avoid displaying it
    return pil_image


def fig_to_pil(fig):
    """Convert a matplotlib figure to a PIL Image"""
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    return img

In [10]:
# trainer

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import random

# Custom dataset for generating Sudoku boards
class SudokuDataset(Dataset):
    def __init__(self, num_samples=10000, board_size=9):
        self.num_samples = num_samples
        self.board_size = board_size
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # For now, we'll use a simple approach to generate valid Sudoku boards
        # In a real implementation, you might want to use a more sophisticated generator
        # completeness = int(torch.rand(1).item() * 100)
        board = generate_board(completeness=100, outer_grid_size=self.board_size)
        board = torch.tensor(board)
        return board

# Training function
def train_diffusion_model(
    model, 
    outer_grid_size=9, 
    num_epochs=10, 
    batch_size=32, 
    lr=1e-4, 
    device="cuda" if torch.cuda.is_available() else "mps", 
    eval_every_n_step=100, 
    warmup_steps = 200, 
    compiled=False
):
    # Create dataset and dataloader
    dataset = SudokuDataset(board_size=outer_grid_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.997), weight_decay=0.01)
    
    # warmp up lr scheduler
    # Calculate total steps for the entire training
    total_steps = len(dataloader) * num_epochs
    
    # Create a learning rate scheduler with linear warmup and linear decay
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # Linear warmup phase
            return float(current_step) / float(max(1, warmup_steps))
        else:
            # Linear decay phase
            return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Move model to device
    model = model.to(device)
    if compiled:
        model = torch.compile(model)

    total_step = 0
    
    # Training loop
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}")
        for batch_idx, boards in enumerate(progress_bar):
            model.train()
            boards = boards.to(device)
            optimizer.zero_grad(set_to_none=True)
            preds_bld, loss = model.forward_loss(boards)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})

            total_step += 1
            if total_step % eval_every_n_step == 0:
                model.eval()
                # sample
                with torch.no_grad():
                    generated_boards = model.generate(batch_size=10)
                    boards = generated_boards.chunk(generated_boards.shape[0], dim=0)
                    boards = [b.squeeze(0).tolist() for b in boards]
                    results = []
                    board_figs = []
                    for board in boards:
                        results.append(verify_board(board, outer_grid_size=model.outer_grid_size))
                        board_figs.append(visualize_sudoku(board, outer_grid_size=model.outer_grid_size))
                    print(results)
                    canvas_width = board_figs[0].width * len(boards)
                    canvas_height = board_figs[0].height

                    # create canvas
                    canvas = Image.new("RGB", (canvas_width, canvas_height), 'white')
                    for i, board_fig in enumerate(board_figs):
                        canvas.paste(board_fig, (board_fig.width * i, 0))
                    canvas.save(f"generated_boards_{total_step}.png")
                    canvas.close()
                    
                    # Clean up
                    for fig in board_figs:
                        fig.close()
    
    return model

# Example usage
outer_grid_size = 9
model = DiscreteDiffusion(outer_grid_size=outer_grid_size)
trained_model = train_diffusion_model(model, outer_grid_size=outer_grid_size)

Epoch 1/10:  32%|████████████████▊                                    | 99/313 [00:48<01:28,  2.41it/s, loss=2.2, lr=5e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  1%|▋                                                                                    | 1/128 [00:02<05:25,  2.56s/it][A
  2%|█▉                                                                                   | 3/128 [00:02<01:28,  1.41it/s][A
  6%|█████▎                                                                               | 8/128 [00:02<00:25,  4.76it/s][A
 10%|████████▌                                                                           | 13/128 [00:02<00:13,  8.74it/s][A
 14%|███████████▊                                                                        | 18/128 [00:03<00:08, 13.13it/s][A
 18%|███████████████                                                                     | 23/128 [00:03<00:05, 17.83it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 1/10:  64%|███████████████████████████████▏                 | 199/313 [01:37<00:47,  2.38it/s, loss=2.19, lr=0.0001]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 49.75it/s][A
  9%|███████▏                                                                            | 11/128 [00:00<00:02, 47.80it/s][A
 12%|██████████▌                                                                         | 16/128 [00:00<00:02, 47.08it/s][A
 16%|█████████████▊                                                                      | 21/128 [00:00<00:02, 46.41it/s][A
 20%|█████████████████                                                                   | 26/128 [00:00<00:02, 46.09it/s][A
 24%|████████████████████▎                                                               | 31/128 [00:00<00:02, 45.87it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 1/10:  96%|█████████████████████████████████████████████▊  | 299/313 [02:23<00:05,  2.40it/s, loss=2.18, lr=9.66e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.80it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.39it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.56it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.02it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.88it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.86it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 1/10: 100%|████████████████████████████████████████████████| 313/313 [02:33<00:00,  2.04it/s, loss=2.19, lr=9.61e-5]
Epoch 2/10:  27%|█████████████▍                                   | 86/313 [00:36<01:37,  2.33it/s, loss=2.15, lr=9.32e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.23it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.65it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.70it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.16it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.82it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 2/10:  59%|████████████████████████████▌                   | 186/313 [01:22<00:52,  2.40it/s, loss=1.94, lr=8.98e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.17it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 46.96it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.40it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 45.90it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.53it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.51it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 2/10:  91%|███████████████████████████████████████████▊    | 286/313 [02:07<00:11,  2.38it/s, loss=1.63, lr=8.63e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.63it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.68it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.84it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.61it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 46.25it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.92it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 2/10: 100%|████████████████████████████████████████████████| 313/313 [02:22<00:00,  2.20it/s, loss=1.82, lr=8.55e-5]
Epoch 3/10:  23%|███████████▍                                     | 73/313 [00:31<01:38,  2.42it/s, loss=1.76, lr=8.29e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.87it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.56it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.57it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.14it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.82it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 3/10:  55%|██████████████████████████▌                     | 173/313 [01:16<00:57,  2.45it/s, loss=1.46, lr=7.95e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.81it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.81it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.91it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.43it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.62it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.10it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 3/10:  87%|█████████████████████████████████████████▊      | 273/313 [02:02<00:16,  2.38it/s, loss=1.62, lr=7.61e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 51.11it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 47.20it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 47.31it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 46.87it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 46.60it/s][A
 25%|█████████████████████                                                               | 32/128 [00:00<00:02, 46.48it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 3/10: 100%|████████████████████████████████████████████████| 313/313 [02:23<00:00,  2.18it/s, loss=1.78, lr=7.48e-5]
Epoch 4/10:  19%|█████████▍                                       | 60/313 [00:26<01:45,  2.39it/s, loss=1.55, lr=7.27e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.90it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.33it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.61it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.23it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 46.05it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 4/10:  51%|████████████████████████▌                       | 160/313 [01:12<01:04,  2.38it/s, loss=1.61, lr=6.93e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 50.75it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 47.72it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 46.39it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 45.88it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 45.59it/s][A
 25%|█████████████████████                                                               | 32/128 [00:00<00:02, 45.74it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 4/10:  83%|███████████████████████████████████████▊        | 260/313 [01:58<00:22,  2.36it/s, loss=1.69, lr=6.59e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 46.00it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 44.96it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 44.49it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 44.80it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.00it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 44.41it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 4/10: 100%|████████████████████████████████████████████████| 313/313 [02:25<00:00,  2.16it/s, loss=1.56, lr=6.41e-5]
Epoch 5/10:  15%|███████▎                                         | 47/313 [00:20<01:52,  2.36it/s, loss=1.55, lr=6.25e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 50.27it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 45.84it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 44.88it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 44.14it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 43.99it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 5/10:  47%|███████████████████████▍                          | 147/313 [01:06<01:11,  2.32it/s, loss=1.7, lr=5.9e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.95it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 45.39it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 44.89it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 44.66it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 44.45it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 44.76it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 5/10:  79%|█████████████████████████████████████▉          | 247/313 [01:53<00:27,  2.36it/s, loss=1.54, lr=5.56e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 50.67it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 47.45it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 45.72it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 45.56it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 45.47it/s][A
 25%|█████████████████████                                                               | 32/128 [00:00<00:02, 45.34it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 5/10: 100%|████████████████████████████████████████████████| 313/313 [02:26<00:00,  2.14it/s, loss=1.31, lr=5.34e-5]
Epoch 6/10:  11%|█████▎                                           | 34/313 [00:14<01:57,  2.37it/s, loss=1.48, lr=5.22e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.40it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 46.92it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.73it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.41it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.99it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 6/10:  43%|████████████████████▌                           | 134/313 [01:01<01:15,  2.39it/s, loss=1.61, lr=4.88e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.06it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.18it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.24it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.13it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.63it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.45it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 6/10:  75%|███████████████████████████████████▉            | 234/313 [01:47<00:33,  2.38it/s, loss=1.64, lr=4.54e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.06it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.33it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 45.79it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 45.65it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.63it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.53it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 6/10: 100%|████████████████████████████████████████████████| 313/313 [02:23<00:00,  2.18it/s, loss=1.69, lr=4.27e-5]
Epoch 7/10:   7%|███▎                                              | 21/313 [00:09<02:01,  2.41it/s, loss=1.59, lr=4.2e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.78it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.60it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.42it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.09it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.17it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 7/10:  39%|██████████████████▌                             | 121/313 [00:55<01:22,  2.34it/s, loss=1.66, lr=3.86e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.68it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 45.31it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 43.99it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 43.01it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 42.25it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 42.08it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 7/10:  71%|█████████████████████████████████▉              | 221/313 [01:41<00:39,  2.34it/s, loss=1.69, lr=3.52e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.31it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 46.58it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 45.53it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 45.66it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.51it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.51it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 7/10: 100%|█████████████████████████████████████████████████| 313/313 [02:23<00:00,  2.17it/s, loss=1.52, lr=3.2e-5]
Epoch 8/10:   3%|█▎                                                | 8/313 [00:03<02:07,  2.39it/s, loss=1.67, lr=3.17e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 49.06it/s][A
  9%|███████▏                                                                            | 11/128 [00:00<00:02, 47.54it/s][A
 12%|██████████▌                                                                         | 16/128 [00:00<00:02, 47.01it/s][A
 16%|█████████████▊                                                                      | 21/128 [00:00<00:02, 46.67it/s][A
 20%|█████████████████                                                                   | 26/128 [00:00<00:02, 46.34it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 8/10:  35%|████████████████▉                                | 108/313 [00:49<01:25,  2.39it/s, loss=1.5, lr=2.83e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.79it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.24it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.28it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 45.68it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.39it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.12it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 8/10:  66%|███████████████████████████████▉                | 208/313 [01:35<00:43,  2.41it/s, loss=1.54, lr=2.49e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.94it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.73it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.53it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.23it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.99it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.46it/s

[False, False, False, False, False, False, False, False, False, False]


Epoch 8/10:  98%|███████████████████████████████████████████████▏| 308/313 [02:21<00:02,  2.40it/s, loss=1.49, lr=2.15e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 49.51it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 47.46it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 46.54it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 46.03it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.88it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 45.70it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 8/10: 100%|████████████████████████████████████████████████| 313/313 [02:26<00:00,  2.13it/s, loss=1.44, lr=2.14e-5]
Epoch 9/10:  30%|███████████████▏                                  | 95/313 [00:40<01:30,  2.40it/s, loss=1.6, lr=1.81e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 48.89it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 46.62it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 45.63it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 45.21it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 45.18it/s][

[False, False, False, False, False, False, False, False, False, False]


Epoch 9/10:  62%|█████████████████████████████▉                  | 195/313 [01:27<00:54,  2.18it/s, loss=1.42, lr=1.47e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  4%|███▎                                                                                 | 5/128 [00:00<00:02, 48.34it/s][A
  8%|██████▌                                                                             | 10/128 [00:00<00:02, 45.32it/s][A
 12%|█████████▊                                                                          | 15/128 [00:00<00:02, 43.83it/s][A
 16%|█████████████▏                                                                      | 20/128 [00:00<00:02, 44.45it/s][A
 20%|████████████████▍                                                                   | 25/128 [00:00<00:02, 43.20it/s][A
 23%|███████████████████▋                                                                | 30/128 [00:00<00:02, 43.16it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 9/10:  94%|█████████████████████████████████████████████▏  | 295/313 [02:13<00:07,  2.39it/s, loss=1.47, lr=1.13e-5]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 51.07it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 47.08it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 45.92it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 45.48it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 45.29it/s][A
 25%|█████████████████████                                                               | 32/128 [00:00<00:02, 45.33it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 9/10: 100%|████████████████████████████████████████████████| 313/313 [02:24<00:00,  2.17it/s, loss=1.58, lr=1.07e-5]
Epoch 10/10:  26%|████████████▌                                   | 82/313 [00:35<01:39,  2.33it/s, loss=1.65, lr=7.85e-6]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 49.69it/s][A
  9%|███████▏                                                                            | 11/128 [00:00<00:02, 47.58it/s][A
 12%|██████████▌                                                                         | 16/128 [00:00<00:02, 46.17it/s][A
 16%|█████████████▊                                                                      | 21/128 [00:00<00:02, 43.52it/s][A
 20%|█████████████████                                                                   | 26/128 [00:00<00:02, 43.99it/s][

[True, True, True, True, True, True, True, True, True, True]


Epoch 10/10:  58%|███████████████████████████▎                   | 182/313 [01:22<00:55,  2.37it/s, loss=1.75, lr=4.44e-6]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 49.10it/s][A
  9%|███████▏                                                                            | 11/128 [00:00<00:02, 46.74it/s][A
 12%|██████████▌                                                                         | 16/128 [00:00<00:02, 45.66it/s][A
 16%|█████████████▊                                                                      | 21/128 [00:00<00:02, 45.26it/s][A
 20%|█████████████████                                                                   | 26/128 [00:00<00:02, 45.15it/s][A
 24%|████████████████████▎                                                               | 31/128 [00:00<00:02, 44.84it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 10/10:  90%|██████████████████████████████████████████▎    | 282/313 [02:09<00:13,  2.37it/s, loss=1.76, lr=1.02e-6]
  0%|                                                                                             | 0/128 [00:00<?, ?it/s][A
  5%|███▉                                                                                 | 6/128 [00:00<00:02, 50.24it/s][A
  9%|███████▉                                                                            | 12/128 [00:00<00:02, 47.41it/s][A
 13%|███████████▏                                                                        | 17/128 [00:00<00:02, 45.94it/s][A
 17%|██████████████▍                                                                     | 22/128 [00:00<00:02, 45.47it/s][A
 21%|█████████████████▋                                                                  | 27/128 [00:00<00:02, 44.73it/s][A
 25%|█████████████████████                                                               | 32/128 [00:00<00:02, 44.69it/s

[True, True, True, True, True, True, True, True, True, True]


Epoch 10/10: 100%|█████████████████████████████████████████████████████| 313/313 [02:26<00:00,  2.14it/s, loss=1.87, lr=0]


<Figure size 600x600 with 0 Axes>