In [11]:
## Complete training and testing function for your 3D Voxel GAN and have fun making pottery art!

import numpy as np
import torch
from torch import optim
from torch.utils import data
from torch import nn
from utils.FragmentDataset import FragmentDataset
import click
from utils.model_utils import *
import argparse
from test import *

class Discriminator(torch.nn.Module):
    def __init__(self, n_out, resolution=64):
        super(Discriminator, self).__init__()
        # initialize superior inherited class, necessary hyperparams and modules
        # You may use torch.nn.Conv3d(), torch.nn.sequential(), torch.nn.BatchNorm3d() for blocks
        # You may try different activation functions such as ReLU or LeakyReLU.
        # REMEMBER YOU ARE WRITING A DISCRIMINATOR (binary classification) so Sigmoid
        self.scale = resolution // 32
        self.model = nn.Sequential(
            nn.Conv3d(1, 32, 5, 1, 2),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2),
            nn.Conv3d(32, 32, 3, 2, 1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2),
            nn.Conv3d(32, 64, 3, 2, 1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2),
            nn.Conv3d(64, 128, 3, 2, 1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2),
            nn.Conv3d(128, 256, 3, 2, 1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(self.scale * self.scale * self.scale * 256, n_out),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Try to connect all modules to make the model operational!
        # Note that the shape of x may need adjustment
        # # Do not forget the batch size in x.dim
        # TODO
        out = self.model(x)
        return out


class Generator(torch.nn.Module):
    # TODO
    def __init__(self, n_labels, cube_len=64, z_latent_space=64, z_intern_space=64, device='cuda'):
        super(Generator, self).__init__()
        self.resolution = cube_len // 32
        self.scale = (cube_len // 32) ** 3 * 1024  # dimensions of the final convolution result
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, 5, 1, 2),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2),
            nn.Conv3d(32, 32, 3, 2, 1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2),
            nn.Conv3d(32, 64, 3, 2, 1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2),
            nn.Conv3d(64, 128, 3, 2, 1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2),
            nn.Conv3d(128, 256, 3, 2, 1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2)
        )
        self.embedding = nn.Sequential(
            nn.Embedding(n_labels, 64),
            nn.Flatten(),
            nn.Linear(64 * n_labels, 1024),
            nn.LeakyReLU(0.2)
        )
        self.flatten = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.scale, z_latent_space)
        )
        self.cat = lambda x, y: torch.cat((x, y), dim=1)

        self.fc1 = nn.Linear(self.scale, z_latent_space)
        self.fc2 = nn.Linear(self.scale, z_latent_space)  # 1 and 2 for VI method
        self.restore = nn.Sequential(
            nn.Linear(z_latent_space + n_labels, z_latent_space),
            nn.Linear(z_latent_space, self.scale)
        )  # restoration of the mix layer ready for deconvolution
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(256, 128, 3, 1, 1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, 3, 2, 1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, 3, 2, 1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 32, 5, 2, 2),

        )
        self.device = device

    def reparameterize(self, mean, logvar):
        eps = torch.randn(mean.shape).to(self.device)
        z = mean + eps * torch.exp(logvar)
        return z

    def forward_encode(self, x):
        z = self.encoder(x) # 2*2*2*256 (for 64)
        mean = self.fc1(z.view(z.shape[0], -1))
        logvar = self.fc2(z.view(z.shape[0], -1))
        y = self.embedding(x)  # labels embedding layer
        mix = self.flatten(z)
        mix = self.cat(mix, y)
        mix = self.restore(mix).view(-1, self.resolution, self.resolution, self.resolution, 256)
        z = self.reparameterize(mean, logvar)
        z = self.cat(mix, z)
        return z, mean, logvar

    def forward_decode(self, x):
        out = self.decoder(x)
        return out

    def forward(self, x):
        z = self.forward_encode(x)
        out = self.forward_decode(z)
        return out


def CVAE_loss(z, x, mean, logstd, ratio):
    MSEcriterion = nn.MSELoss().to(available_device)
    mse = MSEcriterion(x, z)
    var = torch.pow(torch.exp(logstd), 2)
    kld = -0.5 * torch.sum(1 + torch.log(var) - torch.pow(mean, 2) - var)
    return mse + kld * ratio


