In [None]:
import os
# import tqdm
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
from pyntcloud import PyntCloud
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
##
from src.autoencoder import AutoEncoder, PointcloudDatasetAE
from src.chamferloss import ChamferLoss_distance

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
DATA_DIR = "./data/shape_net_core_uniform_samples_2048/"
list_point_clouds = np.load('./data/filter/list_point_cloud_filepath.npy')
X_train, X_test, _, _ = train_test_split(list_point_clouds, list_point_clouds, test_size=0.1, random_state=42)
print(len(X_train))

In [None]:
train_dataset = PointcloudDatasetAE(DATA_DIR, X_train)
train_dataloader = DataLoader(train_dataset, num_workers=2, shuffle=False, batch_size=48)

test_dataset = PointcloudDatasetAE(DATA_DIR, X_test)
test_dataloader = DataLoader(test_dataset, num_workers=2, shuffle=False, batch_size=1)

for i, data in enumerate(train_dataloader):
    data = data.permute([0,2,1])
    print(data.shape)
    break

In [None]:
autoencoder = AutoEncoder(2048).to(device)
chamfer_loss = ChamferLoss_distance(2048).to(device)

In [None]:
lr = 1.0e-4
momentum = 0.95
optimizer_AE = torch.optim.Adam(autoencoder.parameters(), lr=lr, betas=(momentum, 0.999))

In [None]:
ROOT_DIR = './models/autoencoder/'
now =   str(datetime.datetime.now())

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

OUTPUTS_DIR = ROOT_DIR  + now + '/outputs/'
if not os.path.exists(OUTPUTS_DIR):
    os.makedirs(OUTPUTS_DIR)

MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)

In [None]:
# print('Training')
for epoch in range(1000):
    autoencoder.train()
    for i, data in enumerate(train_dataloader):
        data = data.permute([0,2,1]).float().to(device)
        optimizer_AE.zero_grad()
        out_data, gfv = autoencoder(data)
        loss = chamfer_loss(out_data, data)
        loss.backward()
        optimizer_AE.step()        
        print('Epoch: {}, Iteration: {}, Content Loss: {}'.format(epoch, i, loss.item()))
        summary_writer.add_scalar('Content Loss', loss.item())
        # if i > 2:
        #     break
    torch.save(autoencoder.state_dict(), MODEL_DIR+'{}_ae_.pt'.format(epoch))

In [None]:
autoencoder.load_state_dict(torch.load('./models/autoencoder/2022-08-06 15:19:12.904709/models/14_ae_.pt'))

In [None]:
eval_output = os.path.join(ROOT_DIR, 'outputs', 'eval_output')

In [None]:
for i in range(X_test.shape[0]):
        points = PyntCloud.from_file(X_test[i])
        points = np.array(points.points)
        points_normalized = (points - (-0.5)) / (0.5 - (-0.5))
        points = points_normalized.astype(np.float)
        points = torch.from_numpy(points).unsqueeze(0)
        points = points.permute([0,2,1]).float().to(device)
        print(points.shape)
        autoencoder.eval()
        with torch.no_grad():
                out_data, gfv = autoencoder(points)
                loss = chamfer_loss(out_data, points)
        print(loss.item())   
        output = out_data[0,:,:]
        output = output.permute([1,0]).detach().cpu().numpy()
        inputt = points[0,:,:]
        inputt = inputt.permute([1,0]).detach().cpu().numpy()
        fig = plt.figure()
        ax_x = fig.add_subplot(111, projection='3d')
        x_ = output
        ax_x.scatter(x_[:, 0], x_[:, 1], x_[:,2])
        ax_x.set_xlim([0,1])
        ax_x.set_ylim([0,1])
        ax_x.set_zlim([0,1])
        fig.savefig('{}/{}_{}.png'.format(eval_output, i, 'out'))
        fig = plt.figure()
        ax_x = fig.add_subplot(111, projection='3d')
        x_ = inputt
        ax_x.scatter(x_[:, 0], x_[:, 1], x_[:,2])
        ax_x.set_xlim([0,1])
        ax_x.set_ylim([0,1])
        ax_x.set_zlim([0,1])
        fig.savefig('{}/{}_{}.png'.format(eval_output, i, 'in'))
        plt.close('all')