In [1]:
import argparse
import os
import sys
import logging
import numpy
from pathlib import Path
import torch
import torch.utils.data
import torchvision

sys.path.append(os.path.join(Path().resolve(), '../')) # add parent directory to path
import ptlk

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def options(outfile, dataset_path, categoryfile, dataset_type='modelnet', num_points=1024, mag=0.8,
            pointnet='tune', transfer_from='', dim_k=1024, symfn='max', max_iter=10, delta=1.0e-2,
            learn_delta=False, logfile='', workers=4, batch_size=32, epochs=200, start_epoch=0,
            optimizer='Adam', resume='', pretrained='', device='cuda:0'):

    args = argparse.Namespace(
        outfile=outfile,
        dataset_path=dataset_path,
        categoryfile=categoryfile,
        dataset_type=dataset_type,
        num_points=num_points,
        mag=mag,
        pointnet=pointnet,
        transfer_from=transfer_from,
        dim_k=dim_k,
        symfn=symfn,
        max_iter=max_iter,
        delta=delta,
        learn_delta=learn_delta,
        logfile=logfile,
        workers=workers,
        batch_size=batch_size,
        epochs=epochs,
        start_epoch=start_epoch,
        optimizer=optimizer,
        resume=resume,
        pretrained=pretrained,
        device=device
    )

    return args

In [3]:
args = options(outfile='results/ex_pointlk_1110',
               dataset_path='./dataset/PointCloudSet',
               categoryfile='./sampledata/pointcloud_half.txt',
               dataset_type='pointcloud',
               logfile='results/ex_pointlk_1110.log')


In [4]:
def get_datasets(args):

    cinfo = None
    if args.categoryfile:
        #categories = numpy.loadtxt(args.categoryfile, dtype=str, delimiter="\n").tolist()
        categories = [line.rstrip('\n') for line in open(args.categoryfile)]
        categories.sort()
        c_to_idx = {categories[i]: i for i in range(len(categories))}
        cinfo = (categories, c_to_idx)

    if args.dataset_type == 'modelnet':
        transform = torchvision.transforms.Compose([\
                ptlk.data.transforms.Mesh2Points(),\
                ptlk.data.transforms.OnUnitCube(),\
                ptlk.data.transforms.Resampler(args.num_points),\
            ])

        traindata = ptlk.data.datasets.ModelNet(args.dataset_path, train=1, transform=transform, classinfo=cinfo)
        testdata = ptlk.data.datasets.ModelNet(args.dataset_path, train=0, transform=transform, classinfo=cinfo)

        mag_randomly = True
        trainset = ptlk.data.datasets.CADset4tracking(traindata,\
                        ptlk.data.transforms.RandomTransformSE3(args.mag, mag_randomly))
        testset = ptlk.data.datasets.CADset4tracking(testdata,\
                        ptlk.data.transforms.RandomTransformSE3(args.mag, mag_randomly))
        
    elif args.dataset_type == 'pointcloud':
        transform = torchvision.transforms.Compose([\
                    ptlk.data.transforms.OnUnitCube(),\
                    ptlk.data.transforms.Resampler(args.num_points),\
                    ])

        traindata = ptlk.data.datasets.PointCloud(args.dataset_path, train=1, transform=transform, classinfo=cinfo)
        testdata = ptlk.data.datasets.PointCloud(args.dataset_path, train=0, transform=transform, classinfo=cinfo)

        mag_randomly = True
        trainset = ptlk.data.datasets.CADset4tracking(traindata,\
                        ptlk.data.transforms.RandomTransformSE3(args.mag, mag_randomly))
        testset = ptlk.data.datasets.CADset4tracking(testdata,\
                        ptlk.data.transforms.RandomTransformSE3(args.mag, mag_randomly))


    return trainset, testset

In [5]:
trainset, testset = get_datasets(args)

In [6]:
len(testset)

5

In [7]:
# dataloader
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=args.batch_size, shuffle=True, num_workers=args.workers)

In [13]:
testloader

<torch.utils.data.dataloader.DataLoader at 0x171dc6d6040>