In [20]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
import torchvision.transforms as T
import matplotlib.pyplot as plt
from IPython import display

In [21]:
plt.rcParams['figure.figsize'] = (11,11)
plt.rcParams['image.cmap'] = 'gray'

In [22]:
# Dataset
mean = 0.5
std = 0.5
external_drive = "D:/MNIST"
dataset = MNIST(external_drive, transform=T.Compose([T.ToTensor(), T.Normalize((mean,), (std,))]), download=True)

In [23]:
from copy import deepcopy

class Scaled_Act(nn.Module):
    to_str = {'Sigmoid' : 'sigmoid', 'ReLU': 'relu', 'Tanh' : 'tanh', 'LeakyReLU': 'leaky_relu'}
    def __init__(self, act, scale = None):
        super().__init__()
        self.act = act
        act_name = Scaled_Act.to_str.get(act._get_name(), act._get_name())
        param = getattr(act, 'negative_slope', None)
        self.scale = scale if scale else torch.nn.init.calculate_gain(act_name, param)

    def forward(self, input):
        return self.scale*self.act(input)

class Equal_LR:
    def __init__(self, name):
        self.name = name

    def compute_norm(module, weight):
        mode = 'fan_in'
        if hasattr(module, 'transposed') and module.transposed:
            mode = 'fan_out'
        return torch.nn.init._calculate_correct_fan(weight, mode)


    def scale_weight(self, module, input):
        setattr(module, self.name, module.scale*module.weight_orig)

    def fn(self, module):
        try:
            weight = getattr(module, self.name)
            module.scale = 1/np.sqrt(Equal_LR.compute_norm(module, weight))
            if isinstance(weight, torch.nn.Parameter):
                # register new parameter -- unscaled weight
                module.weight_orig = nn.Parameter(weight.clone()/module.scale)
                # delete old parameter
                del module._parameters[self.name]
            else:
                # register new buffer -- unscaled weight
                module.register_buffer('weight_orig', weight.clone()/module.scale)
                # delete old buffer
                del module._buffers[self.name]
            module.equalize = module.register_forward_pre_hook(self.scale_weight)
        except:
            pass

    def __call__(self, module):
        new_module = deepcopy(module)
        new_module.apply(self.fn)
        return new_module

def parameters_to_buffers(m):
    params = m._parameters.copy()
    m._parameters.clear()
    for n,p in params.items():
        m.register_buffer(n, p.data)

def grid(array, ncols=8):
    array = np.pad(array, [(0,0),(1,1),(1,1),(0,0)], 'constant')
    nindex, height, width, intensity = array.shape
    ncols = min(nindex, ncols)
    nrows = (nindex+ncols-1)//ncols
    r = nrows*ncols - nindex # remainder
    # want result.shape = (height*nrows, width*ncols, intensity)
    arr = np.concatenate([array]+[np.zeros([1,height,width,intensity])]*r)
    result = (arr.reshape(nrows, ncols, height, width, intensity)
            .swapaxes(1,2)
            .reshape(height*nrows, width*ncols, intensity))
    return np.pad(result, [(1,1),(1,1),(0,0)], 'constant')

class NextDataLoader(torch.utils.data.DataLoader):
    def __next__(self):
        try:
            return next(self.iterator)
        except:
            self.iterator = self.__iter__()
            return next(self.iterator)

def to_tensor(obj, device='cuda'):
    if obj.shape[-1] != 3 and obj.shape[-1] != 1:
        obj = np.expand_dims(obj,-1)
    if obj.ndim < 4:
        obj = np.expand_dims(obj,0)
    t = torch.tensor(np.moveaxis(obj,-1,-3), dtype=torch.float, device=device)
    return t

def to_img(obj):
    array = np.moveaxis(obj.data.cpu().numpy(),-3,-1)
    return array

