In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
#The Encoder network for single view 3D reconstruction is a ResNet18 pretrained
#on the ImageNet dataset with the last fully-connected layer adjusted to project
#the features to a 256 dimensional embedding, "c"
from torchvision.models.resnet import resnet18 as _resnet18
import numpy
from PIL import Image
import io

In [3]:
K = 10
batch_size = 20
device = 'cuda'

In [4]:
class Block(nn.Module):
    def __init__(self):
        super(Block,self).__init__()
        self.fc1 = nn.Conv1d(256,256,kernel_size=1)
        self.fc2 = nn.Conv1d(256,256,kernel_size=1)
        self.bn1 = nn.BatchNorm1d(256, affine=False, track_running_stats=True)
        self.bn2 = nn.BatchNorm1d(256, affine=False, track_running_stats=True)
        self.gammaLayer1 = nn.Conv1d(256,256,kernel_size=1)
        self.gammaLayer2 = nn.Conv1d(256,256,kernel_size=1)
        self.betaLayer1 = nn.Conv1d(256,256,kernel_size=1)
        self.betaLayer2 = nn.Conv1d(256,256,kernel_size=1)
        
    def forward(self,y):
        x = y['ex']
        encoding = y['enc']
        gamma = self.gammaLayer1(encoding)
        beta = self.betaLayer1(encoding)
        #First apply Conditional Batch Normalization
        out = gamma*self.bn1(x) + beta
        #Then ReLU activation function
        out = F.relu(out)
        #fully connected layer
        out = self.fc1(out)
        #Second CBN layer
        gamma = self.gammaLayer2(encoding)
        beta = self.betaLayer2(encoding)
        out = gamma*self.bn2(out) + beta
        #RELU activation
        out = F.relu(out)
        #2nd fully connected
        out = self.fc2(out)
        #Add to the input of the ResNet Block 
        out = x + out
        
        return {'ex':out, 'enc':encoding}

In [5]:
class OccupancyModel(nn.Module):
    def __init__(self):
        super(OccupancyModel,self).__init__()
        self.blocks = self.makeBlocks()
        self.encoderModel = _resnet18(pretrained=True)
        self.fc_enc = nn.Linear(1000, 256)
        self.gammaLayer = nn.Conv1d(256,256,kernel_size=1)
        self.betaLayer = nn.Conv1d(256,256,kernel_size=1)
        self.cbn = nn.BatchNorm1d(256, affine=False, track_running_stats=True)
        self.fc1 = nn.Conv1d(3,256,kernel_size=1)
        self.fc2 = nn.Conv1d(256,1,kernel_size=1)
        
    def makeBlocks(self):
        blocks = []
        for _ in range(5):
            blocks.append(Block())
        return nn.Sequential(*blocks)
   
  
    def forward(self,x,img):
        img = self.encoderModel(img)
        img = self.fc_enc(img)
        img = img.view(-1,256,1)
        x = self.fc1(x)
        #5 pre-activation ResNet-blocks
        x = self.blocks({'enc':img , 'ex':x })
        x = x['ex']
        #CBN
        gamma = self.gammaLayer(img)
        beta = self.betaLayer(img)
        x = gamma*self.cbn(x) + beta
        x = F.relu(x)
        x = self.fc2(x)
        x = x.view(-1,1)
        x = torch.sigmoid(x)
        return x

In [None]:
model = OccupancyModel()
#Input to the occupancy network architecture is the 
#output of the encoder network and a batch of 3D coordinates. 
coords = torch.rand(20,3,1)
image = torch.rand(20,3,137,137)
model.eval()

p = model(coords,image)


In [None]:
#The following is evidence that the fc layers in the encoder are indeed updating their weights 
modelCriterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)
model.train()



In [None]:
output = model(coords,image)
print(f"out: {output}")
loss = modelCriterion(output, torch.rand(output.size()))
print(f"loss: {loss}")
loss.backward()
optimizer.step()

In [25]:
#load some data:
#The .npz contains "points, occupancies, loc, scale" 
with numpy.load("/home/andrea/Documents/GradSchool/OccupancyNetworks/occupancy_networks/data/ShapeNet/02691156/fd528602cbde6f11bbf3143b1cb6076a/points.npz") as data:
    pts = torch.tensor(data["points"], dtype=torch.float)
    occupancies = torch.tensor(numpy.unpackbits(data["occupancies"])[:pts.size()[0]], dtype=torch.float)

image = numpy.array(Image.open("/home/andrea/Documents/GradSchool/OccupancyNetworks/occupancy_networks/data/ShapeNet/02691156/fd528602cbde6f11bbf3143b1cb6076a/img_choy2016/015.jpg"))
#At least for this image directory, the jpgs come in as 137,137,3
image = torch.tensor(image,dtype=torch.float).permute(2,0,1)
image = image.view(1,3,137,137)

train_loader = torch.utils.data.DataLoader(list(zip(pts,occupancies)), batch_size=batch_size)


