# Setup

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision

import os
import numpy as np
import matplotlib.pyplot as plt
import random

In [34]:
# define utility functions

def apply_field(img, field):
    """Applies field to a fixed image to estimate a target image."""
    
    # create identity coordinate grid
    grid = torch.stack(torch.meshgrid(torch.arange(28), torch.arange(28)), dim=-1).float()
    grid = grid.unsqueeze(0)  # Add batch dimension
    grid = grid.repeat(img.shape[0], 1, 1, 1)
    grid = (grid / 28) * 2 - 1
    
    # add field to grid
    field = field.permute(0,2,3,1)
    coords_adj = grid + field
    coords_adj = coords_adj.float().permute(0,2,1,3)

    # adjust image
    img_adj = F.grid_sample(img, coords_adj, mode="bilinear", align_corners=True)
    
    return(img_adj)

def loss(img1, img2, field, lmbda):
    """Calculates loss associated with image reconstruction and associated field."""
    
    # approximate field gradient
    diff_x = torch.diff(field[:,:,:,0], axis=1)
    diff_y = torch.diff(field[:,:,:,1], axis=2)
    diff_x = F.pad(diff_x, (0, 0, 1, 0), mode='constant')
    diff_y = F.pad(diff_y, (1, 0, 0, 0), mode='constant')

    # calculate loss
    loss_sim = torch.sum((img1 - img2)**2)
    loss_smooth = torch.sum(diff_x**2 + diff_y**2)
    loss_total = loss_sim + lmbda * loss_smooth
    return loss_total

def show_images(img, img_adj, img_goal):  
    """Utility for displaying fixed, estimated, and target images."""
    # convert pytorch to numpy
    img = img.detach().numpy()
    img_adj = img_adj.detach().numpy()
    img_goal = img_goal.detach().numpy()    
    
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Original Image')

    plt.subplot(1, 3, 2)
    plt.imshow(img_adj)
    plt.title('Estimated Image')

    plt.subplot(1, 3, 3)
    plt.imshow(img_goal)
    plt.title('Goal Image')

    plt.show()

In [36]:
# define model architecture

class DeformatioNet(nn.Module):
    def __init__(self, n_1, n_2):
        super(DeformatioNet, self).__init__()
        
        self.Encoder = nn.Sequential(
                nn.Conv2d(2, n_1, 3, stride=1, padding=2),
                nn.BatchNorm2d(n_1),
                nn.LeakyReLU(),
                nn.Dropout(0.2),
                nn.MaxPool2d(2),
                nn.Conv2d(n_1, n_2, 3, stride=1, padding=1),
                nn.BatchNorm2d(n_2),
                nn.LeakyReLU(),
                nn.Dropout(0.2),
                nn.MaxPool2d(2),
                nn.Conv2d(n_2, n_2, 3, stride=1, padding=1),
                nn.BatchNorm2d(n_2),
                nn.LeakyReLU(),
                nn.Dropout(0.2),
            )

        self.Decoder = nn.Sequential(
                nn.ConvTranspose2d(n_2, n_2, 3, stride=1, padding=1),
                nn.BatchNorm2d(n_2),
                nn.LeakyReLU(),
                nn.Dropout(0.2),
                nn.ConvTranspose2d(n_2, n_1, 3, stride=1, padding=1),
                nn.BatchNorm2d(n_1),
                nn.LeakyReLU(),
                nn.Upsample(scale_factor=2, mode="bicubic", align_corners=True),
                nn.Dropout(0.2),
                nn.ConvTranspose2d(n_1, 2, 3, stride=1, padding=1),
                nn.BatchNorm2d(2),
                nn.Upsample(scale_factor=2, mode="bicubic", align_corners=True),
                nn.Dropout(0.2),
                nn.Tanh()
            )
    
    def forward(self, x, y):
        # stack images
        z = torch.cat((x,y),
                      dim=1)
        
        # encode images into latent space
        enc = self.Encoder(z)
        
        # decode latent space into deformation field
        field = self.Decoder(enc)

        # (will use to adjust image in post)
        return(field)

# Model Training

In [43]:
# set hyperparameters
epochs = 40
batch_size = 179
val_iter = 3

In [106]:
transform = torchvision.transforms.ToTensor()

