In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import binvox_rw

In [45]:
device = "cpu"
# if torch.backends.mps.is_available():
#     device = torch.device("mps")
device

'cpu'

In [48]:
class Generator(nn.Module):
    def __init__(self,cube_length = 32,latent_dimension = 200):
        super(Generator,self).__init__()
        self.cube_len = cube_length
        self.latent_dim = latent_dimension
        self.init_size = cube_length//16

        padd = (1,1,1) 
        self.lin=nn.Linear(in_features=self.latent_dim,out_features=256*self.init_size*self.init_size*self.init_size )

        self.layer1 = nn.Sequential(
            nn.ConvTranspose3d(self.cube_len*8, self.cube_len*4, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len*4),
            nn.ReLU()
        )
        # self.layer2 = nn.Sequential(
        #     nn.ConvTranspose3d(self.cube_len*8, self.cube_len*4, kernel_size=4, stride=2, bias=False, padding=padd),
        #     nn.BatchNorm3d(self.cube_len*4),
        #     nn.ReLU()
        # )
        self.layer2 = nn.Sequential(
            nn.ConvTranspose3d(self.cube_len*4, self.cube_len*2, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len*2),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.ConvTranspose3d(self.cube_len*2, self.cube_len, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len),
            nn.ReLU()
        )
        self.layer4 = nn.Sequential(
            nn.ConvTranspose3d(self.cube_len, self.cube_len//2, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len//2),
            nn.ReLU()
        )
        self.layer5 = nn.Sequential(
            # nn.ConvTranspose3d(self.cube_len, 1, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.Conv3d(self.cube_len//2, 1, kernel_size=3, stride=1,padding=1),
            nn.Tanh()
        )
    def forward(self, x):
        out = self.lin(x)
        # print("layer 1 before view",out.shape)

        obj=out.view(x.shape[0],256,self.init_size,self.init_size,self.init_size)
        # print("layer 1 after liner and view",obj.shape)
        # [batch,latent,1,1,1]
        obj = self.layer1(obj)
        # print("after layer 1",obj.shape)

        # [batch,256,2,2,2]
        obj = self.layer2(obj)
        # print("layer 2 after liner and view",obj.shape)

        # [batch,128,4,4,4]
        obj = self.layer3(obj)
        # print("layer 3 after liner and view",obj.shape)

        # [batch,64,8,8,8]
        obj = self.layer4(obj)
        # print("layer 4 after liner and view",obj.shape)

        # [batch,32,16,16,16]
        obj = self.layer5(obj)
        # print("layer 5 after liner and view",obj.shape)

        # [batch,1,32,32,32]

        return obj
    
class Discriminator(nn.Module):
    def __init__(self,cube_length = 32):
        super(Discriminator,self).__init__()
        self.cube_len = cube_length

        padd = (1,1,1)

        self.layer1 = nn.Sequential(
            nn.Conv3d(1, self.cube_len, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len),
            nn.LeakyReLU(0.01)
        )
        #16
        self.layer2 = nn.Sequential(
            nn.Conv3d(self.cube_len, self.cube_len*2, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len*2),
            nn.LeakyReLU(0.01)
        )
        #8
        self.layer3 = nn.Sequential(
            nn.Conv3d(self.cube_len*2, self.cube_len*4, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len*4),
            nn.LeakyReLU(0.01)
        )
        #4
        self.layer4 = nn.Sequential(
            nn.Conv3d(self.cube_len*4, self.cube_len*8, kernel_size=4, stride=2, bias=False, padding=padd),
            nn.BatchNorm3d(self.cube_len*8),
            nn.LeakyReLU(0.01)
        )
        #2
        # self.layer5 = nn.Sequential(
        #     nn.Conv3d(self.cube_len*8, 1, kernel_size=4, stride=2, bias=False, padding=padd),
        #     nn.Sigmoid()
        # )
        self.out=nn.Sequential(nn.Linear(256*((self.cube_len//(2**4))**3),2))

    def forward(self, x):
        # print(x.shape,"x")
        out = x.view(-1, 1, self.cube_len, self.cube_len,self.cube_len)
        # [batch,1,32,32,32]
        # print(out.size(),"disc l1 inp")

        out = self.layer1(out)
        # [batch,32,16,16,16]
        # print(out.size(),"disc l1 out")

        out = self.layer2(out)
        # [batch,64,8,8,8]
        # print(out.size(),"disc l2 out")

        out = self.layer3(out)
        # [batch,128,4,4,4]
        # print(out.size(),"disc l3 out")

        out = self.layer4(out)
        # [batch,256,2,2,2]
        # print(out.size(),"disc l4 out")

        # out = self.layer5(out)
        # print(out.size(),"disc l5 out")
        out=out.view(out.shape[0],-1)

        # [batch,200,1,1,1]
        pred=self.out(out)


        return pred


In [50]:
def readBinvox3DObject(path):
    with open(path, 'rb') as file:
        data = np.float32(binvox_rw.read_as_3d_array(file).data)
    return data


transform=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5])
    ]
)

class ShapeNetDataset(Dataset):
    """Custom Dataset compatible with torch.utils.data.DataLoader"""

    def __init__(self, data_path,transform):
        """Set the path for Data.

        Args:
            root: image directory.
            transform: Tensor transformer.
        """
        self.data_path = data_path
        self.listdir = os.listdir(self.data_path)
        self.transform = transform


    def __getitem__(self, index):
        
        model_3d_file = [name for name in self.listdir if name.endswith('.' + "binvox")][index]
        # print(self.data_path + model_3d_file)
        volume = readBinvox3DObject(self.data_path + model_3d_file)
        vol = self.transform(volume)
        # return torch.FloatTensor(volume)
        return torch.FloatTensor(vol)

    def __len__(self):
        return len([name for name in self.listdir if name.endswith('.' + "binvox")])

In [51]:
dsets_path = "data/train/"
print(dsets_path)
train_dataset = ShapeNetDataset(dsets_path,transform)
print(train_dataset[0].shape,type(train_dataset[0]))
torch.unique(train_dataset[1])

data/train/
torch.Size([32, 32, 32]) <class 'torch.Tensor'>


tensor([-1.,  1.])

In [52]:
train_dl = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

In [53]:
beta1,beta2 = 0.9,0.99
d_lr = float(0.0002)
g_lr = float(0.0002)
n_epochs = 1000
batch_size = 32
d_thresh = 0.75

In [54]:
disc = Discriminator(cube_length=32).to(device)
genr = Generator(cube_length=32,latent_dimension=200).to(device)

criterion = nn.CrossEntropyLoss().to(device)

D_optim = optim.Adam(disc.parameters(), lr=d_lr, betas=[beta1,beta2])
G_optim = optim.Adam(genr.parameters(), lr=g_lr, betas=[beta1,beta2])

In [55]:
import torch

def generateZ(z_dis,latent_dim,batch_size):
    if z_dis == "norm":
        Z = torch.normal(mean=0, std=0.33, size=(batch_size, latent_dim))
    elif z_dis == "uni":
        Z = torch.randn(size=(batch_size, latent_dim))  
    elif z_dis == "nothing":
        Z = torch.randn(size=(batch_size, latent_dim))
    else:
        print("z_dis is not normal or uniform")
        Z = None  #

    return Z.to(device)

In [56]:
import matplotlib.pyplot as plt
import os
import torch
import matplotlib.gridspec as gridspec

def plot_voxel(ax, voxels, title):
    # Move tensor to CPU and detach for plotting
    
    voxels = torch.Tensor(voxels)
    # Rotate to make the chair upright and face the viewer
    voxels = torch.rot90(voxels, k=1, dims=[0, 2])  # Adjust rotation based on orientation
    voxels = torch.rot90(voxels, k=-1, dims=[0, 1])  # Rotate to face the viewer

    # Apply threshold to get a boolean tensor
    voxels = (voxels >= 0)

    # Plot the voxels
    ax.voxels(voxels, edgecolor='k')
    ax.set_title(title)
    # ax.set_xticks([])
    # ax.set_yticks([])
    # ax.set_zticks([])

def SavePloat_Voxels(original_data, predicted_data, epoch):
    save_dir = 'binvox_images'
    os.makedirs(save_dir, exist_ok=True)

    # predicted_data = predicted_data.__ge__(0.5)
    # original_data = original_data.__ge__(0.5)

    fig = plt.figure(figsize=(12, 6))  # Adjusted figure size for two plots
    gs = gridspec.GridSpec(1, 2)  # 1 row, 2 columns for side-by-side plots
    gs.update(wspace=0.05, hspace=0.05)

    # Plot original data on the left
    ax1 = fig.add_subplot(gs[0], projection='3d')
    plot_voxel(ax1, original_data, title="Original")

    # Plot predicted data on the right
    ax2 = fig.add_subplot(gs[1], projection='3d')
    plot_voxel(ax2, predicted_data, title="Predicted")

    # Save the figure
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch}_comparison.png'), bbox_inches='tight')
    plt.close()

