In [1]:
import torch
import torchvision as tv
import os
from transformer_flow import Model
import utils
import pathlib

os.environ["CUDA_VISIBLE_DEVICES"] = "4" # set GPU

utils.set_random_seed(100)
notebook_output_path = pathlib.Path('runs/notebook')

In [None]:
dataset = 'mnist'
num_classes = 10
img_size = 28
channel_size = 1

# we use a small model for fast demonstration, increase the model size for better results
patch_size = 4
channels = 128
blocks = 4
layers_per_block = 4
# try different noise levels to see its effect
noise_std = 0.1

batch_size = 256
lr = 5e-4
# increase epochs for better results
epochs = 100
sample_freq = 5

if torch.cuda.is_available():
    device = 'cuda' 
elif torch.backends.mps.is_available():
    device = 'mps' # if on mac
else:
    device = 'cpu' # if mps not available
print(f'using device {device}')

fixed_noise = torch.randn(num_classes * 10, (img_size // patch_size)**2, channel_size * patch_size ** 2, device=device)
# fixed_noise = torch.randn(num_classes * 10, (img_size // patch_size)**2, channels, device=device)
fixed_y = torch.arange(num_classes, device=device).view(-1, 1).repeat(1, 10).flatten()

transform = tv.transforms.Compose([
    tv.transforms.Resize((img_size, img_size)),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5,), (0.5,))
])
data = tv.datasets.MNIST('.', transform=transform, train=True, download=True)
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)

model = Model(in_channels=channel_size, img_size=img_size, patch_size=patch_size, 
              channels=channels, num_blocks=blocks, layers_per_block=layers_per_block,
              num_classes=num_classes).to(device)

optimizer = torch.optim.AdamW(model.parameters(), betas=(0.9, 0.95), lr=lr, weight_decay=1e-4)
lr_schedule = utils.CosineLRSchedule(optimizer, len(data_loader), epochs * len(data_loader), 1e-6, lr)

model_name = f'{patch_size}_{channels}_{blocks}_{layers_per_block}_{noise_std:.2f}_linears_with_residual'
sample_dir = notebook_output_path / f'{dataset}_samples_{model_name}'
ckpt_file = notebook_output_path / f'{dataset}_model_{model_name}.pth'
sample_dir.mkdir(exist_ok=True, parents=True)

for epoch in range(epochs):
    losses = 0
    for x, y in data_loader:
        x = x.to(device)
        eps = noise_std * torch.randn_like(x)
        x = x + eps
        y = y.to(device)
        optimizer.zero_grad()
        z, outputs, logdets = model(x, y)
        loss = model.get_loss(z, logdets)
        loss.backward()
        optimizer.step()
        lr_schedule.step()
        losses += loss.item()

    print(f"epoch {epoch} lr {optimizer.param_groups[0]['lr']:.6f} loss {losses / len(data_loader):.4f}")
    print('layer norms', ' '.join([f'{z.pow(2).mean():.4f}' for z in outputs]))
    print(f'logdet: {logdets.mean():.4f}, prior p: {0.5 * z.pow(2).mean():.4f}')
    if (epoch + 1) % sample_freq == 0:
        with torch.no_grad():
            samples = model.reverse(fixed_noise, fixed_y)
        tv.utils.save_image(samples, sample_dir / f'samples_{epoch:03d}.png', normalize=True, nrow=10)
        latents = model.unpatchify(z[:100])
        tv.utils.save_image(latents, sample_dir / f'latent_{epoch:03d}.png', normalize=True, nrow=10)
        print(f'sampling complete. Sample mean: {samples.mean():.4f}, std: {samples.std():.4f}')
        print(f'latent mean: {latents.mean():.4f}, std: {latents.std():.4f}')
    print('\n')
torch.save(model.state_dict(), ckpt_file)

using device cuda
Unitary logdet: 0.00
Unitary logdet: -0.00
Unitary logdet: -0.00
Unitary logdet: 0.00
Number of parameters: 3.24M
epoch 0 lr 0.000500 loss -0.8450
layer norms 1.2346 1.2148 0.9755 0.9531
logdet: 1.7785, prior p: 0.4766
epoch 1 lr 0.000500 loss -1.3738
layer norms 1.2000 0.8658 0.6879 0.9517
logdet: 1.8950, prior p: 0.4759
epoch 2 lr 0.000499 loss -1.4457
layer norms 0.9725 0.8269 0.6558 0.9906
logdet: 1.9576, prior p: 0.4953
epoch 3 lr 0.000499 loss -1.4771
layer norms 0.9077 0.7974 0.6192 1.0341
logdet: 2.0135, prior p: 0.5171
epoch 4 lr 0.000498 loss -1.4974
layer norms 0.8630 0.7712 0.5666 1.0048
logdet: 2.0075, prior p: 0.5024
sampling complete
epoch 5 lr 0.000497 loss -1.5136
layer norms 0.8490 0.7687 0.5534 1.0387
logdet: 2.0406, prior p: 0.5194
epoch 6 lr 0.000495 loss -1.5280
layer norms 0.8320 0.7304 0.4761 0.9781
logdet: 2.0216, prior p: 0.4890
epoch 7 lr 0.000494 loss -1.5400
layer norms 0.8306 0.7381 0.4755 0.9805
logdet: 2.0233, prior p: 0.4902
epoch 8 lr