# load training data
train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train.data = train.data[train.targets == 7]
train.targets = train.targets[train.targets == 7]

# split into training/validation sets
train, val = random_split(train, [int(0.8 * len(train)), len(train) - int(0.8 * len(train))])

# load test data
test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test.data = test.data[test.targets == 7]
test.targets = test.targets[test.targets == 7]
test_loader = DataLoader(test, batch_size=1028, shuffle=True)
test_loader2 = DataLoader(test, batch_size=1028, shuffle=True)

# load more test data
test4 = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test4.data = test4.data[test4.targets == 4]
test4.targets = test4.targets[test4.targets == 4]
test4_loader = DataLoader(test4, batch_size=982, shuffle=True)
test4_loader2 = DataLoader(test4, batch_size=982, shuffle=True)

test6 = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test6.data = test6.data[test6.targets == 6]
test6.targets = test6.targets[test6.targets == 6]
test6_loader = DataLoader(test6, batch_size=982, shuffle=True)
test6_loader2 = DataLoader(test6, batch_size=982, shuffle=True)

test1 = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test1.data = test1.data[test1.targets == 1]
test1.targets = test1.targets[test1.targets == 1]
test1_loader = DataLoader(test1, batch_size=982, shuffle=True)
test1_loader2 = DataLoader(test1, batch_size=982, shuffle=True)

In [45]:
# training loop

def train_model(model, losses,
                train, val,
                val_iter, 
                lmbda, lmbda2, log_lr):
    """Trains model for a single epoch."""
    
    base_batch = len(losses)
    
    # set up data loaders
    torch.manual_seed(random.randint(0, 100000))
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    torch.manual_seed(random.randint(0, 100000))
    train_loader2 = DataLoader(train, batch_size=batch_size, shuffle=True)

    torch.manual_seed(random.randint(0, 100000))
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
    torch.manual_seed(random.randint(0, 100000))
    val_loader2 = DataLoader(val, batch_size=batch_size, shuffle=True)
    
    # set optimizer
    optimizer = optim.Adam(model.parameters(), lr=10**log_lr, weight_decay=lmbda2)
    
    # run training loop
    
    for batch_id, (img, _) in enumerate(train_loader):
        print(f"Batch {batch_id+1} of {len(train_loader)}", end="\r")
        _, (img_goal, _) = next(enumerate(train_loader2))
        
        # training step
        model.train()
        optimizer.zero_grad()
        
        field = model(img, img_goal)
        
        img_adj = apply_field(img, field)
        
        loss_train = loss(img_adj, img_goal, field, lmbda)
        loss_train.backward()
        optimizer.step()
        
        # validation step
        loss_val = 0
        with torch.no_grad():
            for i in range(val_iter):
                _, (img_val, _) = next(enumerate(val_loader))
                _, (img_val_goal, _) = next(enumerate(val_loader2))
                
                field_val = model(img_val, img_val_goal)
                img_val_adj = apply_field(img_val, field_val)
                loss_val += loss(img_val_adj, img_val_goal, field_val, lmbda)
            loss_val = loss_val/val_iter

        losses.append([base_batch + batch_id, 
                       torch.log(loss_train).item(), torch.log(loss_val).item()])    
        
    return model, losses

In [None]:
# hyperparameter tuning

lmbdas = [-1, -2, -3]
lmbdas2 = [-1]
log_lrs = [-1, -2, -3]
n_1s = [256]
n_2s = [256]

