In [1]:
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm
import functools as ft

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [12]:
def load_sudoku_images(path, total, device, normalize=False):
    sudoku_img = torch.empty(total,8*28,8*28, device=device)
    if normalize:
        transform = transforms.Compose((
            transforms.ToTensor(),
            transforms.Normalize((0.5,),(0.5,))))
    else:
        transform = transforms.ToTensor()
    for i in tqdm(range(total), 'sudoku images'):
        sudoku_img[i] = transform(Image.open(f'{path}/{i}.png'))
    return sudoku_img

In [13]:
def as_digit_images(sudoku_img):
    return torch.stack(torch.split(
        torch.stack(torch.split(sudoku_img, [28]*8, dim=-2), dim=-3),
        [28]*8, dim=-1), dim=-3)

In [21]:
#sudoku_img = load_sudoku_images('data/query', 10000, device, normalize=True)
#torch.save(sudoku_img, 'data/pt-cache/query_sudoku_img_norm.pt')
#sudoku_img = torch.load('data/pt-cache/query_sudoku_img_norm.pt')

In [22]:
#digit_img = as_digit_images(sudoku_img)
#torch.save(digit_img, 'data/pt-cache/query_digit_img_norm.pt')
digit_img = torch.load('data/pt-cache/query_digit_img_norm.pt')

In [23]:
full_X = digit_img.view(-1,1,28,28)
full_y = torch.load('data/pt-cache/query_y.pt')
class_size = torch.min(torch.bincount(full_y))
bal_X = torch.cat([full_X[full_y==i][:class_size] for i in range(9)])
bal_y = torch.cat([full_y[full_y==i][:class_size] for i in range(9)])
train_X, test_X, train_y, test_y = train_test_split(
    bal_X, bal_y, test_size=10000, stratify=bal_y)

In [67]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        make_block = lambda i, o, k, s: nn.Sequential(
            nn.ConvTranspose2d(i, o, k, s),
            nn.BatchNorm2d(o),
            nn.ReLU()
        )
        self.gen = nn.Sequential(
            make_block(73, 256, 3, 2),
            make_block(256, 128, 4, 1),
            make_block(128, 64, 3, 2),
            nn.ConvTranspose2d(64, 1, 4, 2),
            nn.Tanh()
        )

    def forward(self, Z, Y):
        return self.gen(torch.cat((Z,Y), dim=1).view(-1,73,1,1))

In [68]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        make_block = lambda i, o, k, s: nn.Sequential(
            nn.Conv2d(i, o, k, s),
            nn.BatchNorm2d(o),
            nn.LeakyReLU(0.2)
        )
        self.disc = nn.Sequential(
            make_block(10,64,4,2),
            make_block(64,128,4,2),
            nn.Conv2d(128,1,4,2)
        )

    def forward(self, X, Y):
        return self.disc(torch.cat((X, Y[:,:,None,None].repeat(1,1,28,28)), dim=1))

In [69]:
def viz_images(img):
    plt.imshow(utils.make_grid((img/2+0.5).detach().cpu(), nrow=10).permute(1,2,0))

In [70]:
gen = Generator().to(device)
disc = Discriminator().to(device)
disc_opt = optim.Adam(disc.parameters(), lr=2e-4)
gen_opt = optim.Adam(gen.parameters(), lr=2e-4)
train_loader = DataLoader(TensorDataset(train_X.to(device), train_y.to(device)), batch_size=128, shuffle=True)
criterion = nn.BCEWithLogitsLoss()

In [71]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
n_epochs = 100
disc_losses = []
gen_losses = []
# tqdm.__init__ = ft.partialmethod(tqdm.__init__, disable=True)
ctr = 0
for epoch in tqdm(range(n_epochs), 'epochs'):
    for real_X, real_y in tqdm(train_loader, 'batches'):
        ctr += 1
        
        real_Y = F.one_hot(real_y.long(), num_classes=9).float()
        real_yhat = disc(real_X, real_Y)
        real_loss = criterion(real_yhat, torch.ones_like(real_yhat))
        
        Z = torch.rand(real_X.shape[0],64, device=device)
        fake_X = gen(Z, real_Y)
        fake_yhat = disc(fake_X.detach(), real_Y)
        fake_loss = criterion(fake_yhat, torch.zeros_like(fake_yhat))
        
        disc_loss = (real_loss + fake_loss)/2
        disc_opt.zero_grad()
        disc_loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(disc.parameters(), 5)
        disc_opt.step()
        disc_losses.append(disc_loss.item())
        
        fake_yhat = disc(fake_X, real_Y)
        gen_loss = criterion(fake_yhat, torch.ones_like(fake_yhat))
        gen_opt.zero_grad()
        gen_loss.backward()
        nn.utils.clip_grad_norm_(gen.parameters(), 5)
        gen_opt.step()
        gen_losses.append(gen_loss.item())
        
        if ctr % 500 == 499:
            plt.figure(figsize=(12,5))
            plt.subplot(121)
            viz_images(fake_X[:100])
            plt.subplot(122)
            viz_images(real_X[:100])
            plt.suptitle(f'epoch: {epoch}    gen loss: {gen_loss:.5f}    disc loss: {disc_loss:.5f}')
            plt.show()
            plt.close()

            plt.figure()
            plt.plot(disc_losses, label='discriminator')
            plt.plot(gen_losses, label='generator')
            plt.legend()
            plt.xlabel('batches')
            plt.ylabel('loss')
            plt.title('Loss Curves')
            plt.show()
            plt.close()

    torch.save(gen.state_dict(), f'data/pt-cache/nice_gen.{epoch}.pt')
    print(f'gen.state_dict() saved in: data/pt-cache/nice_gen.{epoch}.pt')
    torch.save(disc.state_dict(), f'data/pt-cache/nice_disc.{epoch}.pt')
    print(f'disc.state_dict() saved in: data/pt-cache/nice_disc.{epoch}.pt')

HBox(children=(HTML(value='epochs'), FloatProgress(value=0.0), HTML(value='')))

HBox(children=(HTML(value='batches'), FloatProgress(value=0.0, max=2752.0), HTML(value='')))