In [1]:
import torch
import torch.nn as nn
import numpy as np
from BigGAN import Generator
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
from torch.autograd import Variable
import torch.optim as optim
from torchvision.transforms import ToTensor, Compose, Resize
from latentaug.truncated import TruncatedResNet18
from latentaug.encoder import ApproximateEncoder

In [2]:
dataset = CIFAR10('.', train=True, transform=Compose([ToTensor()]))#, download=True)

In [3]:
def img(out):
    return np.transpose(out.detach().cpu(), (1, 2, 0))

In [4]:
class InvertedGenerator(nn.Module):
    def __init__(self, G, z_init=None, device='cuda:1', batch_size=1):
        super(InvertedGenerator, self).__init__()
        self.G = G
        self.batch_size = batch_size
        self.z_init = torch.FloatTensor(self.batch_size, 128).normal_(0, 1).to(device) if z_init is None else z_init
        self.delta = [
            Variable(self.z_init, requires_grad=True),
            #Variable(torch.FloatTensor(self.batch_size, 4096).zero_().to(device), requires_grad=True),
            #Variable(torch.FloatTensor(self.batch_size, 256, 4, 4).zero_().to(device), requires_grad=True),
            #Variable(torch.FloatTensor(self.batch_size, 256, 8, 8).zero_().to(device), requires_grad=True),
            #Variable(torch.FloatTensor(self.batch_size, 256, 16, 16).zero_().to(device), requires_grad=True),
            #Variable(torch.FloatTensor(self.batch_size, 256, 32, 32).zero_().to(device), requires_grad=True)
        ]
        
    def _compute_penalty(self):
        return torch.norm(I.delta[0], p=2, dim=1)**2
    '''
        penalty = torch.norm(self.delta[1].view(self.batch_size, -1), p=2, dim=1)**2
        #penalty += torch.norm(self.delta[2].view(self.batch_size, -1), p=2, dim=1)**2
        #penalty += torch.norm(self.delta[3].view(self.batch_size, -1), p=2, dim=1)**2
        #penalty += torch.norm(self.delta[4].view(self.batch_size, -1), p=2, dim=1)**2
        #penalty += torch.norm(self.delta[5].view(self.batch_size, -1), p=2, dim=1)**2
        #penalty += torch.norm(self.delta[6].view(self.batch_size, -1), p=2, dim=1)**2
        return penalty'''

    def forward(self, y):
        z = self.delta[0]
        ys = [y] * len(self.G.blocks)
        h = self.G.linear(z) #+ self.delta[1]
        h = h.view(h.size(0), -1, self.G.bottom_width, self.G.bottom_width)# + self.delta[2]
        
        for index, blocklist in enumerate(self.G.blocks):
            for block in blocklist:
                h = block(h, ys[index])# + self.delta[3+index]
        h = torch.tanh(self.G.output_layer(h))# + self.delta[5]
        return h
        

In [5]:
device = 'cuda:1'
trained_path = 'G_cur.pth'
G = Generator(n_classes=10, resolution=32, G_shared=False).to(device)
G.load_state_dict(torch.load(trained_path, map_location=torch.device(device)), strict=True)
for param in G.parameters():
    param.requires_grad = False

Param count for Gs initialized parameters: 4303875


In [6]:
E = ApproximateEncoder().to(device)
E.train(G, device=device)

In [7]:
inpt, label = dataset[7]
inpt = inpt.to(device).view(1, 3, 32, 32)
y = torch.zeros(1).long().to(device)
y[0] = label
rn = TruncatedResNet18(device)
for param in rn.parameters():
    param.requires_grad=False

In [8]:
z_init = E(inpt).detach()
I = InvertedGenerator(G, z_init.clone()).to(device)
I_optim = optim.Adam(I.delta, lr=0.01)
mse_loss = nn.MSELoss()
losses = []
vis_losses = []

In [None]:
for i in range(20000):
    I.zero_grad()
    recon = (I(y)+1)/2
    Wp = [2, 1, 1]
    Gf = [recon, rn(recon, n=0), rn(recon, n=5)]
    Re = [inpt, rn(inpt, n=0), rn(inpt, n=5)]
    loss = sum([Wp[i]*mse_loss(Gf[i], Re[i]) for i in range(len(Gf))])# + I._compute_penalty()
    loss.backward()
    I_optim.step()
    losses.append(loss.item())
    vis_losses.append(mse_loss(recon, inpt).item())

In [None]:
plt.gcf().set_size_inches(15, 10)

plt.subplot(2, 3, (1,3))
plt.title('Invertor Losses')
plt.plot(losses, label='Training Loss')
plt.plot(vis_losses, label='Reconstruction Loss')
plt.plot(np.array(losses)-np.array(vis_losses), label='Residual Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(2, 3, 4)
plt.title('Initialization')
plt.imshow(img((G(z_init, y)[0]+1)/2))
plt.axis('off')
plt.subplot(2, 3, 5)
plt.title('Optimization')
plt.imshow(img((recon[0])))
plt.axis('off')
plt.subplot(2, 3, 6)
plt.title('Target')
plt.imshow(img(inpt[0]))
plt.axis('off')

plt.savefig('reconstruction')
plt.show()