In [1]:
import os
import random
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import PIL
from PIL import Image, ImageFilter
from torch.utils.data import DataLoader
import numpy as np
from tqdm.auto import tqdm

# seed everything
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


class PrintedDigitsDataset(Dataset):
    def __init__(self, root_dir, transform=None, random_dilation=False, meta='font_list.txt'):
        self.root_dir = root_dir
        self.transform = transform
        self.random_dilation = random_dilation
        self.samples = []
        with open(os.path.join(root_dir, meta)) as f:
            font_list = f.read().splitlines()

        for font_folder in font_list:
            font_path = os.path.join(root_dir, font_folder)
            for label in range(10):
                image_path = os.path.join(font_path, f'{label}/{label}.png')
                self.samples.append((image_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = PIL.ImageOps.invert(Image.open(image_path).convert("RGB"))
        if self.random_dilation:
            dilation_width = random.randrange(1, 11, 2)
            image = image.resize((128, 128)).filter(ImageFilter.MaxFilter(dilation_width))
        if self.transform is not None:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

data_path = "./printed_digits"
train_dataset = PrintedDigitsDataset(root_dir=data_path, transform=transform, random_dilation=True)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

print(f"Total samples: {len(train_dataset)}")


Total samples: 50


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DigitClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(DigitClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(4096, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DigitClassifier(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4)

def train_model(model, train_loader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        print('\r', end='')
        total_loss = 0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.2f}%", end='')
    print()
train_model(model, train_loader, optimizer, criterion, epochs=70)


Epoch 70/70, Loss: 0.4332, Accuracy: 98.00%%


In [4]:
torch.save(model.state_dict(), "digit_classifier_printed.pth")

In [5]:
@torch.no_grad()
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")


In [6]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = PrintedDigitsDataset(root_dir=data_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
test_model(model, test_loader)

Test Accuracy: 98.00%


In [7]:
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_transform = transforms.Compose([
    transforms.Lambda(lambda img : img.convert('RGB')),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=mnist_transform
)

# mnist_test_dataset = torch.utils.data.Subset(mnist_test_dataset, list(range(128)))

mnist_test_loader = DataLoader(mnist_test_dataset, batch_size=256, shuffle=False)

# def show_mnist_images(data_loader):
#     data_iter = iter(data_loader)
#     images, labels = next(data_iter)
#     images = images.squeeze(1)

#     fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
#     for idx, (img, label) in enumerate(zip(images, labels)):
#         axes[idx].imshow(img.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
#         axes[idx].set_title(f"Label: {label.item()}")
#         axes[idx].axis("off")
#     plt.show()

# show_mnist_images(mnist_test_loader)

In [8]:
model = DigitClassifier(num_classes=10).to(device)
model.load_state_dict(torch.load("digit_classifier_printed.pth", weights_only=True))
test_model(model, mnist_test_loader)

Test Accuracy: 37.34%


In [9]:
mnist_transform = transforms.Compose([
    transforms.Lambda(lambda img : img.convert('RGB')),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=mnist_transform
)

mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=256, shuffle=True)

model = DigitClassifier(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
train_model(model, mnist_train_loader, optimizer, criterion, epochs=10)
test_model(model, mnist_test_loader)
torch.save(model.state_dict(), "digit_classifier_mnist.pth")

Epoch 10/10, Loss: 4.8183, Accuracy: 99.37%
Test Accuracy: 98.90%


In [10]:
from diffusers import UNet2DModel
from pipeline_ddpm_custom import DDPMPipelineCustom
from scheduling_rectflow import RectFlowScheduler, RectFlowInverseScheduler

def get_beautifier(model_dir, classifier, device='cpu', num_inverse_step=200, num_denoise_step=50, printed_digits_dir="./printed_digits"):
    printed_digits = {}
    total_fonts = 0
    total_images = 0
    for font in os.listdir(printed_digits_dir):
        font_path = os.path.join(printed_digits_dir, font)
        if os.path.isdir(font_path):
            total_fonts += 1  # Count fonts
            for digit in range(10):
                digit_folder = os.path.join(font_path, str(digit))
                if os.path.exists(digit_folder):
                    images = [
                        transforms.functional.to_tensor(
                            PIL.ImageOps.invert(Image.open(os.path.join(digit_folder, img_path)).resize((32, 32)).convert("RGB"))
                        ) * 2 - 1
                        for img_path in os.listdir(digit_folder)
                        if img_path.endswith(".png")
                    ]
                    total_images += len(images)
                    printed_digits.setdefault(digit, []).extend(images)
    

    unet = UNet2DModel.from_pretrained(f"{model_dir}/unet")
    scheduler_config_path = os.path.join(model_dir, "scheduler", "scheduler_config.json")
    scheduler = RectFlowScheduler.from_config(scheduler_config_path)
    scheduler_inv = RectFlowInverseScheduler.from_config(scheduler_config_path)
    unet.to(device)
    classifier.to(device)
    
    pipeline = DDPMPipelineCustom(
        unet=unet,
        scheduler=scheduler,
    )
    pipeline_inv = DDPMPipelineCustom(
        unet=unet,
        scheduler=scheduler_inv,
    )
    pipeline.set_progress_bar_config(disable=True)
    pipeline_inv.set_progress_bar_config(disable=True)
    class_conditioning = unet.class_embedding is not None
    
    @torch.no_grad()
    def beautifier(images, alpha):
        classifier.eval()
        _, labels = torch.max(classifier(images), 1)
        printed = torch.stack([
            random.choice(printed_digits[label.item()])
            for (image, label) in zip(images, labels)
        ]).to(images.device)
        
        # Run noise inversion
        images = torch.cat([images, printed])
        labels = torch.cat([labels] * 2) if class_conditioning else None
        semantic_noise = pipeline_inv(labels, init_noise=images, clamp_output=False, num_inference_steps=num_inverse_step, output_type='pt').images
        z_images, z_printed = semantic_noise.chunk(2)

        # interpolation
        interp_z = alpha * z_printed + (1 - alpha) * z_images
        
        # denoise from interpolated semantic noise
        labels = labels.chunk(2)[0] if class_conditioning else None
        interp = pipeline(labels, init_noise=interp_z * 2 - 1, num_inference_steps=num_denoise_step, output_type='pt').images
        return interp
    
    return beautifier

In [11]:
@torch.no_grad()
def evaluate_handwriting_beautification(model, test_loader, beautifier, alpha):
    model.eval()
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        images = beautifier(images, alpha) * 2 - 1
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [12]:
classifier = DigitClassifier(num_classes=10).to(device)
classifier.load_state_dict(torch.load("digit_classifier_mnist.pth", weights_only=True))
model = DigitClassifier(num_classes=10).to(device)
model.load_state_dict(torch.load("digit_classifier_printed.pth", weights_only=True))

for config in ['cls_cnd_aug_printed', 'uncnd_aug_printed']:
    for alpha in [0.05, 0.1, 0.15, 0.2]:
        print(f'Running {config} with alpha {alpha}')
        beautifier = get_beautifier(f'./output/{config}/', classifier, device)
        acc = evaluate_handwriting_beautification(model, mnist_test_loader, beautifier, alpha)
        with open(f'nb_logs/{alpha}_{config}.txt', 'w') as f:
            f.write(str(acc))

Running cls_cnd_aug_printed with alpha 0.05


  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 45.72%
Running cls_cnd_aug_printed with alpha 0.1


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 63.46%
Running cls_cnd_aug_printed with alpha 0.15


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 76.12%
Running cls_cnd_aug_printed with alpha 0.2


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 84.84%
Running uncnd_aug_printed with alpha 0.05


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 35.90%
Running uncnd_aug_printed with alpha 0.1


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 44.09%
Running uncnd_aug_printed with alpha 0.15


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 53.61%
Running uncnd_aug_printed with alpha 0.2


  0%|          | 0/40 [00:00<?, ?it/s]

Test Accuracy: 64.25%
