In [None]:
!pip install torchvision



In [None]:
import numpy as np
import argparse
import os

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def quantisize(image, levels):
    return np.digitize(image, np.arange(levels) / levels) - 1


def str2bool(s):
    if isinstance(s, bool):
        return s
    if s.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif s.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected')


def nearest_square(num):
    return round(np.sqrt(num))**2


def save_samples(samples, dirname, filename):
    if not os.path.exists(dirname):
        os.mkdir(dirname)

    count = samples.size()[0]

    count_sqrt = int(count ** 0.5)
    if count_sqrt ** 2 == count:
        nrow = count_sqrt
    else:
        nrow = count

    save_image(samples, os.path.join(dirname, filename), nrow=nrow)

def duplicate_image_horizontal(image):
    """Duplicate the image horizontally to make it 6 times as wide."""
    return torch.cat([image] * 6, dim=-1)

def get_loaders(dataset_name, batch_size, color_levels, train_root, test_root):
    normalize = transforms.Lambda(lambda image: np.array(image) / 255)

    discretize = transforms.Compose([
        transforms.Lambda(lambda image: quantisize(image, color_levels)),
        transforms.ToTensor()
    ])

    to_rgb = transforms.Compose([
        discretize,
        transforms.Lambda(lambda image_tensor: image_tensor.repeat(3, 1, 1)),  # Convert grayscale to RGB by repeating channels
        transforms.Lambda(duplicate_image_horizontal)  # Apply the duplication transform
    ])

    dataset_mappings = {'mnist': 'MNIST', 'fashionmnist': 'FashionMNIST', 'cifar': 'CIFAR10'}
    transform_mappings = {'mnist': to_rgb, 'fashionmnist': to_rgb, 'cifar': transforms.Compose([normalize, discretize])}
    hw_mappings = {'mnist': (28, 28), 'fashionmnist': (28, 28), 'cifar': (32, 32)}

    try:
        dataset = dataset_mappings[dataset_name]
        transform = transform_mappings[dataset_name]

        train_dataset = getattr(datasets, dataset)(root=train_root, train=True, download=True, transform=transform)
        test_dataset = getattr(datasets, dataset)(root=test_root, train=False, download=True, transform=transform)

        h, w = hw_mappings[dataset_name]
        w *= 6  # Update width to reflect the 6x duplication
    except KeyError:
        raise AttributeError("Unsupported dataset")

    print(f"train: {train_dataset}")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=True)

    return train_loader, test_loader, h, w

In [None]:
!git clone https://github.com/RobinXL/Handwritten-Math-Equation-Image-Generator.git


fatal: destination path 'Handwritten-Math-Equation-Image-Generator' already exists and is not an empty directory.


In [None]:
### Custom euqation loader

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from PIL import Image
import random

equations = {}

def quantisize(image, levels):
    return np.digitize(image, np.arange(levels) / levels) - 1

def duplicate_image_horizontal(image):
    """Duplicate the image horizontally to make it 6 times as wide."""
    return torch.cat([image] * 6, dim=-1)

def generate_equation_image(digit_images, equation):
    """Generate a single image representing an equation, with images concatenated horizontally."""
    equation_image = Image.new('L', (28 * 6, 28), color=255)
    for i, char in enumerate(equation):
        img = random.choice(digit_images[char])  # Randomly pick an image for the digit/symbol
        equation_image.paste(img, (i * 28, 0))  # Paste at the correct position
    return equation_image

def generate_equation_dataset(digit_images, output_dir, num_samples):
    """Generate images for all possible equations and save them to the output directory."""
    os.makedirs(output_dir, exist_ok=True)
    num = 0
    for i in range(1, 10):
        for j in range(1, 10):
            result = f"{i + j:02d}"  # Two-digit result format
            equation = f"{i}+{j}={result}"
            equations[equation]= num
            num += 1
            equation_folder = os.path.join(output_dir, equation)
            os.makedirs(equation_folder, exist_ok=True)
            for sample_num in range(num_samples):
                img = generate_equation_image(digit_images, equation)
                img.save(os.path.join(equation_folder, f"{sample_num}.png"))
    print(len(equations))

