In [1]:
import sys
sys.path.insert(1, 'data_utils')
sys.path.insert(1, 'models/')

from transform_functions import PCRNetTransform as transform
import transform_functions
from modelnet_reg_utils import ModelNet40Data, RegistrationData
from torch.utils.data import DataLoader
from pytorch3d.loss import chamfer_distance
import torch
from tqdm import tqdm
import numpy as np
import os
import open3d as o3d
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
import time
from models.feature_models import PointResNet, AttentionPointResNet, PointNet
from models.attention_pooling import AttentionPooling
from utils.load_model import load_model
from args import Args
arger = Args()
SAVE_DIR = 'checkpoints/pointresnet/cycle/'

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Pointnet2

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class get_model(nn.Module):
    def __init__(self,args,  feature_extractor = PointResNet):
        super(get_model, self).__init__()
        self.args = args
        self.feature_extractor = feature_extractor(self.args)
        if self.args.load_pretrained_feature_extractor:
            self.feature_extractor = load_model(self.feature_extractor, self.args.feature_extractor_path)
            print("Loaded pretrained feature extractor")
        if self.args.freeeze_feature_extractor:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
            print("Freezed feature extractor")
        if self.args.attention_pooling:
            self.attentional_pooling = AttentionPooling()
        else:
            self.attentional_pooling = None
    def forward(self, x):
        x_ap, x_mp = self.feature_extractor(x)
        if self.args.attention_pooling:
            x_feat= self.attentional_pooling(x_ap, x_mp)
        else:
            x_feat = x_mp
        return x_feat

## PCRNET

In [3]:
class iPCRNet(nn.Module):
	def __init__(self, feature_model, droput=0.0):
		super().__init__()
		self.feature_model = feature_model

		self.linear = [nn.Linear(1024 * 2, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 512), nn.ReLU(),
				   	   nn.Linear(512, 512), nn.ReLU(),
				   	   nn.Linear(512, 256), nn.ReLU()]

		if droput>0.0:
			self.linear.append(nn.Dropout(droput))
		self.linear.append(nn.Linear(256,7))

		self.linear = nn.Sequential(*self.linear)

	# Single Pass Alignment Module (SPAM)
	def spam(self, template_features, source, est_R, est_t):
		batch_size = source.size(0)
		self.source_features = self.feature_model(source.permute(0,2,1))
		y = torch.cat([template_features, self.source_features], dim=1)
		pose_7d = self.linear(y)
		pose_7d = transform.create_pose_7d(pose_7d)

		# Find current rotation and translation.
		identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous()
		est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1)
		est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3)

		# update translation matrix.
		est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp
		# update rotation matrix.
		est_R = torch.bmm(est_R_temp, est_R)
		
		source = transform.quaternion_transform(source, pose_7d)      # Ps' = est_R*Ps + est_t
		return est_R, est_t, source

	def forward(self, template, source, max_iteration=3):
		est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous()         # (Bx3x3)
		est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous()     # (Bx1x3)
		template_features = self.feature_model(template.permute(0,2,1))
		if max_iteration == 1:
			est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
		else:
			for i in range(max_iteration):
				est_R, est_t, source = self.spam(template_features, source, est_R, est_t)

		result = {'est_R': est_R,				# source -> template
				  'est_t': est_t,				# source -> template
				  'est_T': transform.convert2transformation(est_R, est_t),			# source -> template
				  'r': template_features - self.source_features,
				  'transformed_source': source}
		return result

## Training

In [4]:
def test_one_epoch(device, model, test_loader):
	model.eval()
	test_loss = 0.0
	pred  = 0.0
	count = 0
	with torch.no_grad():
		for i, data in enumerate(tqdm(test_loader)):
			template, source, _,_,_ = data

			template = template.to(device)
			source = source.to(device)

			# mean substraction
			source = source - torch.mean(source, dim=1, keepdim=True)
			template = template - torch.mean(template, dim=1, keepdim=True)

			output = model(template, source)
			loss_val = chamfer_distance(template, output['transformed_source'])[0]
			

			test_loss += loss_val.item()
			count += 1

	test_loss = float(test_loss)/count
	return test_loss

def train_one_epoch(device, model, train_loader, optimizer):
	model.train()
	train_loss = 0.0
	pred  = 0.0
	count = 0
	for i, data in enumerate(tqdm(train_loader)):
		template, source, _, _, _ = data
		
		# mean substraction
		source = source - torch.mean(source, dim=1, keepdim=True)
		template = template - torch.mean(template, dim=1, keepdim=True)

		source = source.to(device)
		template = template.to(device)

		outputab = model(template, source)
		outputba = model(source, template)
		loss_val1 = chamfer_distance(template, outputab['transformed_source'])[0]
		loss_val2 = chamfer_distance(source, outputba['transformed_source'])[0]
		loss = loss_val1 + loss_val2
		# print(loss_val.item())

		# forward + backward + optimize
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		count += 1

	train_loss = float(train_loss)/count
	return train_loss