In [24]:
# network functions
class Modulated_Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, latent_size,
                 demodulate=True, bias=True, stride=1, padding=0, dilation=1, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride,
                        padding, dilation, groups=1,
                        bias=bias, padding_mode='zeros')
        self.demodulate = demodulate
        # style mapping
        self.style = nn.Linear(latent_size, in_channels)
        # required shape might be different in transposed conv
        self.s_broadcast_view = (-1,1,self.in_channels,1,1)
        self.in_channels_dim = 2


    def convolve(self,x,w,groups):
        # bias would be added later
        return F.conv2d(x, w, None, self.stride, self.padding, self.dilation, groups=groups)


    def forward(self, x, v):
        N, in_channels, H, W = x.shape

        # new minibatch dim: (ch dims, K, K) -> (1, ch dims, K, K)
        w = self.weight.unsqueeze(0)

        # compute styles: (N, C_in)
        s = self.style(v) + 1

        # modulate: (N, ch dims, K, K)
        w = s.view(self.s_broadcast_view)*w

        # demodulate
        if self.demodulate:
            sigma = torch.sqrt((w**2).sum(dim=[self.in_channels_dim,3,4],keepdim=True) + 1e-8)
            w = w/sigma

        # reshape x: (N, C_in, H, W) -> (1, N*C_in, H, W)
        x = x.view(1, -1, H, W)

        # reshape w: (N, C_out, C_in, K, K) -> (N*C_out, C_in, K, K) for common conv
        #            (N, C_in, C_out, K, K) -> (N*C_in, C_out, K, K) for transposed conv
        w = w.view(-1, w.shape[2], w.shape[3], w.shape[4])

        # use groups so that each sample in minibatch has it's own conv,
        # conv weights are concatenated along dim=0
        out = self.convolve(x,w,N)

        # reshape back to minibatch.
        out = out.view(N,-1,out.shape[2],out.shape[3])

        # add bias
        if not self.bias is None:
            out += self.bias.view(1, self.bias.shape[0], 1, 1)

        return out


class Up_Mod_Conv(Modulated_Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, latent_size,
                demodulate=True, bias=True, factor=2):
        assert (kernel_size % 2 == 1)
        padding = (max(kernel_size-factor,0)+1)//2
        super().__init__(in_channels, out_channels, kernel_size, latent_size, demodulate, bias,
                        stride=factor, padding=padding)
        self.output_padding = torch.nn.modules.utils._pair(2*padding - kernel_size + factor)
        # transpose as expected in F.conv_transpose2d
        self.weight = nn.Parameter(self.weight.transpose(0,1).contiguous())
        self.transposed = True
        # taking into account transposition
        self.s_broadcast_view = (-1,self.in_channels,1,1,1)
        self.in_channels_dim = 1

    def convolve(self, x, w, groups):
        return F.conv_transpose2d(x, w, None, self.stride, self.padding, self.output_padding, groups, self.dilation)


class Down_Mod_Conv(Modulated_Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, latent_size,
                demodulate=True, bias=True, factor=2):
        assert (kernel_size % 2 == 1)
        padding = kernel_size//2
        super().__init__(in_channels, out_channels, kernel_size, latent_size, demodulate, bias,
                        stride=factor, padding=padding)

    def convolve(self, x, w, groups):
        return F.conv2d(x, w, None, self.stride, self.padding, self.dilation, groups=groups)


class Down_Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,
                bias=True, factor=2):
        assert (kernel_size % 2 == 1)
        padding = kernel_size//2
        super().__init__(in_channels, out_channels, kernel_size, factor, padding, bias=True)

    def convolve(self, x):
        return F.conv2d(x, w, None, self.stride, self.padding, self.dilation, self.groups)


class Noise(nn.Module):
    def __init__(self):
        super().__init__()
        self.noise_strength = nn.Parameter(torch.zeros(1))

    def forward(self, x, input_noise=None):
        if input_noise is None:
            input_noise = torch.randn(x.shape[0],1,x.shape[2],x.shape[3], device=x.device)
        noise = self.noise_strength*input_noise
        return x + noise

class Mapping(nn.Module):
    def __init__(self, n_layers, latent_size, nonlinearity, normalize=True):
        super().__init__()
        self.normalize = normalize
        self.layers = []
        for idx in range(n_layers):
            layer = nn.Linear(latent_size, latent_size)
            self.add_module(str(idx), layer)
            self.layers.append(layer)
            self.layers.append(nonlinearity)

    def forward(self, input):
        if self.normalize:
            input = input/torch.sqrt((input**2).mean(dim=1, keepdim=True) + 1e-8)
        for module in self.layers:
            input = module(input)
        return input


