In [2]:
import os
import sys
import tqdm
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
from pyntcloud import PyntCloud
from tensorboardX import SummaryWriter
from torch import autograd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
##
from src.autoencoder import AutoEncoder, PointcloudDatasetAE
from src.chamferloss import ChamferLoss_distance
from src.gan import GenSAGAN, DiscSAGAN

ModuleNotFoundError: No module named 'tqdm'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
BATCH_SIZE = 20
LAMBDA = 1e1
use_cuda = torch.cuda.is_available()
def calc_gradient_penalty(netD, real_data, fake_data):
    try:
        alpha = torch.rand(BATCH_SIZE, 1)
        alpha = alpha.expand(real_data.size())
        alpha = alpha.cuda() if use_cuda else alpha
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        if use_cuda:
            interpolates = interpolates.cuda()
        interpolates = autograd.Variable(interpolates, requires_grad=True)
        disc_interpolates, _ = netD(interpolates)
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                    disc_interpolates.size()),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
        return gradient_penalty
    except:
        print("Err")
        return None

In [None]:
DATA_DIR = "./data/shape_net_core_uniform_samples_2048/"
list_point_clouds = np.load('./data/filter/list_point_cloud_filepath.npy')
X_train, X_test, _, _ = train_test_split(list_point_clouds, list_point_clouds, test_size=0.1, random_state=42)
print(len(X_train))

In [None]:
train_dataset = PointcloudDatasetAE(DATA_DIR, X_train)
train_dataloader = DataLoader(train_dataset, num_workers=2, shuffle=False, batch_size=BATCH_SIZE)

test_dataset = PointcloudDatasetAE(DATA_DIR, X_test)
test_dataloader = DataLoader(test_dataset, num_workers=2, shuffle=False, batch_size=1)

for i, data in enumerate(train_dataloader):
    data = data.permute([0,2,1])
    print(data.shape)
    break

In [None]:
z_dim = 5
generator = GenSAGAN(z_dim=z_dim).to(device)
discriminator = DiscSAGAN().to(device) 
autoencoder = AutoEncoder(2048).to(device)
chamfer_loss = ChamferLoss_distance(2048).to(device)

In [None]:
g_lr = 1.0e-4
d_lr = 1.0e-4
lr = 1.0e-4
d_gp_weight = 1e1   
momentum = 0.95
optimizer_AE = torch.optim.Adam(autoencoder.parameters(), lr=lr, betas=(momentum, 0.999))
g_optim = torch.optim.Adam(generator.parameters(), lr=g_lr)
d_optim = torch.optim.Adam(discriminator.parameters(), lr=d_lr)

In [None]:
ROOT_DIR = './models/gan/'
now =   str(datetime.datetime.now())+'z'+str(z_dim)

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

OUTPUTS_DIR = ROOT_DIR  + now + '/outputs/'
if not os.path.exists(OUTPUTS_DIR):
    os.makedirs(OUTPUTS_DIR)

MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)

In [None]:
autoencoder.load_state_dict(torch.load('./models/autoencoder/2022-08-06 15:19:12.904709/models/14_ae_.pt'))

In [None]:
def test_model(generator, autoencoder,epoch):
    for i in tqdm.trange(5):
        # points = PyntCloud.from_file(X_test[i])
        # points = np.array(points.points)
        # points_normalized = (points - (-0.5)) / (0.5 - (-0.5))
        # points = points_normalized.astype(np.float)
        # points = torch.from_numpy(points).unsqueeze(0)
        # points = points.permute([0,2,1]).float().to(device)
        # print(points.shape)
        autoencoder.eval()
        generator.eval()
        z = torch.randn(1, z_dim).to(device)
        with torch.no_grad():
                gen_out, _ = generator(z)
                out_data = autoencoder.decode(gen_out)
                # loss = chamfer_loss(out_data, points)
        # print(loss.item())                
        output = out_data[0,:,:]
        output = output.permute([1,0]).detach().cpu().numpy()
        # inputt = points[0,:,:]
        # inputt = inputt.permute([1,0]).detach().cpu().numpy()
        fig = plt.figure()
        ax_x = fig.add_subplot(111, projection='3d')
        x_ = output
        ax_x.scatter(x_[:, 0], x_[:, 1], x_[:,2])
        ax_x.set_xlim([0,1])
        ax_x.set_ylim([0,1])
        ax_x.set_zlim([0,1])
        fig.savefig(OUTPUTS_DIR+'/{}_{}_{}.png'.format(epoch, i, 'out'))

In [None]:
# print('Training')
for epoch in range(1000):
    autoencoder.train()
    for i, data in enumerate(train_dataloader):
        data = data.permute([0,2,1]).float().to(device)
        # optimizer_AE.zero_grad()
        autoencoder.eval()
        generator.train()
        discriminator.train()      
        with torch.no_grad():
            gfv = autoencoder.encode(data)
        z = torch.randn(data.shape[0], z_dim).to(device)
        g_optim.zero_grad()
        d_optim.zero_grad()
        fake_out, _ = generator(z)
        # print(gfv.device)
        d_fake, _ = discriminator(fake_out)
        d_real, _ = discriminator(gfv)
        d_loss = -(torch.mean(d_real) - torch.mean(d_fake))
        d_grad_penalty = calc_gradient_penalty(discriminator, gfv, fake_out)
        if not d_grad_penalty:
            continue
        total_d_loss = d_loss + d_grad_penalty
        total_d_loss.backward()
        d_optim.step()
        #####################################
        g_optim.zero_grad()
        d_optim.zero_grad()     
        g_out, _ = generator(z)        
        d_fake, _ = discriminator(g_out)
        gen_loss = -torch.mean(d_fake)        
        out_data = autoencoder.decode(g_out)
        loss = gen_loss
        loss.backward()
        g_optim.step()
        print('Epoch: {}, Iteration: {},  G Loss: {:.4f} D Loss: {:.4f} '.format(epoch, i, loss.item(), total_d_loss.item()))
        summary_writer.add_scalar('G Loss', loss.item())
        summary_writer.add_scalar('GP  Loss', d_grad_penalty.item())
        summary_writer.add_scalar('D Loss', d_loss.item())
        summary_writer.add_scalar('Total D Loss', total_d_loss.item())    
    if epoch % 20 == 0:
        torch.save(generator.state_dict(), MODEL_DIR+'{}_gen_.pt'.format(epoch))
        torch.save(discriminator.state_dict(), MODEL_DIR+'{}_disc_.pt'.format(epoch))    
    if epoch % 5 == 0:
        test_model(generator, autoencoder, epoch)