# SavePloat_Voxels(torch.randn((1,32,32,32)),torch.randn((1,32,32,32)),1)
# SavePloat_Voxels(train_dataset[0],train_dataset[1],1)

In [63]:

def train_one_epoch(e,D_model,G_model,data_loader,criterion,D_optim,G_optim):
    D_model.train()
    G_model.train()
    G_loss = 0
    D_loss = 0
    for i,X in enumerate(data_loader):
        X = X.to(device)

        if X.size()[0] != int(batch_size):
            continue        # This is for last batch where number of image is less

        Z = generateZ("nothing",200,batch_size=batch_size).to(device)

        real_labels = torch.ones(batch_size).to(device).type(torch.int64)
        fake_labels = torch.zeros(batch_size).to(device).type(torch.int64)

# --------------------------------- Train the generator ---------------------------------#

        fake = G_model(Z)

        d_fake_pred = D_model(fake)
        
        G_loss = criterion(d_fake_pred, real_labels)

        G_optim.zero_grad()
        G_loss.backward()
        G_optim.step()
        
# --------------------------------- Train the discriminator ------------------------------#
        # [20,32,32,32]
        d_real_pred = D_model(X)
        d_real_loss = criterion(d_real_pred, real_labels)

        # fake = G_model(Z)
        d_fake_pred = D_model(fake.detach())
        d_fake_loss = criterion(d_fake_pred, fake_labels)
        
        D_loss = (d_real_loss + d_fake_loss)/2

        D_optim.zero_grad()
        D_loss.backward()
        D_optim.step()
       

        if (e +1)% 10 == 0 and i == 0:  # Save for the first batch of each epoch
            # Convert tensors to binvox compatible format
            original_image = X[0].cpu().detach()  # Original image (first example)
            predicted_image = fake[0][0].cpu().detach() # Predicted image (first example)
            # print(type(predicted_image),predicted_image)
            # predicted_image = np.where(predicted_image > 0, 1, 0)
            # Save the binvox images side by side
            SavePloat_Voxels(original_image, predicted_image, e+1)

    return D_loss,G_loss


