In [None]:
import numpy as np
import torch.nn as nn
from s2cnn.nn.soft.so3_conv import SO3Convolution
from s2cnn.nn.soft.s2_conv import S2Convolution
from s2cnn.nn.soft.so3_integrate import so3_integrate
from s2cnn.ops.so3_localft import near_identity_grid as so3_near_identity_grid
from s2cnn.ops.s2_localft import near_identity_grid as s2_near_identity_grid
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip, pickle
import numpy as np
from torch.autograd import Variable
from torch.distributions import Normal

import matplotlib.pyplot as plt


In [None]:
class MLP(nn.Module):
    def __init__(self ,H = [1,10,1], activation = nn.ReLU, end_activation = False):
        super(MLP, self).__init__()
        self.H = H
        self.activation = activation
        modules = []
        for input_dim, output_dim in zip (H, H[1:-1]):
            modules.append(nn.Linear(input_dim, output_dim))
            modules.append(self.activation())
        modules.append(nn.Linear(H[-2],H[-1]))
        if end_activation:
            modules.append(self.activation())
        self.module = nn.Sequential(*modules)    

    def forward(self, x):
        y = self.module(x)
        return y
    
class Nreparametrize(nn.Module):
    def __init__(self,input_dim, z_dim):
        super(Nreparametrize, self).__init__()
            
        self.input_dim = input_dim
        self.z_dim = z_dim
        
        self.sigma_linear = nn.Linear(input_dim, z_dim)
        self.mu_linear = nn.Linear(input_dim, z_dim)
        
    def forward(self, x, n=1):
        
        self.mu = self.mu_linear(x)
        
        self.sigma = F.softplus(self.sigma_linear(x))
        
        
        self.z = self.nsample(self.mu, self.sigma, n = n)
        
        return self.z
    
    
    def kl(self):

        return -0.5 * torch.sum(1 + 2*self.sigma.log() - self.mu.pow(2) - self.sigma**2, -1)
        
        
    def log_posterior(self):
        return Normal(self.mu, self.sigma).log_prob(self.z)
        
    def log_prior(self):
        return Normal(torch.zeros_like(self.mu), torch.ones_like(self.sigma)).log_prob(self.z)
        
        
    @staticmethod     
    def nsample(mu, sigma, n = 1):
        eps = Normal(torch.zeros_like(mu), torch.ones_like(mu)).sample_n(n)
        return mu + eps*sigma

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = None 
        self.decoder = None 
        self.reparametrize = []
        self.r_callback = []
        
    def encode(self,x, n = 1):
        
        h = self.encoder(x)
        z = [r(f(h), n) for r, f in zip(self.reparametrize, self.r_callback)]
        
        return z
    
    def kl(self):
        # NOTE always call after encode
        # TODO make this bahaviour impossible
        kl = [r.kl() for r in self.reparametrize]
        return kl
        
    def decode(self, z):
        return self.decoder(z) 
    
    def forward(self, x, n=1):
        z = self.encode(x, n=n)
    
        # flatten and stack z
        z_cat = torch.cat([v.view(n, x.size()[0], -1) for v in z], -1)
        
        return self.decode(z_cat)
    
    
    def recon_loss(self, x, x_recon):
        raise NotImplemented 
    
    def elbo(self, x):
        x_recon = self.forward(x)[0]
        kl = self.kl()
        # TODO maybe sum directly  without stacking 
        kl_summed  = torch.sum(torch.stack(kl,-1), -1)
        recon_loss = self.recon_loss(x, x_recon)
        return recon_loss, kl_summed
        
        
    def log_likelihood(self, x, n=1):
        raise NotImplemented       



In [None]:
class S2ConvNet(nn.Module):

    def __init__(self, f_list = [1,10,10], b_list = [30,10,6], activation = nn.ReLU):
        super(S2ConvNet, self).__init__()
        
        #TODO make boolean for integrate
        grid_s2 = s2_near_identity_grid()
        grid_so3 = so3_near_identity_grid()
        
        self.f_list = f_list
        self.b_list = b_list
        
        #self.mlp_dim = mlp_dim.copy()
        
        #self.mlp_dim.insert(0, f_list[-1]*(b_list[-1]*2)**3)
        #print(self.mlp_dim)
        self.activation = activation

        modules = []
        conv1 = S2Convolution(
            nfeature_in= f_list[0],
            nfeature_out=f_list[1],
            b_in=b_list[0],
            b_out=b_list[1],
            grid=grid_s2)
        modules.append(conv1)
        modules.append(self.activation())
    
        for f_in, f_out, b_in, b_out in zip(f_list[1:-1], f_list[2:], b_list[1:-1], b_list[2:]):
            #print(f_in, f_out, b_in, b_out)
            conv = SO3Convolution(
                                    nfeature_in=f_in,
                                    nfeature_out=f_out,
                                    b_in=b_in,
                                    b_out=b_out,
                                    grid=grid_so3)
            
            modules.append(conv)
            modules.append(self.activation())
            
        self.conv_module = nn.Sequential(*modules) 
        
        #self.mlp_module = MLP(H = self.mlp_dim, activation = self.activation)

    def forward(self, x):
       
        x = self.conv_module(x)
        #x = self.mlp_module(x.view(-1,self.mlp_dim[0]))
       

        #x = so3_integrate(x)

        
        return x