In [5]:
def train(model, device, train_loader, test_loader, start_epoch):
	learnable_params = filter(lambda p: p.requires_grad, model.parameters())
	optimizer = torch.optim.Adam(learnable_params, lr = 0.00001)
	# scheduler = MultiStepLR(optimizer, milestones=[10,20,50], gamma=0.1)

	best_test_loss = np.inf

	for epoch in range(start_epoch, 400):
		train_loss = train_one_epoch(device, model, train_loader, optimizer)
		test_loss = test_one_epoch(device, model, test_loader)

		if test_loss<best_test_loss:
			best_test_loss = test_loss
			snap = {'epoch': epoch + 1,
					'model': model.state_dict(),
					'min_loss': best_test_loss,
					'optimizer' : optimizer.state_dict(),}
			torch.save(snap, SAVE_DIR + 'best_model_snap.t7')
			# torch.save(model.state_dict(), SAVE_DIR + 'best_model.t7')
			# torch.save(model.feature_model.state_dict(), SAVE_DIR + 'best_ptnet_model.t7')
			print("Test loss after epoch # {} is : {}".format(epoch, best_test_loss))
			print("learning_rate = {}".format(optimizer.param_groups[0]['lr']))
		# scheduler.step()
		# torch.save(snap, SAVE_DIR + 'model_snap.t7')
		# torch.save(model.state_dict(), SAVE_DIR + 'model.t7')
		# torch.save(model.feature_model.state_dict(), SAVE_DIR + 'ptnet_model.t7')

In [6]:
BATCH_SIZE = 16
trainset = RegistrationData('PCRNet', ModelNet40Data(train=True, download=True))
testset = RegistrationData('PCRNet', ModelNet40Data(train=False, download=True))
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE , shuffle=True, drop_last=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE , shuffle=False, drop_last=True, num_workers=2)

## train

In [13]:
if not torch.cuda.is_available():
    device = 'cpu'
else: 
    device = 'cuda:0'
device = torch.device(device)

# Create PointNet Model.
ptnet = get_model(arger, feature_extractor= AttentionPointResNet)
model = iPCRNet(feature_model=ptnet)

In [14]:
best_model_path = SAVE_DIR + "best_model_snap.t7"
LOAD = True
if os.path.isfile(best_model_path) and LOAD:
    print("Found checkpoint, loading weights")
    checkpoint = torch.load(best_model_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    model.to(device)
else:
    print("start from scratch")
    start_epoch = 0
    model.to(device)

Found checkpoint, loading weights


In [None]:
train(model, device, train_loader, test_loader, start_epoch)

## Inference

In [10]:
it = iter(test_loader)

In [11]:
data = it._next_data()

In [12]:
ptnet = get_model(arger, feature_extractor= PointNet)
model = iPCRNet(feature_model=ptnet)

In [27]:
# best_model_path = SAVE_DIR + "best_model_snap.t7"
best_model_path = "checkpoints/pointnet/cycle/best_model_snap.t7"
LOAD = True
if os.path.isfile(best_model_path) and LOAD:
    print("Found checkpoint, loading weights")
    checkpoint = torch.load(best_model_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    model.to(device)
else:
    print("start from scratch")
    start_epoch = 0
    model.to(device)

Found checkpoint, loading weights


In [15]:
template = data[0]
source= data[1]
igt= data[2]
template = template.to(device)
source = source.to(device)
igt = igt.to(device)
source = source - torch.mean(source, dim=1, keepdim=True)
template = template - torch.mean(template, dim=1, keepdim=True)

In [16]:
model.eval()
with torch.no_grad():
    output = model(template, source, max_iteration=3)
    loss_val = chamfer_distance(template, output['transformed_source'])
    print(loss_val)

(tensor(0.0227, device='cuda:0'), None)


In [17]:
with torch.no_grad():
    output = model(source, template, max_iteration=3)
    loss_val = chamfer_distance(source, output['transformed_source'])
    print(loss_val)

(tensor(0.0307, device='cuda:0'), None)


In [18]:
temp_cpu = template.cpu()
source_pred_cpu = output['transformed_source'].cpu()
source_cpu = source.cpu()

In [19]:
for i in range(8):
    template_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(temp_cpu.detach().numpy()[i,:,:]))
    template_cloud.paint_uniform_color(np.array([0,0,1]))#B
    source_pred_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(source_pred_cpu.detach().numpy()[i,:,:]))
    source_pred_cloud.paint_uniform_color(np.array([1,0,0]))#R
    source_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(source_cpu.detach().numpy()[i,:,:]))
    source_cloud.paint_uniform_color(np.array([0,1,0]))#G
    o3d.visualization.draw_geometries([template_cloud, source_pred_cloud, source_cloud])