In [None]:
# default_exp layers

In [None]:
#export
import torch
from torch import nn
import scipy.special
import numpy as np
import random

# Layers

## Pooling

In [None]:
#export

#From fastai library
class AdaptiveConcatPool2d(nn.Module):
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`."
    def __init__(self, sz = None):
        "Output will be 2*sz or 2 if sz is None"
        super().__init__()
        self.output_size = sz or 1
        self.ap = nn.AdaptiveAvgPool2d(self.output_size)
        self.mp = nn.AdaptiveMaxPool2d(self.output_size)

    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

## Upsample

In [None]:
# #export
# def icnr(x, scale=2, init=nn.init.kaiming_normal_):
#     "ICNR init of `x`, with `scale` and `init` function."
#     ni,nf,h,w = x.shape
#     ni2 = int(ni/(scale**2))
#     k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
#     k = k.contiguous().view(ni2, nf, -1)
#     k = k.repeat(1, 1, scale**2)
#     k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
#     x.data.copy_(k)

# class PixelShuffle_ICNR(nn.Module):
#     "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
#     def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, norm_type=NormType.Weight, leaky:float=None):
#         nf = ifnone(nf, ni)
#         self.conv = conv_layer(ni, nf*(scale**2), ks=1, norm_type=norm_type, use_activ=False)
#         icnr(self.conv[0].weight)
#         self.shuf = nn.PixelShuffle(scale)
#         # Blurring over (h*w) kernel
#         # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
#         # - https://arxiv.org/abs/1806.02658
#         self.pad = nn.ReplicationPad2d((1,0,1,0))
#         self.blur = nn.AvgPool2d(2, stride=1)
#         self.do_blur = blur
#         self.relu = relu(True, leaky=leaky)

#     def forward(self,x):
#         x = self.shuf(self.relu(self.conv(x)))
#         return self.blur(self.pad(x)) if self.do_blur else x

## Padding

In [None]:
#export
# The current pytorch version has a bug on circular padding.

class CircularPad2d(nn.Module):
    def __init__(self, pad_size):
        super().__init__()
        self.pad_size = pad_size
        
    def forward(self, x):
        return torch.nn.functional.pad(x, (self.pad_size, self.pad_size, self.pad_size, self.pad_size), mode='circular')
    

# Init

In [None]:
#export

#from fastai
def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

# Loss

In [None]:
#export


def generate_gaussian_points(N):

    # N uniformly distributed numbers in [0, 1]

    U = np.linspace(0.5/float(N), 1-0.5/float(N), N)

    # N Gaussian points : G = CDF^{-1}(U) = sqrt(2) * erfinv(2*U - 1)

    G = 1.41421356 * scipy.special.erfinv(2*U - 1)

    #

    return G

class OTGauss(nn.Module):
    def __init__(self, nb_dim, output_size, batch_size):
        super().__init__()
        self.nb_dim = nb_dim
        self.batch_size = batch_size
        self.output_size = output_size
        self.current_iter = 0
        
        self.generate_dims()
        
    
    def generate_dims(self):
        self.dims = torch.rand((self.nb_dim, self.output_size)).cuda()
        self.dims = (self.dims / torch.norm(self.dims, dim=1).unsqueeze(1))
        
        if not hasattr(self, 'projeted_target'): self.generate_target()
        
        self.projeted_target = self.dims @ self.target
        assert(self.projeted_target.shape == (self.nb_dim, self.batch_size))
        self.projeted_target = torch.sort(self.projeted_target, axis=-1)[0]
        
    def generate_target(self):
        self.target = torch.tensor(generate_gaussian_points(self.output_size)).unsqueeze(-1).repeat((1,self.batch_size))
        self.target = self.target.type(torch.FloatTensor).cuda()
        
    def forward(self, generated):
        proj_gen = self.dims @ generated.T
        #print(proj_gen.shape)
        assert(proj_gen.shape == (self.nb_dim, self.batch_size))

        proj_gen = torch.sort(proj_gen, axis=-1)[0]

        dist = torch.mean(((proj_gen - self.projeted_target) ** 2), axis=-1)
        #print(f'shape dist after sum {dist.shape}')
        assert(len(dist) == self.nb_dim)

        return torch.mean(dist)

        

# Tests

In [None]:
input = torch.randn((64,512,5,5))
ad_concat_pool = AdaptiveConcatPool2d(1)
r = ad_concat_pool(input)
assert(r.shape == (64,1024,1,1))

In [None]:
ot = OTGauss(5000, 128, 64)

rand1 = torch.randn((64,128)).cuda()
rand2 = torch.randn((64,128)).cuda()

randu1 = torch.rand((64,128)).cuda()
randu2 = torch.rand((64,128)).cuda() 

In [None]:
assert(ot(rand1) < 2)
assert(ot(rand2) < 2)

assert(ot(randu1) > 10)

assert(ot(randu2) > 10)

g = torch.tensor(generate_gaussian_points(128),  dtype=torch.float).view(1,128).cuda()
g = g.repeat((64, 1))

assert(ot(g) == 0)

col_idxs = list(range(128))
random.shuffle(col_idxs)
g = g[:, torch.tensor(col_idxs)]

assert(ot(g) < 0.8)

print(ot(randu1),ot(rand1),
      ot(g))

tensor(24.5740, device='cuda:0') tensor(1.2106, device='cuda:0') tensor(0.4753, device='cuda:0')


In [None]:
#Circular padding
r = torch.randn((3,2,5,5))
p = CircularPad2d(2)

assert(p(r).shape == (3,2,9,9))
assert(all(p(r)[1,1,1,2:-2]  == r[1,1,-1,:]))
assert(all(p(r)[1,1,1,2:-2]  == r[1,1,-1,:]))
assert(all(p(r)[1,1,2:-2,2] == p(r)[1,1,2:-2,7]))

In [None]:
from nbdev.export import *
notebook2script()


Converted 00_core.ipynb.
Converted 01_utils.ipynb.
Converted 02_training.ipynb.
Converted 03_layers.ipynb.
Converted index.ipynb.