In [None]:
class S2DeconvNet(nn.Module):

    def __init__(self, f_list = [10,10,1], b_list = [5,10,30], mlp_dim = [10], activation = nn.ReLU):
        super(S2DeconvNet, self).__init__()
        #TODO make boolean for integrate
        grid_s2 = s2_near_identity_grid()
        grid_so3 = so3_near_identity_grid()
        
        self.f_list = f_list
        self.b_list = b_list
        self.mlp_dim = mlp_dim.copy()
        self.mlp_dim.append( f_list[0]*(b_list[0]*2)**3)
        print(self.mlp_dim)
        self.activation = activation
        
        self.mlp_module = MLP(H = self.mlp_dim, activation = self.activation)

        modules = []
        
  
        for f_in, f_out, b_in, b_out in zip(f_list[:-1], f_list[1:], b_list[:-1], b_list[1:]):
        
            modules.append(self.activation())
            
            if b_in < b_out:
                modules.append(torch.nn.Upsample(size= b_out*2,  mode='nearest'))
                
            conv = SO3Convolution(
                                    nfeature_in=f_in,
                                    nfeature_out=f_out,
                                    b_in=b_out,
                                    b_out=b_out,
                                    grid=grid_so3)
            
            modules.append(conv)
            
            
        self.conv_module = nn.Sequential(*modules) 
        

    def forward(self, x):
        
        x = self.mlp_module(x)
        shape = x.size()[:-1]
        x = x.view(-1, self.f_list[0], self.b_list[0]*2, self.b_list[0]*2, self.b_list[0]*2)
        
        x = self.conv_module(x)
        
        x = x.view(*shape, self.f_list[-1], self.b_list[-1]*2, self.b_list[-1]*2, self.b_list[-1]*2)
        
       
        
       

        #x = so3_integrate(x)

        # TODO better reduce gamma channel
        return x.mean(-1)
    

In [None]:
class NS2VAE(VAE):
    def __init__(self, z_dim = 10,
                       encoder_f=[1,10,10],
                       decoder_f=[10,10,1],
                       encoder_b=[30,20,6],
                       decoder_b=[5,15,30],
                       mlp_h=[100]):
        super(NS2VAE, self).__init__() 
        
        self.encoder = S2ConvNet(f_list = encoder_f, b_list = encoder_b)
        self.decoder = S2DeconvNet(f_list = decoder_f, b_list = decoder_b, mlp_dim = [z_dim])
        
        self.mlp_h = mlp_h.copy()
        self.mlp_h.insert(0, encoder_f[-1]*(encoder_b[-1]*2)**3)
        self.mlp = MLP(H = self.mlp_h, end_activation = True)
        
        self.repar1 = Nreparametrize(mlp_h[-1], z_dim)
        self.reparametrize = [self.repar1]
        self.r_callback = [lambda x: self.mlp(x.view(-1, self.mlp_h[0]))] 
        
        self.bce = nn.BCELoss(size_average=False)
        
    
    def recon_loss(self, x, x_recon):
        x_recon = F.sigmoid(x_recon)
        b = x_recon.log() * x + (1 - x_recon).log() * (1 - x)
        return -b

In [None]:
#vae = NS2VAE()

In [None]:
def load_data(path, batch_size):

    with gzip.open(path, 'rb') as f:
        dataset = pickle.load(f)

    train_data = torch.from_numpy(
        dataset["train"]["images"][:,None,:,:].astype(np.float32))
    train_labels = torch.from_numpy(
        dataset["train"]["labels"].astype(np.int64))

    mean = train_data.mean()
    stdv = train_data.std()

    train_dataset = data_utils.TensorDataset(train_data, train_labels)
    train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_data = torch.from_numpy(
        dataset["test"]["images"][:,None,:,:].astype(np.float32))
    test_labels = torch.from_numpy(
        dataset["test"]["labels"].astype(np.int64))

    test_dataset = data_utils.TensorDataset(test_data, test_labels)
    test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader, train_dataset, test_dataset

In [None]:
DEVICE_ID = 0
MNIST_PATH = "./mnist_example/s2_mnist.gz"
NUM_EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 5e-3

train_loader, test_loader, train_dataset, _ = load_data(
        MNIST_PATH, BATCH_SIZE)

torch.cuda.set_device(DEVICE_ID)

vae = NS2VAE(z_dim = 30)

print("#params", sum([x.numel() for x in vae.parameters()]))

if torch.cuda.is_available():
    vae.cuda(DEVICE_ID)

optimizer = torch.optim.Adam(
    vae.parameters(),
    lr=LEARNING_RATE)

In [None]:
for epoch in range(NUM_EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images)/255
        labels = Variable(labels)

        if torch.cuda.is_available():
            images = images.cuda(DEVICE_ID)
            labels = labels.cuda(DEVICE_ID)
        
        rec, kl = vae.elbo(images)

        
        elbo = (rec.sum(-1).sum(-1).sum(-1) + 0.1*kl ).mean()
        
        elbo.backward()

        optimizer.step()
        optimizer.zero_grad()

        print("\r elbo: %.4f", (elbo.cpu().data.numpy()), end="")
        if epoch > 2:
            lol()

    

In [None]:
plt.imshow(vae(images)[0,7,0].cpu().data.numpy(), cmap="gray")
plt.show()

In [None]:
vae(images)[0,0,0].cpu().data.numpy().shape