In [3]:
# now we can also evaluate the model by turning it into a classifier with Bayes rule, p(y|x) = p(y)p(x|y)/p(x)
data = tv.datasets.MNIST('.', transform=transform, train=False, download=False)
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False)
num_correct = 0
num_examples = 0
for x, y in data_loader:
    x = x.to(device)
    y = y.to(device)
    eps = noise_std * torch.randn_like(x)
    x = x.repeat(num_classes, 1, 1, 1)
    y_ = torch.arange(num_classes, device=device).view(-1, 1).repeat(1, y.size(0)).flatten()
    with torch.no_grad():
        z, outputs, logdets = model(x, y_)
        losses = 0.5 * z.pow(2).mean(dim=[1, 2]) - logdets # keep the batch dimension
        pred = losses.reshape(num_classes, y.size(0)).argmin(dim=0)
    num_correct += (pred == y).sum()
    num_examples += y.size(0)
print(f'Accuracy %{100 * num_correct / num_examples:.2f}')

Accuracy %98.56


In [5]:
import torch
import os
from transformer_flow import Model
import utils

os.environ["CUDA_VISIBLE_DEVICES"] = "4" # set GPU
utils.set_random_seed(100)

num_classes = 10
img_size = 28
channel_size = 1

# we use a small model for fast demonstration, increase the model size for better results
patch_size = 4
channels = 128
blocks = 4
layers_per_block = 4
# try different noise levels to see its effect
noise_std = 0.1

if torch.cuda.is_available():
    device = 'cuda' 
elif torch.backends.mps.is_available():
    device = 'mps' # if on mac
else:
    device = 'cpu' # if mps not available
print(f'using device {device}')

fixed_noise = torch.randn(num_classes * 10, (img_size // patch_size)**2, channel_size * patch_size ** 2, device=device)
# fixed_noise = torch.randn(num_classes * 10, (img_size // patch_size)**2, channels, device=device)
fixed_y = torch.arange(num_classes, device=device).view(-1, 1).repeat(1, 10).flatten()

# load the model
model = Model(in_channels=channel_size, img_size=img_size, patch_size=patch_size, 
              channels=channels, num_blocks=blocks, layers_per_block=layers_per_block,
              num_classes=num_classes).to(device)
model.load_state_dict(torch.load("runs/notebook/mnist_model_4_128_4_4_0.10_linears_with_residual.pth"))
model.eval()

# code for printing the unitaries
unitaries = model.unitaries
for i, unitary in enumerate(unitaries):
    print(f'unitary {i}:')
    W = unitary.weight
    # calculate the eigenvalues
    eigenvalues, eigenvectors = torch.linalg.eig(W)
    # print the eigenvalues
    print(f'eigenvalues: {eigenvalues}\n')
    # normalize the eigenvectors
    eigenvectors = eigenvectors / torch.norm(eigenvectors, dim=1, keepdim=True)
    # print the max value of the eigenvectors
    max_eigenvector = torch.abs(eigenvectors).max(dim=1).values
    assert max_eigenvector.ndim == 1
    print(f'max eigenvector: {max_eigenvector}\n')


# with torch.no_grad():
#     samples = model.reverse(fixed_noise, fixed_y)
#     # print the mean and std of the samples
#     mean = samples.mean(dim=[0, 2, 3])
#     std = samples.std(dim=[0, 2, 3])
#     print(f'mean: {mean}, std: {std}')

using device cuda
Unitary logdet: 0.00
Unitary logdet: -0.00
Unitary logdet: -0.00
Unitary logdet: 0.00
Number of parameters: 3.24M
unitary 0:
eigenvalues: tensor([1.1871+0.0000j, 1.0933+0.0000j, 1.0603+0.0000j, 1.0039+0.0000j,
        1.0081+0.0000j, 1.0333+0.0000j, 1.0243+0.0082j, 1.0243-0.0082j,
        1.0309+0.0000j, 1.0295+0.0000j, 1.0274+0.0000j, 1.0263+0.0005j,
        1.0263-0.0005j, 1.0099+0.0001j, 1.0099-0.0001j, 1.0216+0.0044j,
        1.0216-0.0044j, 1.0113+0.0000j, 1.0122+0.0008j, 1.0122-0.0008j,
        1.0248+0.0000j, 1.0244+0.0000j, 1.0133+0.0000j, 1.0137+0.0000j,
        1.0164+0.0024j, 1.0164-0.0024j, 1.0146+0.0000j, 1.0154+0.0015j,
        1.0154-0.0015j, 1.0212+0.0016j, 1.0212-0.0016j, 1.0225+0.0006j,
        1.0225-0.0006j, 1.0218+0.0009j, 1.0218-0.0009j, 1.0218+0.0000j,
        1.0178+0.0021j, 1.0178-0.0021j, 1.0157+0.0007j, 1.0157-0.0007j,
        1.0157+0.0000j, 1.0198+0.0011j, 1.0198-0.0011j, 1.0170+0.0000j,
        1.0179+0.0006j, 1.0179-0.0006j, 1.0188+0.000

  model.load_state_dict(torch.load("runs/notebook/mnist_model_4_128_4_4_0.10_linears_with_residual.pth"))
