In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from mmd_pointnet import MMDWAE

import matplotlib.pyplot as plt

In [2]:
import h5py
import torch.utils.data as data

class H5Dataset(data.Dataset):

    def __init__(self, file_path, sample_size = 1024):
        super(H5Dataset, self).__init__()
        self.sample_size = sample_size
        random_index = np.random.choice(np.arange(10000), self.sample_size)
        h5_file = h5py.File(file_path)
        self.data = h5_file.get('data')
        self.target = h5_file.get('data')
        

    def __getitem__(self, index):            
        return (torch.from_numpy(self.data[index]).float(),
                torch.from_numpy(self.target[index]).float())

    def __len__(self):
        return self.data.shape[0]

  from ._conv import register_converters as _register_converters


In [3]:
data = H5Dataset('modelnet_train.hdf5')

In [4]:
batch_size = 10

In [5]:
train_loader = torch.utils.data.DataLoader(data,
    batch_size=batch_size, shuffle=True
)

In [6]:
for X, y in train_loader:
    print(X.shape)
    break

torch.Size([10, 1024, 3])


In [7]:
n = 1024
z_dim = 50

class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)


encoder = nn.Sequential(
      nn.Conv1d(3, 64, 1),
      nn.ReLU(),
      nn.BatchNorm1d(64),
      nn.Conv1d(64, 128, 1),
      nn.ReLU(),
      nn.BatchNorm1d(128),
      nn.Conv1d(128, 256, 1),
      nn.ReLU(),
      nn.BatchNorm1d(256),
      nn.Conv1d(256, 512, 1),
      nn.ReLU(),
      nn.BatchNorm1d(512),
      nn.MaxPool1d(1024, 1),
      View((-1,512)),
      nn.Linear(512,50)
    )

decoder = nn.Sequential(
      nn.Linear(50,512),
#       View((-1,1024,512)),
      nn.ReLU(),
      nn.Linear(512,1024),
      nn.Linear(1024, 2048),
      nn.ReLU(),
      nn.Linear(2048, 3*1024),
      View((-1, 1024, 3))
)

In [8]:
decoder(encoder(X[0:1].transpose(1,2))).shape

torch.Size([1, 1024, 3])

In [9]:
def sample_latent_prior(batch_size):
    return torch.normal(torch.zeros(batch_size, z_dim), torch.ones(batch_size, z_dim)).to(device)

In [10]:
def cost(Y_true, Y_pred):
    return torch.nn.functional.mse_loss(Y_pred, Y_true)

In [11]:
def kernel(z1, z2):
    z11 = z1.unsqueeze(1).repeat(1, z2.size(0), 1)
    z22 = z2.unsqueeze(0).repeat(z1.size(0), 1, 1)
    
    C = 1

    kernel_matrix = C/(1e-9+C+(z11-z22).pow(2).sum(2))
    kernel_sum = kernel_matrix.sum()

    return kernel_sum.to(device)

In [12]:
device = torch.device("cuda:"+str(1) if torch.cuda.is_available() else "cpu")

In [13]:
encoder.to(device)
decoder.to(device)

Sequential(
  (0): Linear(in_features=50, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=1024, bias=True)
  (3): Linear(in_features=1024, out_features=2048, bias=True)
  (4): ReLU()
  (5): Linear(in_features=2048, out_features=3072, bias=True)
  (6): View()
)

In [14]:
optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(decoder.parameters()),
            lr=1e-4
        )

In [15]:
mmd = MMDWAE(cost=cost, decoder=decoder, encoder=encoder, device=device, lamda_coeff=0.1, kernel=kernel,
       sample_latent_prior=sample_latent_prior, trainloader=train_loader, optimizer=optimizer)

In [23]:
mmd.train(500)


  0%|          | 0/500 [00:00<?, ?it/s][A
Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/Albert.Matveev/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/Albert.Matveev/anaconda3/lib/python3.6/site-packages/tqdm/_monitor.py", line 63, in run
    for instance in self.tqdm_cls._instances:
  File "/home/Albert.Matveev/anaconda3/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration

100%|██████████| 500/500 [1:58:05<00:00, 14.17s/it]


In [36]:
sampled = decoder(encoder(X[0:1].transpose(1,2).to(device)))

In [27]:
sampled = sampled.cpu().detach().numpy()

In [32]:
np.savetxt('out_pointnet.csv',sampled[0])