i = 0
for n_2 in n_2s:
    for n_1 in n_1s:
        for log_lr in log_lrs:
            for lmbda in lmbdas:
                for lmbda2 in lmbdas2:
                    i += 1

                    # load hyperparameters
                    # skip if already tested
                    if os.path.exists("hyperparameters.pth"):
                        params_load = torch.load("hyperparameters.pth")
                        losses_load = torch.load("losses.pth")
                    else:
                        params_load = torch.empty(0,5)
                        losses_load = torch.empty(28*epochs,3,0)

                    row_match = (params_load[:, 0] == lmbda) & (params_load[:, 1] == lmbda2) & (params_load[:, 2] == log_lr) & (params_load[:, 3] == n_1) & (params_load[:, 4] == n_2)
                    if torch.any(row_match):
                        continue

                    # run training loop
                    model = DeformatioNet(n_1, n_2)

                    print(f"Combination {i} of {len(lmbdas)*len(lmbdas2)*len(log_lrs)*len(n_1s)*len(n_2s)}")
                    print(f"Current parameters: lmbda={lmbda}; lmbda2={lmbda2}, log_lr={log_lr}; n_1={n_1}; n_2={n_2}")

                    losses = []
                    for epoch in range(epochs):
                        print(f"Epoch {epoch+1} of {epochs}     ")
                        model, losses = train_model(model, losses,
                                                    train, val,
                                                    val_iter, 
                                                    10**lmbda, 10**lmbda2, log_lr)


                    # save tested hyperparameters
                    losses_save = torch.Tensor(losses).unsqueeze(2)
                    params_save = torch.Tensor([lmbda, lmbda2, log_lr, n_1, n_2]).unsqueeze(0)

                    losses_out = torch.cat((losses_load, losses_save), dim=2)
                    params_out = torch.cat((params_load, params_save), dim=0)

                    torch.save(losses_out, "losses.pth")
                    torch.save(params_out, "hyperparameters.pth")
                    print("Test saved!                    \n")

# plot loss
losses_out = losses_save
plt.plot(losses_out[:,0], losses_out[:,1], label='Training')
plt.plot(losses_out[:,0], losses_out[:,2], label='Validation')

plt.xlabel('Iteration')
plt.ylabel('Log-loss')
plt.legend()

plt.show()

In [None]:
# load in best hyperparameters

params = torch.load("hyperparameters.pth")
losses = torch.load("losses.pth")

mins, _ = torch.min(losses[:, 2, :], dim=0)
min_all, ind = torch.min(mins, dim=0)

lmbda_best, lmbda2_best, log_lr_best, n_1_best, n_2_best = params[ind,:]
log_lr_best = int(log_lr_best)
n_1_best = int(n_1_best)
n_2_best = int(n_2_best)

print(f"Best model: lmbda={lmbda_best}, lmbda2={lmbda2_best}, log_lr={log_lr_best}, n_1={n_1_best}, n_2={n_2_best}")
print(f"Combination {ind} of {len(params)}")

In [None]:
# train tuned model

## bested model parameters
#lmbda_best = -3
#lmbda2_best = -1
#n_1_best = 256
#n_2_best = 256
#log_lr_best = -1
#epochs = 20

model = DeformatioNet(n_1_best, n_2_best)
test_epochs = 40

best_path = "best_state256.pth"
if os.path.exists(best_path):
    model.load_state_dict(torch.load(best_path))
    losses = torch.load("best_losses256.pth")
    min_epoch = ((losses[-1][0]+1)/28).int()
    losses = losses.tolist()
else:
    min_epoch = 0
    losses = []

for epoch in range(min_epoch, test_epochs):
    print(f"Epoch {epoch+1} of {test_epochs}     ")
    model, losses = train_model(model, losses,
                                train, val,
                                val_iter, 
                                10**lmbda_best, 10**lmbda2_best, log_lr_best)
    torch.save(model.state_dict(), "best_state256.pth")
    torch.save(torch.Tensor(losses), "best_losses256.pth")
    
    # early stopping
    mean_val_loss = torch.mean(torch.Tensor(losses)[-27:-1, -1])
    mean_val_loss_prev = torch.mean(torch.Tensor(losses)[-55:-28, -1])
    if (epoch > 1) & (mean_val_loss > mean_val_loss_prev):
        break
    
torch.save(model.state_dict(), "best_state256.pth")
torch.save(torch.Tensor(losses), "best_losses256.pth")

losses_out = torch.Tensor(losses)
plt.plot(losses_out[:,0], losses_out[:,1], label='Training')
plt.plot(losses_out[:,0], losses_out[:,2], label='Validation')

plt.xlabel('Iteration')
plt.ylabel('Log-loss')
plt.legend()

plt.show()

# Results

In [None]:
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
train_loader2 = DataLoader(train, batch_size=batch_size, shuffle=True)

# set up demo images
img, _ = next(iter(test_loader))
img_goal, _ = next(iter(test_loader2))

img_train, _ = next(iter(train_loader))
img_goal_train, _ = next(iter(train_loader2))

# generate deformation field
field = model(img, img_goal)
field_train = model(img_train, img_goal_train)