In [64]:
for e in range(n_epochs):
    train_d_loss,train_g_loss = train_one_epoch(e,disc,genr,train_dl,criterion,D_optim,G_optim)
    print(f"--------Epoch {e+1} -------- \n -------- discriminator_loss : {train_d_loss} --------\n Generator_loss : {train_g_loss} --------")

--------Epoch 1 -------- 
 -------- discriminator_loss : 0.31481003761291504 --------
 Generator_loss : 0.8487815856933594 --------
--------Epoch 2 -------- 
 -------- discriminator_loss : 0.12524794042110443 --------
 Generator_loss : 6.759894847869873 --------
--------Epoch 3 -------- 
 -------- discriminator_loss : 0.04349260777235031 --------
 Generator_loss : 2.7754642963409424 --------
--------Epoch 4 -------- 
 -------- discriminator_loss : 0.011318733915686607 --------
 Generator_loss : 4.474286079406738 --------
--------Epoch 5 -------- 
 -------- discriminator_loss : 0.006232726853340864 --------
 Generator_loss : 7.586093425750732 --------
--------Epoch 6 -------- 
 -------- discriminator_loss : 0.0076972669921815395 --------
 Generator_loss : 6.3617706298828125 --------
--------Epoch 7 -------- 
 -------- discriminator_loss : 0.026813244447112083 --------
 Generator_loss : 3.4138689041137695 --------
--------Epoch 8 -------- 
 -------- discriminator_loss : 0.004392439499497