class EquationDataset(Dataset):
    """Custom dataset for loading generated equation images."""
    def __init__(self, data_dir, color_levels, transform=None):
        self.data_dir = data_dir
        self.color_levels = color_levels
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Load all images and labels
        for label_folder in os.listdir(data_dir):
            folder_path = os.path.join(data_dir, label_folder)
            if os.path.isdir(folder_path):
                # Extract just the result part of the equation as a label
                # For example, "3x4=12" -> 12
                result = int(label_folder.split('=')[-1])  # Convert result to integer
                for img_name in os.listdir(folder_path):
                    self.image_paths.append(os.path.join(folder_path, img_name))
                    self.labels.append(result)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("L")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label  # Returns image and numerical label


def get_custom_loader(batch_size,color_levels,data_dir = "/content/Handwritten-Math-Equation-Image-Generator/data", num_train_samples=6000, num_test_samples=1000):
    """Generate and load the equation dataset with transformations."""
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((28, 28)),  # Ensure each digit image is 28x28
        transforms.Lambda(lambda image: quantisize(np.array(image) / 255, color_levels)),
        transforms.ToTensor(),
        transforms.Lambda(lambda image_tensor: image_tensor.repeat(3, 1, 1)),  # Convert to RGB
        transforms.Lambda(duplicate_image_horizontal)
    ])

    # Generate datasets
    digit_images = load_images(data_dir, (28, 28))  # Load all digit and symbol images into memory
    train_output_dir = '/content/generated_train'
    test_output_dir = '/content/generated_test'

    generate_equation_dataset(digit_images, train_output_dir, num_train_samples)
    generate_equation_dataset(digit_images, test_output_dir, num_test_samples)

    # Load datasets
    train_dataset = EquationDataset(train_output_dir, color_levels, transform=transform)
    test_dataset = EquationDataset(test_output_dir, color_levels, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader, 28, 168

def load_images(data_dir, target_size):
    """
    Load images and organize by label in a dictionary.
    Each key is a digit or symbol (0-9, x, =) and value is a list of PIL Images.
    """
    image_dict = {}
    transform = transforms.Compose([transforms.Resize(target_size), transforms.ToTensor()])

    for label in os.listdir(data_dir):
        label_path = os.path.join(data_dir, label)
        if os.path.isdir(label_path):
            image_dict[label] = []
            for img_name in os.listdir(label_path):
                img_path = os.path.join(label_path, img_name)
                try:
                    img = Image.open(img_path).convert("L")
                    img = transform(img)
                    img = transforms.ToPILImage()(img)
                    image_dict[label].append(img)
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")

    return image_dict

global_train_loader, global_test_loader, global_HEIGHT, global_WIDTH = get_custom_loader(32, 2)


Error loading image /content/Handwritten-Math-Equation-Image-Generator/data/+/.DS_Store: cannot identify image file '/content/Handwritten-Math-Equation-Image-Generator/data/+/.DS_Store'
Error loading image /content/Handwritten-Math-Equation-Image-Generator/data/times/.DS_Store: cannot identify image file '/content/Handwritten-Math-Equation-Image-Generator/data/times/.DS_Store'
Error loading image /content/Handwritten-Math-Equation-Image-Generator/data/div/.DS_Store: cannot identify image file '/content/Handwritten-Math-Equation-Image-Generator/data/div/.DS_Store'
Error loading image /content/Handwritten-Math-Equation-Image-Generator/data/0/.DS_Store: cannot identify image file '/content/Handwritten-Math-Equation-Image-Generator/data/0/.DS_Store'
Error loading image /content/Handwritten-Math-Equation-Image-Generator/data/-/.DS_Store: cannot identify image file '/content/Handwritten-Math-Equation-Image-Generator/data/-/.DS_Store'
Error loading image /content/Handwritten-Math-Equation-Ima

In [None]:
import torch
import torch.nn as nn

import numpy as np


class CroppedConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super(CroppedConv2d, self).__init__(*args, **kwargs)

    def forward(self, x):
        x = super(CroppedConv2d, self).forward(x)

        kernel_height, _ = self.kernel_size
        res = x[:, :, 1:-kernel_height, :]
        shifted_up_res = x[:, :, :-kernel_height-1, :]

        return res, shifted_up_res


class MaskedConv2d(nn.Conv2d):
    def __init__(self, *args, mask_type, data_channels, **kwargs):
        super(MaskedConv2d, self).__init__(*args, **kwargs)

        assert mask_type in ['A', 'B'], 'Invalid mask type.'

        out_channels, in_channels, height, width = self.weight.size()
        yc, xc = height // 2, width // 2

        mask = np.zeros(self.weight.size(), dtype=np.float32)
        mask[:, :, :yc, :] = 1
        mask[:, :, yc, :xc + 1] = 1

        def cmask(out_c, in_c):
            a = (np.arange(out_channels) % data_channels == out_c)[:, None]
            b = (np.arange(in_channels) % data_channels == in_c)[None, :]
            return a * b

        for o in range(data_channels):
            for i in range(o + 1, data_channels):
                mask[cmask(o, i), yc, xc] = 0

        if mask_type == 'A':
            for c in range(data_channels):
                mask[cmask(c, c), yc, xc] = 0

        mask = torch.from_numpy(mask).float()

        self.register_buffer('mask', mask)

    def forward(self, x):
        self.weight.data *= self.mask
        x = super(MaskedConv2d, self).forward(x)
        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CausalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, data_channels):
        super(CausalBlock, self).__init__()
        self.split_size = out_channels

        self.v_conv = CroppedConv2d(in_channels,
                                    2 * out_channels,
                                    (kernel_size // 2 + 1, kernel_size),
                                    padding=(kernel_size // 2 + 1, kernel_size // 2))
        self.v_fc = nn.Conv2d(in_channels,
                              2 * out_channels,
                              (1, 1))
        self.v_to_h = nn.Conv2d(2 * out_channels,
                                2 * out_channels,
                                (1, 1))

        self.h_conv = MaskedConv2d(in_channels,
                                   2 * out_channels,
                                   (1, kernel_size),
                                   mask_type='A',
                                   data_channels=data_channels,
                                   padding=(0, kernel_size // 2))
        self.h_fc = MaskedConv2d(out_channels,
                                 out_channels,
                                 (1, 1),
                                 mask_type='A',
                                 data_channels=data_channels)

    def forward(self, image):
        v_out, v_shifted = self.v_conv(image)
        v_out += self.v_fc(image)
        v_out_tanh, v_out_sigmoid = torch.split(v_out, self.split_size, dim=1)
        v_out = torch.tanh(v_out_tanh) * torch.sigmoid(v_out_sigmoid)

        h_out = self.h_conv(image)
        v_shifted = self.v_to_h(v_shifted)
        h_out += v_shifted
        h_out_tanh, h_out_sigmoid = torch.split(h_out, self.split_size, dim=1)
        h_out = torch.tanh(h_out_tanh) * torch.sigmoid(h_out_sigmoid)
        h_out = self.h_fc(h_out)

        return v_out, h_out


class GatedBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, data_channels):
        super(GatedBlock, self).__init__()
        self.split_size = out_channels

        self.v_conv = CroppedConv2d(in_channels,
                                    2 * out_channels,
                                    (kernel_size // 2 + 1, kernel_size),
                                    padding=(kernel_size // 2 + 1, kernel_size // 2))
        self.v_fc = nn.Conv2d(in_channels,
                              2 * out_channels,
                              (1, 1))
        self.v_to_h = MaskedConv2d(2 * out_channels,
                                   2 * out_channels,
                                   (1, 1),
                                   mask_type='B',
                                   data_channels=data_channels)

        self.h_conv = MaskedConv2d(in_channels,
                                   2 * out_channels,
                                   (1, kernel_size),
                                   mask_type='B',
                                   data_channels=data_channels,
                                   padding=(0, kernel_size // 2))
        self.h_fc = MaskedConv2d(out_channels,
                                 out_channels,
                                 (1, 1),
                                 mask_type='B',
                                 data_channels=data_channels)

        self.h_skip = MaskedConv2d(out_channels,
                                   out_channels,
                                   (1, 1),
                                   mask_type='B',
                                   data_channels=data_channels)

        self.label_embedding = nn.Embedding(81, 2*out_channels)

    def forward(self, x):
        v_in, h_in, skip, label = x[0], x[1], x[2], x[3]

        label_embedded = self.label_embedding(label).unsqueeze(2).unsqueeze(3)

        v_out, v_shifted = self.v_conv(v_in)
        v_out += self.v_fc(v_in)
        v_out += label_embedded
        v_out_tanh, v_out_sigmoid = torch.split(v_out, self.split_size, dim=1)
        v_out = torch.tanh(v_out_tanh) * torch.sigmoid(v_out_sigmoid)

        h_out = self.h_conv(h_in)
        v_shifted = self.v_to_h(v_shifted)
        h_out += v_shifted
        h_out += label_embedded
        h_out_tanh, h_out_sigmoid = torch.split(h_out, self.split_size, dim=1)
        h_out = torch.tanh(h_out_tanh) * torch.sigmoid(h_out_sigmoid)

        # skip connection
        skip = skip + self.h_skip(h_out)

        h_out = self.h_fc(h_out)

        # residual connections
        h_out = h_out + h_in
        v_out = v_out + v_in

        return {0: v_out, 1: h_out, 2: skip, 3: label}


class PixelCNN(nn.Module):
    def __init__(self, cfg):
        super(PixelCNN, self).__init__()

        DATA_CHANNELS = 3

        self.hidden_fmaps = cfg["hidden_fmaps"]
        self.color_levels = cfg["color_levels"]

        self.causal_conv = CausalBlock(DATA_CHANNELS,
                                       cfg["hidden_fmaps"],
                                       cfg["causal_ksize"],
                                       data_channels=DATA_CHANNELS)

        self.hidden_conv = nn.Sequential(
            *[GatedBlock(cfg["hidden_fmaps"], cfg["hidden_fmaps"], cfg["hidden_ksize"], DATA_CHANNELS) for _ in range(cfg["hidden_layers"])]
        )

        size = len(equations)
        self.label_embedding = nn.Embedding(81, self.hidden_fmaps)
        # self.vocabulary = set()
        # for equation in equations:
        #     self.vocabulary.update(tokenize_equation(equation))
        # self.vocabulary = list(self.vocabulary)
        # self.token_to_index = {token: index for index, token in enumerate(self.vocabulary)}
        # self.embedding_layer = torch.nn.Embedding(len(self.vocabulary), 4)


        self.out_hidden_conv = MaskedConv2d(cfg["hidden_fmaps"],
                                            cfg["out_hidden_fmaps"],
                                            (1, 1),
                                            mask_type='B',
                                            data_channels=DATA_CHANNELS)

        self.out_conv = MaskedConv2d(cfg["out_hidden_fmaps"],
                                     DATA_CHANNELS * cfg["color_levels"],
                                     (1, 1),
                                     mask_type='B',
                                     data_channels=DATA_CHANNELS)

    def forward(self, image, label):
        count, data_channels, height, width = image.size()

        v, h = self.causal_conv(image)

        _, _, out, _ = self.hidden_conv({0: v,
                                         1: h,
                                         2: image.new_zeros((count, self.hidden_fmaps, height, width), requires_grad=True),
                                         3: label}).values()

        # label = equations[label]
        # label = torch.tensor([label]).to(image.device)
        label_embedded = self.label_embedding(label).unsqueeze(2).unsqueeze(3)
        # tokens = tokenize_equation(label)
        # token_indices = [self.token_to_index[token] for token in tokens]
        # embedded_equation = self.embedding_layer(torch.tensor(token_indices))


        # add label bias
        out += label_embedded
        out = F.relu(out)
        out = F.relu(self.out_hidden_conv(out))
        out = self.out_conv(out)

        out = out.view(count, self.color_levels, data_channels, height, width)

        return out

    def sample(self, shape, count, label=None, device='cuda'):
        channels, height, width = shape

        samples = torch.zeros(count, *shape).to(device)
        if label is None:
            labels = torch.randint(high=10, size=(count,)).to(device)
        else:
            labels = (label*torch.ones(count)).to(device).long()

        with torch.no_grad():
            for i in range(height):
                for j in range(width):
                    for c in range(channels):
                        unnormalized_probs = self.forward(samples, labels)
                        pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
                        sampled_levels = torch.multinomial(pixel_probs, 1).squeeze().float() / (self.color_levels - 1)
                        samples[:, c, i, j] = sampled_levels

        return samples


In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np

import argparse
import os

from tqdm import tqdm
import wandb


TRAIN_DATASET_ROOT = '.data/train/'
TEST_DATASET_ROOT = '.data/test/'

MODEL_PARAMS_OUTPUT_DIR = 'model'
MODEL_PARAMS_OUTPUT_FILENAME = 'params.pth'

TRAIN_SAMPLES_DIR = 'train_samples'

import re

# def tokenize_equation(equation):
#     """Tokenizes an equation string into individual tokens of numbers and operators."""
#     # Split the equation into tokens (numbers, operators, etc.)
#     tokens = re.findall(r'\d+|\+|=|-|\*|\/', equation)
#     return tokens

def train(cfg, model, device, train_loader, optimizer, scheduler, epoch):
    model.train()

    for images, labels in tqdm(train_loader, desc='Epoch {}/{}'.format(epoch + 1, cfg["epochs"])):
        optimizer.zero_grad()

        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        normalized_images = images.float() / (cfg["color_levels"] - 1)


        outputs = model(normalized_images, labels)
        loss = F.cross_entropy(outputs, images)
        loss.backward()

        clip_grad_norm_(model.parameters(), max_norm=cfg["max_norm"])

        optimizer.step()

    scheduler.step()


def test_and_sample(cfg, model, device, test_loader, height, width, losses, params, epoch):
    test_loss = 0

    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            normalized_images = images.float() / (cfg["color_levels"] - 1)
            outputs = model(normalized_images, labels)

            test_loss += F.cross_entropy(outputs, images, reduction='none')

    test_loss = test_loss.mean().cpu() / len(test_loader.dataset)

    wandb.log({
        "Test loss": test_loss
    })
    print("Average test loss: {}".format(test_loss))

    losses.append(test_loss)
    params.append(model.state_dict())

    samples = model.sample((3, height, width), cfg["epoch_samples"], device=device)
    save_samples(samples, TRAIN_SAMPLES_DIR, 'epoch{}_samples.png'.format(epoch + 1))


# Define a global variable to store the saved model path
SAVED_MODEL_PATH = None

def main(
    epochs=15,
    batch_size=32,
    dataset='mnist',
    causal_ksize=7,
    hidden_ksize=7,
    color_levels=2,
    hidden_fmaps=30,
    out_hidden_fmaps=10,
    hidden_layers=6,
    learning_rate=0.0001,
    weight_decay=0.0001,
    max_norm=1.0,
    epoch_samples=3,
    cuda=True,
    model_path=None,
    output_fname='samples.png',
    label=-1,
    count=64,
    height=28,
    width=168
):
    global SAVED_MODEL_PATH  # Access the global variable

    # Configuration dictionary
    cfg = {
        "epochs": epochs,
        "batch_size": batch_size,
        "dataset": dataset,
        "causal_ksize": causal_ksize,
        "hidden_ksize": hidden_ksize,
        "color_levels": color_levels,
        "hidden_fmaps": hidden_fmaps,
        "out_hidden_fmaps": out_hidden_fmaps,
        "hidden_layers": hidden_layers,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "max_norm": max_norm,
        "epoch_samples": epoch_samples,
        "cuda": cuda
    }

    if model_path is None:
        # Training and saving model
        wandb.init(project="PixelCNN")
        wandb.config.update(cfg)
        torch.manual_seed(42)

        model = PixelCNN(cfg=cfg)
        device = torch.device("cuda" if torch.cuda.is_available() and cfg["cuda"] else "cpu")
        model.to(device)

        # train_loader, test_loader, HEIGHT, WIDTH = get_loaders(
        #     cfg["dataset"], cfg["batch_size"], cfg["color_levels"], TRAIN_DATASET_ROOT, TEST_DATASET_ROOT
        # )
        train_loader, test_loader, HEIGHT, WIDTH = global_train_loader, global_test_loader, global_HEIGHT, global_WIDTH


        optimizer = optim.Adam(model.parameters(), lr=cfg["learning_rate"], weight_decay=cfg["weight_decay"])
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, cfg["learning_rate"], 10 * cfg["learning_rate"], cycle_momentum=False)

        wandb.watch(model)

        losses = []
        params = []

        for epoch in range(cfg["epochs"]):
            train(cfg, model, device, train_loader, optimizer, scheduler, epoch)
            test_and_sample(cfg, model, device, test_loader, HEIGHT, WIDTH, losses, params, epoch)

        # Save model parameters
        if not os.path.exists(MODEL_PARAMS_OUTPUT_DIR):
            os.mkdir(MODEL_PARAMS_OUTPUT_DIR)
        MODEL_PARAMS_OUTPUT_FILENAME = '{}_cks{}hks{}cl{}hfm{}ohfm{}hl{}_params.pth'.format(
            cfg["dataset"], cfg["causal_ksize"], cfg["hidden_ksize"], cfg["color_levels"], cfg["hidden_fmaps"],
            cfg["out_hidden_fmaps"], cfg["hidden_layers"]
        )
        SAVED_MODEL_PATH = os.path.join(MODEL_PARAMS_OUTPUT_DIR, MODEL_PARAMS_OUTPUT_FILENAME)
        torch.save(params[np.argmin(np.array(losses))], SAVED_MODEL_PATH)
        print(f"Model saved to {SAVED_MODEL_PATH}")
    else:
        # Testing and loading the model
        OUTPUT_FILENAME = output_fname
        model = PixelCNN(cfg=cfg)
        model.eval()

        device = torch.device("cuda" if torch.cuda.is_available() and cfg["cuda"] else "cpu")
        model.to(device)

        model.load_state_dict(torch.load(model_path))

        label = None if label == -1 else label
        samples = model.sample((3, height, width), count, label=label, device=device)
        save_samples(samples, TRAIN_SAMPLES_DIR, OUTPUT_FILENAME)



if __name__ == '__main__':
    main()


0,1
Test loss,█▁

0,1
Test loss,0.00239


Epoch 1/15: 100%|██████████| 15187/15187 [1:05:17<00:00,  3.88it/s]


Average test loss: 0.0025502557400614023


Epoch 2/15:  60%|██████    | 9167/15187 [39:27<25:57,  3.87it/s]

In [None]:
SAVED_MODEL_PATH = "model/mnist_cks7hks7cl2hfm30ohfm10hl6_params.pth"

In [None]:
import torch


import argparse

OUTPUT_DIRNAME = 'samples'


import torch

def main(
    causal_ksize=7,
    hidden_ksize=7,
    color_levels=2,
    hidden_fmaps=30,
    out_hidden_fmaps=10,
    hidden_layers=6,
    cuda=True,
    model_path=None,
    output_fname='samples.png',
    label=-1,
    count=64,
    height=28,
    width=168
):
    # Configuration dictionary
    cfg = {
        "causal_ksize": causal_ksize,
        "hidden_ksize": hidden_ksize,
        "color_levels": color_levels,
        "hidden_fmaps": hidden_fmaps,
        "out_hidden_fmaps": out_hidden_fmaps,
        "hidden_layers": hidden_layers,
        "cuda": cuda,
        "model_path": model_path,
        "output_fname": output_fname,
        "label": label,
        "count": count,
        "height": height,
        "width": width
    }

    if cfg["model_path"] is None:
        raise ValueError("model_path must be specified to load the model.")

    OUTPUT_FILENAME = cfg["output_fname"]

    model = PixelCNN(cfg=cfg)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() and cfg["cuda"] else "cpu")
    model.to(device)

    # Loading the model with `weights_only=True` as a safer default
    model.load_state_dict(torch.load(cfg["model_path"], weights_only=True))

    label = None if cfg["label"] == -1 else cfg["label"]
    samples = model.sample((3, cfg["height"], cfg["width"]), cfg["count"], label=label, device=device)
    save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)



main(model_path=SAVED_MODEL_PATH)


In [None]:
print(SAVED_MODEL_PATH)

In [None]:
if __name__ == '__main__':
    # Define the path where the model is saved

    # Ask the user for a single digit
    digit_input = input("Enter a single digit (0-9) to generate its image: ")

    # Ensure the input is valid
    try:
        label = int(digit_input)
        if label < 0 or label > 9:
            raise ValueError("Please enter a single digit between 0 and 9.")
    except ValueError as e:
        print(e)
        exit(1)
    print(SAVED_MODEL_PATH)
    # Configuration dictionary for loading model and generating samples
    cfg = {
        "causal_ksize": 7,
        "hidden_ksize": 7,
        "color_levels": 2,
        "hidden_fmaps": 30,
        "out_hidden_fmaps": 10,
        "hidden_layers": 6,
        "cuda": True,
        "model_path": SAVED_MODEL_PATH,
        "output_fname": f'digit_{label}_sample.png',
        "label": label,
        "count": 1,  # Generate a single image for the digit
        "height": 28,
        "width": 168
    }

    # Initialize the model
    model = PixelCNN(cfg=cfg)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() and cfg["cuda"] else "cpu")
    model.to(device)

    # Load the pre-trained model weights
    model.load_state_dict(torch.load(cfg["model_path"]))

    # Generate and save the sample for the specified digit
    samples = model.sample((3, cfg["height"], cfg["width"]), cfg["count"], label=label, device=device)
    save_samples(samples, OUTPUT_DIRNAME, cfg["output_fname"])

    print(f"Image of digit {label} generated and saved as {cfg['output_fname']}.")
