In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy
from PIL import Image
import io

In [3]:
K = 64
batch_size = 64
device = 'cuda'

In [4]:
class OM_Encoder(nn.Module):
    def __init__(self):
        super(OM_Encoder,self).__init__()
        self.fc1 = nn.Linear(3,128)
        self.fc2 = nn.Linear(128,128)
        self.fc3 = nn.Linear(256,128)
        self.fc4 = nn.Linear(256,128)
        self.mean_fc = nn.Linear(128,128)
        self.logstddev_fc = nn.Linear(128,128)
    
    def forward(self,x):
        x = x.squeeze()
        n, c, k = x.size()
        x = x.permute(0,2,1)

        x = F.relu(self.fc1(x))

        x = self.fc2(x)
        
        n,k,c = x.size()
        x = x.permute(0,2,1)

        pooled = F.max_pool1d(x, k).expand(x.size())
        x = torch.cat([x,pooled],dim=1)

        x = x.permute(0,2,1)

        x = F.relu(x)

        x = self.fc3(x)

        n,k,c = x.size()

        x = x.permute(0,2,1)

        pooled = F.max_pool1d(x, k)
        pooled = pooled.expand(x.size())

        x = torch.cat([x,pooled],dim=1)

        x = x.permute(0,2,1)

        x = F.relu(x)
        x = self.fc4(x)

        n,k,c = x.size()

        x = x.permute(0,2,1)

        x = F.max_pool1d(x, k)
        
        x= x.squeeze()

        mean = self.mean_fc(x)
        stddev = self.logstddev_fc(x)
        
        return mean,stddev

In [5]:
class Block(nn.Module):
    def __init__(self):
        super(Block,self).__init__()
        self.fc1 = nn.Conv2d(256,256,kernel_size=1)
        self.fc2 = nn.Conv2d(256,256,kernel_size=1)
        self.bn1 = nn.BatchNorm2d(256, affine=False, track_running_stats=True)
        self.bn2 = nn.BatchNorm2d(256, affine=False, track_running_stats=True)
        self.gammaLayer1 = nn.Conv1d(128,256,kernel_size=1)
        self.gammaLayer2 = nn.Conv1d(128,256,kernel_size=1)
        self.betaLayer1 = nn.Conv1d(128,256,kernel_size=1)
        self.betaLayer2 = nn.Conv1d(128,256,kernel_size=1)
        
    def forward(self,y):
        x = y['ex']
        n,c,k,d = x.size()

        encoding = y['enc']
        gamma = self.gammaLayer1(encoding)

        #Need to stack the beta and gamma
        #so that we multiply all the points for one mesh
        #by the same value
        gamma = torch.stack([gamma for _ in range(k)],dim=2)
        
        beta = self.betaLayer1(encoding)
        beta = torch.stack([beta for _ in range(k)],dim=2)

        #First apply Conditional Batch Normalization
        out = gamma*self.bn1(x) + beta
        #Then ReLU activation function
        out = F.relu(out)
        #fully connected layer
        out = self.fc1(out)
        #Second CBN layer
        gamma = self.gammaLayer2(encoding)
        gamma = torch.stack([gamma for _ in range(k)],dim=2)

        beta = self.betaLayer2(encoding)
        beta = torch.stack([beta for _ in range(k)],dim=2)
        
        out = gamma* self.bn2(out) + beta
        #RELU activation
        out = F.relu(out)
        #2nd fully connected
        out = self.fc2(out)
        #Add to the input of the ResNet Block 
        out = x + out
        
        return {'ex':out, 'enc':encoding}

In [17]:
class OccupancyModel(nn.Module):
    def __init__(self):
        super(OccupancyModel,self).__init__()
        self.blocks = self.makeBlocks()
        self.encoderModel = OM_Encoder()
        self.gammaLayer = nn.Conv1d(128,256,kernel_size=1)
        self.betaLayer = nn.Conv1d(128,256,kernel_size=1)
        self.cbn = nn.BatchNorm2d(256, affine=False, track_running_stats=True)
        self.fc1 = nn.Conv2d(3,256,kernel_size=1)
        self.fc2 = nn.Conv2d(256,1,kernel_size=1)
        
    def makeBlocks(self):
        blocks = []
        for _ in range(5):
            blocks.append(Block())
        return nn.Sequential(*blocks)
   
    def sampleFromZDist(self, z):
        mean, logstddev = z
        std = logstddev.mul(0.5).exp_()
        eps = torch.randn_like(logstddev,requires_grad=True)
        return eps.mul(std).add_(mean)
        
    def forward(self,x, z_eval=None):
        if self.training:
            z_dist = self.encoderModel(x)
            z = self.sampleFromZDist(z_dist)
            z = z.unsqueeze(-1)
        else:
            z = z_eval
            z_dist = (0,1)
        x = self.fc1(x)
        #5 pre-activation ResNet-blocks
        x = self.blocks({'enc':z, 'ex':x })
        x = x['ex']
        n,c,k,d = x.size()
        
        #CBN
        gamma = self.gammaLayer(z)
        
        gamma = torch.stack([gamma for _ in range(k)],dim=2)
        
        beta = self.betaLayer(z)
        beta = torch.stack([beta for _ in range(k)],dim=2)

        x = gamma.mul(self.cbn(x)).add_(beta)
        x = F.relu(x)
        x = self.fc2(x)
        #x = x.view(-1,1)
        #x = torch.sigmoid(x)
        return x, z_dist

