In [81]:
import math
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
import kornia

In [None]:
resnet18 = torchvision.models.resnet18(pretrained=True)

In [111]:
train_dataset = torchvision.datasets.MNIST(root='_data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='_data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
resnet18.fc = nn.Linear(512, 10)
resnet18.cuda()

optimizer = torch.optim.Adam(resnet18.parameters(), lr=1e-3)

for epoch in range(5):
    resnet18.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.cuda().repeat(1, 3, 1, 1)
        y = y.cuda()

        y_pred = resnet18(x)
        loss = nn.functional.cross_entropy(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'epoch {epoch}, iter {i}, loss {loss.item()}')

    resnet18.eval()
    correct = 0
    total = 0
    for x, y in test_loader:
        x = x.cuda().repeat(1, 3, 1, 1)
        y = y.cuda()

        y_pred = resnet18(x)
        _, y_pred = y_pred.max(dim=1)
        correct += (y_pred == y).sum().item()
        total += y.size(0)

    print(f'epoch {epoch}, accuracy {correct / total}')

In [114]:
# https://github.com/jw9730/lps/src/symmetry/groups/SO.py
def samples_from_haar_distribution(d, bsize, device, dtype) -> torch.Tensor:
    """Random orthogonal matrices with determinant 1 drawn from SO(d) Haar distribution
    Adopted from scipy.stats.special_ortho_group, which implements the algorithm described in
    Mezzadri, How to generate random matrices from the classical compact groups (2006)
    """
    # H represents a (dim, dim) matrix, while D represents the diagonal of
    # a (dim, dim) diagonal matrix. The algorithm that follows is
    # broadcasted on the leading shape in `size` to vectorize along
    # samples.
    H = torch.empty(bsize, d, d, device=device, dtype=dtype)
    H[..., :, :] = torch.eye(d, device=device, dtype=dtype)
    D = torch.empty(bsize, d, device=device, dtype=dtype).fill_(float('inf'))
    for n in range(d-1):
        # x is a vector with length dim-n, xrow and xcol are views of it as
        # a row vector and column vector respectively. It's important they
        # are views and not copies because we are going to modify x
        # in-place.
        x = torch.randn(bsize, d-n, device=device, dtype=dtype)
        xrow = x[..., None, :]
        xcol = x[..., :, None]

        # This is the squared norm of x, without vectorization it would be
        # dot(x, x), to have proper broadcasting we use matmul and squeeze
        # out (convert to scalar) the resulting 1x1 matrix
        norm2 = torch.matmul(xrow, xcol).squeeze((-2, -1))

        x0 = x[..., 0].clone()
        D[..., n] = torch.where(x0 != 0, torch.sign(x0), 1)
        x[..., 0] += D[..., n] * torch.sqrt(norm2)

        # In renormalizing x we have to append an additional axis with
        # [..., None] to broadcast the scalar against the vector x
        x /= torch.sqrt((norm2 - x0**2 + x[..., 0]**2) / 2.)[..., None]

        # Householder transformation, without vectorization the RHS can be
        # written as outer(H @ x, x) (apart from the slicing)
        H[..., :, n:] -= torch.matmul(H[..., :, n:], xcol) * xrow

    D[..., -1] = (-1)**(d-1) * D[..., :-1].prod(dim=-1)

    # Without vectorization this could be written as H = diag(D) @ H,
    # left-multiplication by a diagonal matrix amounts to multiplying each
    # row of H by an element of the diagonal, so we add a dummy axis for
    # the column index
    H *= D[..., :, None]
    return H

In [125]:
def get_energy(x, y):
    y_pred = resnet18(x)
    loss = nn.functional.cross_entropy(y_pred, y)
    return loss

def so2_action(g, x):
    alpha = g[:, 0, 0]
    beta = g[:, 0, 1]
    cx = cy = 28/2
    affine_part = torch.stack([(1 - alpha) * cx - beta * cy, beta * cx + (1 - alpha) * cy], dim=1)
    affine_matrices = torch.cat([g, affine_part.unsqueeze(-1)], dim=-1)
    return kornia.geometry.affine(x, affine_matrices)

def so2_lie_algebra_element(alpha):
    bsize = alpha.size(0)
    A = torch.zeros(bsize, 2, 2, device=alpha.device)
    A[:, 0, 1] = -alpha[:, 0]
    A[:, 1, 0] = alpha[:, 0]
    return A

def so2_retraction(A):
    return torch.linalg.matrix_exp(A)

def so2_lie_algebra_gradient(alpha, x, y):
    alpha = alpha.clone().detach()
    alpha.requires_grad_(True)
    A = so2_lie_algebra_element(alpha)
    g = so2_retraction(A)
    gx = so2_action(g, x)
    energy = get_energy(gx, y)
    grad, = torch.autograd.grad(energy, alpha, only_inputs=True)
    return grad

def annealed_langevin_dynamics(alpha, x, y, L, T, eps, sigma_1, decay_rate, grad_scale):
    alpha_list = [alpha]
    sigma_list = [sigma_1 * decay_rate ** i for i in range(L)]
    step_sizes = [eps * (sigma_list[i] / sigma_list[L - 1]) ** 2 for i in range(L)]
    for i in range(L):
        for _ in range(T):
            alpha = alpha_list[-1]
            grad = so2_lie_algebra_gradient(alpha, x, y)
            alpha = alpha - (step_sizes[i] / 2) * grad_scale * grad + math.sqrt(step_sizes[i]) * torch.randn_like(alpha)
            alpha_list.append(alpha)
    return alpha_list

In [126]:
resnet18.eval()

x, y = next(iter(test_loader))

# x = x.cuda().repeat(1, 3, 1, 1)
# y = y.cuda()

x = x[0:1].cuda().repeat(x.size(0), 3, 1, 1)
y = y[0:1].cuda().repeat(x.size(0))

random_g = samples_from_haar_distribution(2, x.size(0), x.device, x.dtype)
x0 = so2_action(random_g, x)

# following https://arxiv.org/pdf/1907.05600
L = 10
T = 100
sigma_1 = 1
decay_rate = 10 ** (math.log10(0.01) / 10)
eps = 2e-5

grad_scale = 100

alpha = torch.randn(x0.size(0), 1).cuda()
alpha_list = annealed_langevin_dynamics(alpha, x0, y, L, T, eps, sigma_1, decay_rate, grad_scale)

In [None]:
final_alpha = alpha_list[-1]
final_A = so2_lie_algebra_element(final_alpha)
final_g = so2_retraction(final_A)
final_x = so2_action(final_g, x0)

def accuracy(x, y):
    y_pred = resnet18(x)
    _, y_pred = y_pred.max(dim=1)
    return (y_pred == y).float().mean().item()

y_pred = resnet18(x)
y0_pred = resnet18(x0)
final_y_pred = resnet18(final_x)

print(f'accuracy(x) = {accuracy(x, y)}')
print(f'accuracy(x0) = {accuracy(x0, y)}')
print(f'accuracy(final_x) = {accuracy(final_x, y)}')

# visualize first 20 images
x_numpy = x.cpu().detach().numpy()
x0_numpy = x0.cpu().detach().numpy()
final_x_numpy = final_x.cpu().detach().numpy()

fig, axs = plt.subplots(3, 20, figsize=(20, 6))
for i in range(20):
    axs[0, i].imshow(x_numpy[-i, 0], cmap='gray')
    axs[1, i].imshow(x0_numpy[-i, 0], cmap='gray')
    axs[2, i].imshow(final_x_numpy[-i, 0], cmap='gray')

In [None]:
class SmallUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, mid_ch=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, mid_ch, 5, padding=2)
        self.conv2 = nn.Conv2d(mid_ch, mid_ch, 5, padding=2)
        self.down  = nn.Conv2d(mid_ch, mid_ch*2, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(mid_ch*2, mid_ch*2, 3, padding=1)
        self.up    = nn.ConvTranspose2d(mid_ch*2, mid_ch, 2, stride=2)
        self.conv4 = nn.Conv2d(mid_ch*2, mid_ch, 5, padding=2)
        self.conv5 = nn.Conv2d(mid_ch, out_ch, 5, padding=2)

    def forward(self, x):
        """
        x: (B,C,H,W)
        returns: (B,C,H,W)
        """
        d1 = torch.relu(self.conv1(x))    # (B, mid_ch, H, W)
        d1 = torch.relu(self.conv2(d1))   # (B, mid_ch, H, W)
        d2 = torch.relu(self.down(d1))    # (B, mid_ch*2, H/2, W/2)
        d2 = torch.relu(self.conv3(d2))   # (B, mid_ch*2, H/2, W/2)

        u1 = self.up(d2)                  # (B, mid_ch, H, W)
        cat1 = torch.cat([u1, d1], dim=1) # (B, mid_ch+mid_ch, H, W)
        u1 = torch.relu(self.conv4(cat1)) # (B, mid_ch, H, W)
        out = self.conv5(u1)              # (B, out_ch, H, W)
        return out

unet = SmallUNet().cuda()
unet.eval()


In [155]:
def get_energy_unet(gx):
    inner = (gx * unet(gx.detach())).sum()
    return inner

def so2_lie_algebra_gradient_unet(alpha, x, create_graph):
    alpha = alpha.clone().detach()
    alpha.requires_grad_(True)
    A = so2_lie_algebra_element(alpha)
    g = so2_retraction(A)
    gx = so2_action(g, x)
    energy = get_energy_unet(gx)
    grad, = torch.autograd.grad(energy, alpha, create_graph=create_graph)
    return grad

def annealed_langevin_dynamics_unet(alpha, x, L, T, eps, sigma_1, decay_rate, grad_scale):
    alpha_list = [alpha]
    sigma_list = [sigma_1 * decay_rate ** i for i in range(L)]
    step_sizes = [eps * (sigma_list[i] / sigma_list[L - 1]) ** 2 for i in range(L)]
    for i in range(L):
        for _ in range(T):
            alpha = alpha_list[-1]
            grad = so2_lie_algebra_gradient_unet(alpha, x, create_graph=False)
            alpha = alpha - (step_sizes[i] / 2) * grad_scale * grad + math.sqrt(step_sizes[i]) * torch.randn_like(alpha)
            alpha_list.append(alpha)
    return alpha_list

In [156]:
resnet18.eval()
unet.eval()

x, y = next(iter(test_loader))

x = x[0:1].cuda().repeat(x.size(0), 3, 1, 1)
y = y[0:1].cuda().repeat(x.size(0))

random_g = samples_from_haar_distribution(2, x.size(0), x.device, x.dtype)
x0 = so2_action(random_g, x)

# following https://arxiv.org/pdf/1907.05600
L = 10
T = 100
sigma_1 = 1
decay_rate = 10 ** (math.log10(0.01) / 10)
eps = 2e-5

grad_scale = 100

alpha = torch.randn(x0.size(0), 1).cuda()
alpha_list = annealed_langevin_dynamics_unet(alpha, x0, L, T, eps, sigma_1, decay_rate, grad_scale)

In [None]:
final_alpha = alpha_list[-1]
final_A = so2_lie_algebra_element(final_alpha)
final_g = so2_retraction(final_A)
final_x = so2_action(final_g, x0)

def accuracy(x, y):
    y_pred = resnet18(x)
    _, y_pred = y_pred.max(dim=1)
    return (y_pred == y).float().mean().item()

y_pred = resnet18(x)
y0_pred = resnet18(x0)
final_y_pred = resnet18(final_x)

print(f'accuracy(x) = {accuracy(x, y)}')
print(f'accuracy(x0) = {accuracy(x0, y)}')
print(f'accuracy(final_x) = {accuracy(final_x, y)}')

# visualize first 20 images
x_numpy = x.cpu().detach().numpy()
x0_numpy = x0.cpu().detach().numpy()
final_x_numpy = final_x.cpu().detach().numpy()

fig, axs = plt.subplots(3, 20, figsize=(20, 6))
for i in range(20):
    axs[0, i].imshow(x_numpy[-i, 0], cmap='gray')
    axs[1, i].imshow(x0_numpy[-i, 0], cmap='gray')
    axs[2, i].imshow(final_x_numpy[-i, 0], cmap='gray')

In [None]:
resnet18.eval()
for p in resnet18.parameters():
    p.requires_grad_(False)

optimizer = torch.optim.Adam(unet.parameters(), lr=1e-3)

for epoch in range(10):
    unet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.cuda().repeat(1, 3, 1, 1).clone().detach()
        y = y.cuda()

        g = samples_from_haar_distribution(2, x.size(0), x.device, x.dtype)
        gx = so2_action(g, x).clone().detach()
        alpha = torch.randn(x.size(0), 1).cuda()
        grad_target = so2_lie_algebra_gradient(alpha, gx, y).clone().detach()
        grad_pred = so2_lie_algebra_gradient_unet(alpha, gx, create_graph=True)
        loss = nn.functional.mse_loss(grad_pred, grad_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'epoch {epoch}, iter {i}, loss {loss.item()}')

    unet.eval()
    test_loss = 0
    total = 0
    for x, y in test_loader:
        x = x.cuda().repeat(1, 3, 1, 1)
        y = y.cuda()

        g = samples_from_haar_distribution(2, x.size(0), x.device, x.dtype)
        gx = so2_action(g, x)
        alpha = torch.randn(x.size(0), 1).cuda()
        grad_target = so2_lie_algebra_gradient(alpha, gx, y)
        grad_pred = so2_lie_algebra_gradient_unet(alpha, gx)

        test_loss += nn.functional.mse_loss(grad_pred, grad_target).item()
        total += y.size(0)

    print(f'epoch {epoch}, test loss {test_loss / total}')