In [6]:
def train(epoch, model, trainloader, optimizer):
    modelCriterion = nn.BCELoss()
    model.train()
    for batch_idx, data in enumerate(train_loader):
        (images, pts, occupancies) = data
        #Each batch size contains batch_size sets of "K" points
        #Collapse those two dimensions
        images = images.view(batch_size*K,3,137,137).cuda() #make robust
        pts = pts.view(batch_size*K, 3,1).cuda()
        occupancies = occupancies.view(batch_size*K,1).cuda()
        optimizer.zero_grad()
        output = model(pts, images) #a probability for each point 
        loss = modelCriterion(output, occupancies)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()))
        if batch_idx % 100 == 0:
            print("Saving to model1.pth")
            torch.save(model.state_dict(), "model1.pth")

In [7]:
def validation(model, val_loader):
    model.eval()
    modelCriterion = nn.BCELoss()
    validation_loss = 0
    correct = 0
    for batch_idx, data in enumerate(val_loader):
        (images, pts, occupancies) = data
        images = images.view(batch_size*K,3,137,137).cuda() #make robust
        pts = pts.view(batch_size*K, 3,1).cuda()
        occupancies = occupancies.view(batch_size*K,1).cuda()
        output = model(pts, images)
        loss = modelCriterion(output, occupancies)
        validation_loss += loss.item()

        threshold = 0.2
        roundedOut = [0 if output < threshold else 1]
        correct += roundedOut.eq(occupancies).sum()
        validation_loss /= len(val_loader.dataset)
        print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        validation_loss, correct, len(occupancies), 100. * correct / len(occupancies)))

In [8]:
#choose a category and load all of the available data:
import random
topdir = "/home/andrea/Documents/GradSchool/OccupancyNetworks/occupancy_networks"
imageFiles = ["000.jpg","001.jpg", "002.jpg","003.jpg", "004.jpg", "005.jpg", "006.jpg", "007.jpg", "008.jpg",
             "009.jpg", "010.jpg", "011.jpg", "012.jpg", "013.jpg", "014.jpg", "015.jpg", "016.jpg", "017.jpg",
             "018.jpg", "019.jpg", "020.jpg", "023.jpg"]

#One DataSetClass per subdirectory in a category, will return "K" point samples and a single image randomly
#drawn from the 23 available
class DataSetClass(torch.utils.data.Dataset):
    def __init__(self, d):
        self.dir = d
        with numpy.load(f"{d}/points.npz") as data:
            self.pts = torch.tensor(data["points"], dtype=torch.float)
            self.occupancies = torch.tensor(numpy.unpackbits(data["occupancies"])[:self.pts.size()[0]], dtype=torch.float)
        self.K = K #TODO how many sample points should come in? 
        self.length = int(self.occupancies.size()[0]/self.K)
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):
        #pick an image randomly to be used an observation for this set of "K sample points"
        imageFile = imageFiles[random.randint(0, len(imageFiles)-1)]
        with Image.open(f"{self.dir}/img_choy2016/{imageFile}") as image:
                image = numpy.array(image)
                image = torch.tensor(image,dtype=torch.float)
                #if the image is grey scale, stack 3 to conform dimensions
                if len(image.size()) < 3:
                    image = torch.stack([image, image, image])
                else:
                    image = image.permute(2,0,1)
        #Stack identical copies of the image so we have one for each input point
        #Maybe there is a better way
        image = torch.stack([image for _ in range(self.K)])
        #sampling in order is fine? 
        return image, self.pts[idx*self.K:(idx*self.K+self.K)], self.occupancies[idx*self.K:(idx*self.K+self.K)]

       
#catalogue all of the directories with the chosen category
trainingDirs = []
tablesDirectory=f"{topdir}/data/ShapeNet/02828884"
with io.open(f"{tablesDirectory}/train.lst") as trainlist:
    for traindir in trainlist.readlines():
        trainingDirs.append(f"{tablesDirectory}/{traindir.strip()}")
dataSets = []
for tdir in trainingDirs:
    dataSets.append(DataSetClass(tdir))
data = torch.utils.data.ConcatDataset(dataSets)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
#Get the validation data
valDirs = []
with io.open(f"{tablesDirectory}/val.lst") as vallist:
    for valdir in vallist.readlines():
        valDirs.append(f"{tablesDirectory}/{valdir.strip()}")
dataSets = []
for vdir in valDirs:
    dataSets.append(DataSetClass(vdir))
val_data = torch.utils.data.ConcatDataset(dataSets)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=5, shuffle=True)

In [9]:
model = OccupancyModel().cuda()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

for epoch in range(1):
    train(epoch,model,train_loader,optimizer)
    validation(model, val_loader)

Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth
Saving to model1.pth


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "model1.pth")

In [None]:
#TODO - switch to cabinet- we need a cube like object for more 1s
#TODO switch in BCE with Logits for loss, take out sigmoid in forward, and fix validation
#To take sigmoid at the output to get the probability 
#Seems like increasing batchsize to 20 and shuffle = True keep enough 1 occupancies in the game? 