def main():
    # Here is a simple demonstration argparse, you may customize your own implementations, and
    # your hyperparam list MAY INCLUDE:
    # 1. Z_latent_space
    # 2. G_lr
    # 3. D_lr  (learning rate for Discriminator)
    # 4. betas if you are going to use Adam optimizer
    # 5. Resolution for input data
    # 6. Training Epochs
    # 7. Test per epoch
    # 8. Batch Size
    # 9. Dataset Dir
    # 10. Load / Save model Device
    # 11. test result save dir
    # 12. device!
    # .... (maybe there exists more hyperparams to be appointed)
    epochs = 100
    G_lr = 2e-3
    D_lr = 2e-4
    C_lr = 2e-4
    optimizer = 'ADAM'
    beta1 = 0.9
    beta2 = 0.999
    batch_size = 64  # modify according to device capability
    n_labels = 11
    resolution = 32
    z_latent_space = 1024
    log_interval = 100
    vi_ratio = 1

    dirdataset = "../VoxPottery"
    available_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, help='training/testing')
    parser.add_argument('-r', type=int, help='resolution')
    args = parser.parse_known_args()[0]

    ### Initialize train and test dataset
    dtrain = FragmentDataset(dirdataset, 'train', resolution)
    dtest = FragmentDataset(dirdataset, 'test', resolution)
    print("Data initialized")

    ### Initialize Generator and Discriminator to specific device
    G = Generator(n_labels, resolution, z_latent_space).to(available_device)
    D = Discriminator(1, resolution).to(available_device)
    C = Discriminator(n_labels, resolution).to(available_device)
    optimG = optim.Adam(G.parameters(), G_lr, (beta1, beta2))
    optimD = optim.Adam(D.parameters(), D_lr, (beta1, beta2))
    optimC = optim.Adam(C.parameters(), C_lr, (beta1, beta2))
    print("VAE initialized")

    ### Call dataloader for train and test dataset
    trainloader = torch.utils.data.DataLoader(dtrain, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(dtest, batch_size=batch_size, shuffle=False, num_workers=2)

    ### Implement GAN Loss!!
    # TODO
    criterion = nn.BCELoss().to(available_device)  # BCE loss
    # loss_function = 'BCE'

    ### Training Loop implementation
    ### You can refer to other papers / github repos for training a GAN
    # TODO
    print("Start training")
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            frg, vox, label = data
            vox = vox.to(available_device)
            frg = frg.to(available_device)
            whole = vox + frg
            label_onehot = torch.zeros((vox.shape[0], n_labels)).to(available_device)
            label_onehot[torch.arange(vox.shape[0]), label] = 1
            # train classifier on 11 types of ceramics (prepare for conditional GAN)
            out = C(whole)
            truth = label_onehot.to(available_device)
            lossC = criterion(out, truth)
            C.zero_grad()
            lossC.backward()
            optimC.step()
            # train Discriminator
            out = D(whole)
            real_label = torch.ones(batch_size).to(available_device)  # real pieces labelled 1
            fake_label = torch.zeros(batch_size).to(available_device)  # fake pieces labelled 0
            lossD_real = criterion(out, real_label)

            z = torch.randn(batch_size, z_latent_space + n_labels).to(available_device)
            fake_data = G.forward_decode(z)+vox
            out = D(fake_data)
            lossD_fake = criterion(out, fake_label)

            lossD = lossD_real + lossD_fake
            D.zero_grad()
            lossD.backward()
            optimD.step()
            # train Generator
            z, mean, logstd = G.forward_encode(vox)
            recon_data = G.forward_decode(z)
            lossG_var_completion = CVAE_loss(recon_data, vox, mean, logstd, vi_ratio)
            out = D(recon_data+vox)
            truth = torch.ones(batch_size).to(available_device)
            lossG_dis = criterion(out, truth)
            out = C(recon_data+vox)
            truth = label_onehot
            lossG_condition = criterion(out, truth)
            G.zero_grad()
            lossG = lossG_var_completion + lossG_dis + lossG_condition
            lossG.backward()
            optimG.step()
            if i % log_interval == 0:
                print("i =", i)
                # test()


if __name__ == "__main__":
    main()


Data initialized
VAE initialized
Start training


RuntimeError: Given groups=1, weight of size [32, 1, 5, 5, 5], expected input[1, 64, 32, 32, 32] to have 1 channels, but got 64 channels instead

In [21]:
import torch
import torch.nn.functional as F

# Define tensor a
a = torch.tensor([0.4835, 0.5407, 0.4840, 0.5245, 0.5569, 0.6447, 0.5125, 0.4295, 0.5018,
        0.6360, 0.5395, 0.5377, 0.5086, 0.5101, 0.4061, 0.5847])

# Define tensor b
b = torch.zeros(16)  # Create a tensor of zeros with dimension 16

# Convert tensor b to one-hot encoding
b_one_hot = F.one_hot(b.to(torch.int64), num_classes=2)  # Assuming 2 classes, change num_classes accordingly

# Calculate cross-entropy loss
loss = F.binary_cross_entropy(torch.sigmoid(a), b_one_hot.float())

print(loss.item())


ValueError: Using a target size (torch.Size([16, 2])) that is different to the input size (torch.Size([16])) is deprecated. Please ensure they have the same size.