In [None]:
om = OccupancyModel()
optimizer = optim.Adam(om.parameters(), lr = 0.001)
om.train()
om.cuda()

In [8]:
#choose a category and load all of the available data:
import random
topdir = "/home/andrea/Documents/GradSchool/OccupancyNetworks/occupancy_networks"

#One DataSetClass per subdirectory in a category, will return "K" point samples and a single image randomly
#drawn from the 23 available
class DataSetClass(torch.utils.data.Dataset):
    def __init__(self, d):
        self.dir = d
        with numpy.load(f"{d}/points.npz") as data:
            self.pts = torch.tensor(data["points"], dtype=torch.float)
            self.occupancies = torch.tensor(numpy.unpackbits(data["occupancies"])[:self.pts.size()[0]], dtype=torch.float)
        self.K = K 
        self.length = int(self.occupancies.size()[0]/self.K)
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):
        return self.pts[idx*self.K:(idx*self.K+self.K)], self.occupancies[idx*self.K:(idx*self.K+self.K)]

       
#catalogue all of the directories with the chosen category
trainingDirs = []
couchesDirectory=f"{topdir}/data/ShapeNet/04256520"
#couchesDirectory=f"{topdir}/data/ShapeNet/02828884"
with io.open(f"{couchesDirectory}/train.lst") as trainlist:
    for traindir in trainlist.readlines():
        trainingDirs.append(f"{couchesDirectory}/{traindir.strip()}")
dataSets = []
for tdir in trainingDirs:
    dataSets.append(DataSetClass(tdir))
data = torch.utils.data.ConcatDataset(dataSets)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
#Get the validation data
valDirs = []
with io.open(f"{couchesDirectory}/val.lst") as vallist:
    for valdir in vallist.readlines():
        valDirs.append(f"{couchesDirectory}/{valdir.strip()}")
dataSets = []
for vdir in valDirs:
    dataSets.append(DataSetClass(vdir))
val_data = torch.utils.data.ConcatDataset(dataSets)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [9]:
def train(epoch, model, trainloader, optimizer):
    decoderLoss = nn.BCEWithLogitsLoss(reduction='sum')
    #encoderLoss = nn.KLDivLoss(reduction='batchmean')

    model.train()
    for batch_idx, data in enumerate(train_loader):
        (pts, occupancies) = data
        #Each batch size contains batch_size sets of "K" points
        pts = pts.view(-1,K, 3,1).permute(0,2,1,3).cuda()
        occupancies = occupancies.view(-1,K,1).cuda()
        optimizer.zero_grad()
        
        pred,z_dist = model(pts) #a probability for each point, and the dist parameters of latent distribution
        pred = pred.permute(0,2,1,3).squeeze(-1)
        #targetNormal = torch.stack((torch.zeros_like(z_dist[0]),torch.ones_like(z_dist[1])))
        #encloss = encoderLoss(torch.stack(z_dist),targetNormal)
        mu, log_var = z_dist
        encloss = mu.pow(2).add_(log_var.exp()).mul_(-1).add_(1).add_(log_var)
        encloss = torch.sum(encloss).mul_(-0.5)
        #encloss = -0.5*torch.sum(1+z_dist[1] + z_dist[0].pow(2) - z_dist[1].exp())
        #print(encloss)
        decloss = decoderLoss(pred,occupancies)
        
        loss = (encloss + decloss/K)/batch_size
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader), 100. * batch_idx / len(train_loader),
                loss.item()))
            #print(f"Reconstruction Loss {decloss/(K*batch_size)}")
        if batch_idx % 100 == 0:
            print("Saving to model3.pth")
            torch.save(model.state_dict(), "unconditional_model3.pth")

In [None]:
train(0,om,train_loader,optimizer)

In [None]:
for ep in range(1,10):
    train(ep,om,train_loader,optimizer)

In [None]:
model = OccupancyModel()
model.load_state_dict(torch.load("unconditional_model3.pth",map_location=device))
model.cuda()
model.eval()

In [45]:
from torch.autograd import Variable
def validation(model, val_loader):
    model.eval()
    decoderLoss = nn.BCEWithLogitsLoss(reduction='mean')

    validation_loss = 0
    correct = 0
    for batch_idx, data in enumerate(val_loader):
        (pts, occupancies) = data
        pts = pts.view(-1,K, 3,1).permute(0,2,1,3).cuda()
        occupancies = occupancies.view(-1,K,1).cuda()
        z = Variable(torch.randn(pts.size()[0],128))
        z= z.unsqueeze(-1).cuda()
        pred,z_dist = model(pts,z) 
        pred = pred.permute(0,2,1,3).squeeze(-1)
        pred = torch.sigmoid(pred)
        #print(pred)
        #print(occupancies)
        loss = decoderLoss(pred, occupancies)
        validation_loss += loss.item()
        
        threshold = 0.6
        roundedOut = [1 if out > threshold else 0 for out in pred.view(-1)]
        roundedOut = torch.tensor(roundedOut).cuda()
        correctNow = roundedOut.eq(occupancies.view(-1)).sum()
        correct += correctNow
        validation_loss /= len(val_loader.dataset)
        print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        validation_loss, correctNow, pts.size()[0]*K, 100. * correctNow / (pts.size()[0]*K)))

In [None]:
validation(model, val_loader)