class G_Block(nn.Module):
    def __init__(self, in_fmaps, out_fmaps, kernel_size, latent_size, nonlinearity, factor=2, img_channels=3):
        super().__init__()
        inter_fmaps = (in_fmaps + out_fmaps)//2
        self.upconv = Up_Mod_Conv(in_fmaps, inter_fmaps, kernel_size, latent_size,
                                    factor=factor)
        self.conv = Modulated_Conv2d(inter_fmaps, out_fmaps, kernel_size, latent_size,
                                    padding=kernel_size//2)
        self.noise = Noise()
        self.noise2 = Noise()
        self.to_channels = Modulated_Conv2d(out_fmaps, img_channels, kernel_size=1,
                                    latent_size=latent_size, demodulate = False)
        self.upsample = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=False)
        self.act = nonlinearity

    def forward(self, x, v, y=None, input_noises=None):
        x = self.noise(self.upconv(x,v), None if (input_noises is None) else input_noises[:,0])
        x = self.act(x)
        x = self.noise2(self.conv(x,v), None if (input_noises is None) else input_noises[:,1])
        x = self.act(x)
        if not y is None:
            y = self.upsample(y)
        else:
            y = 0
        y = y + self.to_channels(x,v)
        return x, y

class D_Block(nn.Module):
    def __init__(self, in_fmaps, out_fmaps, kernel_size, nonlinearity, factor=2):
        super().__init__()
        inter_fmaps = (in_fmaps + out_fmaps)//2
        self.conv = nn.Conv2d(in_fmaps, inter_fmaps, kernel_size, padding=kernel_size//2)
        self.downconv = Down_Conv2d(inter_fmaps, out_fmaps, kernel_size, factor=factor)
        self.skip = Down_Conv2d(in_fmaps, out_fmaps, kernel_size=1, factor=factor)
        self.act = nonlinearity

    def forward(self, x):
        t = x
        x = self.conv(x)
        x = self.act(x)
        x = self.downconv(x)
        x = self.act(x)
        t = self.skip(t)
        return (x + t)/ np.sqrt(2)


class Minibatch_Stddev(nn.Module):
    def __init__(self, group_size=4):
        super().__init__()
        self.group_size = group_size

    def forward(self, x):
        s = x.shape
        t = x.view(self.group_size, -1, s[1], s[2], s[3])
        t = t - t.mean(dim=0, keepdim=True)
        t = torch.sqrt((t**2).mean(dim=0) + 1e-8)
        t = t.mean(dim=[1,2,3], keepdim=True) # [N/G,1,1,1]
        t = t.repeat(self.group_size,1,1,1).expand(x.shape[0],1,*x.shape[2:])
        return torch.cat((x,t),dim=1)
def G_logistic_ns(fake_logits):
    return -F.logsigmoid(fake_logits).mean() # -log(D(G(z)))


def D_logistic(real_logits, fake_logits):
    return torch.mean(-F.logsigmoid(real_logits) + F.softplus(fake_logits)) # -log(D(x)) - log(1-D(G(z)))

def R1_reg(real_imgs, real_logits):
    grads = torch.autograd.grad(real_logits.sum(), real_imgs, create_graph=True)[0]
    return torch.mean((grads**2).sum(dim=[1,2,3]))

class Path_length_loss(nn.Module):
    def __init__(self, decay=0.01):
        super().__init__()
        self.decay = decay
        self.avg = 0

    def forward(self, dlatent, gen_out):
        # Compute |J*y|.
        noise = torch.randn(gen_out.shape, device=gen_out.device)/np.sqrt(np.prod(gen_out.shape[2:])) #[N,Channels,H,W]
        grads = torch.autograd.grad((gen_out * noise).sum(), dlatent, create_graph=True)[0]  #[N, num_layers, dlatent_size]
        lengths = torch.sqrt((grads**2).mean(2).sum(1)) #[N]
        # Update exp average. Lengths are detached
        self.avg = self.decay*torch.mean(lengths.detach()) + (1-self.decay)*self.avg
        return torch.mean((lengths - self.avg)**2)


def Noise_reg(noise_maps, min_res=8):
    loss = 0
    for nmap in noise_maps:
        res = nmap.shape[-1]
        while res > 8:
            loss += ( torch.mean(nmap * nmap.roll(shifts=1, dims=-1), dim=[-1,-2])**2
                    + torch.mean(nmap * nmap.roll(shifts=1, dims=-2), dim=[-1,-2])**2 ).sum()
            nmap = F.avg_pool2d(nmap.squeeze(), 2)
            res = res//2
    return loss