img_adj = apply_field(img, field)
img_adj_train = apply_field(img_train, field_train)


# display original/transformed images
# test dataset
field_out = field[0,:,:,:]
field_out = field_out.detach().numpy()

x, y = torch.meshgrid(torch.arange(28), torch.arange(28))
plt.quiver(x, y, field_out[0], field_out[1], scale=30)
plt.show()

show_images(img[0, 0, :, :], img_adj[0, 0, :, :], img_goal[0, 0, :, :])


# training dataset
field_out_train = field_train[0,:,:,:]
field_out_train = field_out_train.detach().numpy()

plt.quiver(x, y, field_out_train[0], field_out_train[1], scale=30)
plt.show()

show_images(img_train[0, 0, :, :], img_adj_train[0, 0, :, :], img_goal_train[0, 0, :, :])

In [None]:
# generalization
img4, _ = next(iter(test4_loader))
img4_goal, _ = next(iter(test4_loader2))

# generate deformation field
field4 = model(img4, img4_goal)

img4_adj = apply_field(img4, field4)

# display original/transformed images

# test dataset
field_out = field4[0,:,:,:]
field_out = field_out.detach().numpy()

x, y = torch.meshgrid(torch.arange(28), torch.arange(28))
plt.quiver(x, y, field_out[0], field_out[1], scale=30)
plt.show()

In [None]:
# generalization
img6, _ = next(iter(test6_loader))
img6_goal, _ = next(iter(test6_loader2))

# generate deformation field
field6 = model(img6, img6_goal)

img6_adj = apply_field(img6, field6)

# display original/transformed images

# test dataset
field_out = field6[0,:,:,:]
field_out = field_out.detach().numpy()

x, y = torch.meshgrid(torch.arange(28), torch.arange(28))
plt.quiver(x, y, field_out[0], field_out[1], scale=30)
plt.show()

In [None]:
# generalization
img1, _ = next(iter(test1_loader))
img1_goal, _ = next(iter(test1_loader2))

# generate deformation field
field1 = model(img1, img_goal)

img1_adj = apply_field(img1, field1)

# display original/transformed images

# test dataset
field_out = field1[0,:,:,:]
field_out = field_out.detach().numpy()

x, y = torch.meshgrid(torch.arange(28), torch.arange(28))
plt.quiver(x, y, field_out[0], field_out[1], scale=30)
plt.show()

show_images(img1[0, 0, :, :], img1_adj[0, 0, :, :], img1_goal[0, 0, :, :])
show_images(img4[0, 0, :, :], img4_adj[0, 0, :, :], img4_goal[0, 0, :, :])
show_images(img6[0, 0, :, :], img6_adj[0, 0, :, :], img6_goal[0, 0, :, :])

In [None]:
# generate test error

# mean loss: 7s
_, (img, _) = next(enumerate(test_loader))
_, (img_goal, _) = next(enumerate(test_loader2))
field = model(img, img_goal)
img_adj = apply_field(img, field)
loss7 = loss(img_adj, img_goal, field, lmbda)/len(test)

# mean loss: 4s
_, (img, _) = next(enumerate(test4_loader))
_, (img_goal, _) = next(enumerate(test4_loader2))
field = model(img, img_goal)
img_adj = apply_field(img, field)
loss4 = loss(img_adj, img_goal, field, lmbda)/len(test4)

# mean loss: 6s
_, (img, _) = next(enumerate(test6_loader))
_, (img_goal, _) = next(enumerate(test6_loader2))
field = model(img, img_goal)
img_adj = apply_field(img, field)
loss6 = loss(img_adj, img_goal, field, lmbda)/len(test6)

# mean loss: 4s
_, (img, _) = next(enumerate(test1_loader))
_, (img_goal, _) = next(enumerate(test1_loader2))
field = model(img, img_goal)
img_adj = apply_field(img, field)
loss1 = loss(img_adj, img_goal, field, lmbda)/len(test1)

bars = plt.bar([0,1,2,3], [loss7.item(), loss1.item(), loss4.item(), loss6.item()])
bars[0].set_color('orange')
plt.ylabel('Mean Loss')
plt.xlabel('Number')
plt.xticks([0, 1, 2, 3], ['7', '1', '4', '6'])
plt.show()