In [1]:
import sys
sys.path.insert(1, 'data_utils')
sys.path.insert(1, 'models/')

from transform_functions import PCRNetTransform as transform
import transform_functions
from modelnet_reg_utils import ModelNet40Data, RegistrationData
from torch.utils.data import DataLoader
from losses.mse import loss_function
from pytorch3d.loss import chamfer_distance
import torch
from tqdm import tqdm
import numpy as np
import os
import open3d as o3d
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
import time
from models.feature_models import PointResNet, AttentionPointResNet, PointNet, AttentionPointResNetV2
from models.attention_pooling import AttentionPooling
from utils.load_model import load_model
from args import Args
arger = Args()
SAVEDIR = 'checkpoints/pointresnetv2/mse/'

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## PointNet

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class get_model(nn.Module):
    def __init__(self,args,  feature_extractor = PointResNet):
        super(get_model, self).__init__()
        self.args = args
        self.feature_extractor = feature_extractor(self.args)
        if self.args.load_pretrained_feature_extractor:
            self.feature_extractor = load_model(self.feature_extractor, self.args.feature_extractor_path)
            print("Loaded pretrained feature extractor")
        if self.args.freeeze_feature_extractor:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
            print("Freezed feature extractor")
        # if self.args.attention_pooling:
        #     self.attentional_pooling = AttentionPooling()
        # else:
        #     self.attentional_pooling = None
    def forward(self, x):
        x_feat = self.feature_extractor(x)
        # if self.args.attention_pooling:
        #     x_feat= self.attentional_pooling(x_ap, x_mp)
        # else:
        #     x_feat = x_mp
        return x_feat

## PCRNet

In [3]:
class iPCRNet(nn.Module):
	def __init__(self, feature_model, droput=0.0):
		super().__init__()
		self.feature_model = feature_model

		self.linear = [nn.Linear(1024 * 2, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 512), nn.ReLU(),
				   	   nn.Linear(512, 512), nn.ReLU(),
				   	   nn.Linear(512, 256), nn.ReLU()]

		if droput>0.0:
			self.linear.append(nn.Dropout(droput))
		self.linear.append(nn.Linear(256,7))

		self.linear = nn.Sequential(*self.linear)

	# Single Pass Alignment Module (SPAM)
	def spam(self, template_features, source, est_R, est_t):
		batch_size = source.size(0)
		self.source_features = self.feature_model(source.permute(0,2,1))
		y = torch.cat([template_features, self.source_features], dim=1)
		pose_7d = self.linear(y)
		pose_7d = transform.create_pose_7d(pose_7d)

		# Find current rotation and translation.
		identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous()
		est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1)
		est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3)

		# update translation matrix.
		est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp
		# update rotation matrix.
		est_R = torch.bmm(est_R_temp, est_R)
		
		source = transform.quaternion_transform(source, pose_7d)      # Ps' = est_R*Ps + est_t
		return est_R, est_t, source

	def forward(self, template, source, max_iteration=3):
		est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous()         # (Bx3x3)
		est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous()     # (Bx1x3)
		template_features = self.feature_model(template.permute(0,2,1))
		if max_iteration == 1:
			est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
		else:
			for i in range(max_iteration):
				est_R, est_t, source = self.spam(template_features, source, est_R, est_t)

		result = {'est_R': est_R,				# source -> template
				  'est_t': est_t,				# source -> template
				  'est_T': transform.convert2transformation(est_R, est_t),			# source -> template
				  'r': template_features - self.source_features,
				  'transformed_source': source}
		return result

In [4]:
def visualize_with_trans(source, template, trans):
    src_b = source.detach().cpu().numpy()
    tar_b = template.detach().cpu().numpy()
    trans_b = trans.detach().cpu().numpy()

    for i in range(src_b.shape[0]):
        src = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_b[i])).paint_uniform_color([1, 0.706, 0])
        tar = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tar_b[i])).paint_uniform_color([0, 0.651, 0.929])
        src.transform(trans_b[i])
        o3d.visualization.draw_geometries([src, tar])

In [5]:
def batch_inverse(T):
    """
    Invert a batch of 4x4 transformation matrices.
    Args:
        T: A torch tensor of shape (B, 4, 4), where B is the batch size.
    Returns:
        inv_T: A torch tensor of shape (B, 4, 4), where each 4x4 matrix
               is the inverse of the corresponding input matrix.
    """
    B = T.shape[0]
    inv_T = torch.zeros_like(T)
    for i in range(B):
        inv_T[i] = torch.inverse(T[i])
    return inv_T

## Training

In [6]:
def test_one_epoch(device, model, test_loader):
	model.eval()
	test_loss = 0.0
	count = 0
	chamfer_loss = 0
	with torch.no_grad():
		for i, data in enumerate(tqdm(test_loader)):
			template, source, _, igt_R, igt_t = data
			batch_size = template.size(0)
			source = source.to(device)
			template = template.to(device)
			#igt = igt.to(device)
			igt_R = igt_R.to(device)
			igt_t = igt_t.to(device)
			# mean substraction
			igt_t = igt_t - torch.mean(source, dim=1).unsqueeze(1)
			source = source - torch.mean(source, dim=1, keepdim=True)
			template = template - torch.mean(template, dim=1, keepdim=True)
			
			output = model(template, source)
			igt = batch_inverse(transform.convert2transformation(igt_R, igt_t).to(device))
			chamfer_loss += chamfer_distance(template, output['transformed_source'])[0]
			# loss = loss_function(output['est_R'], output['est_t'].squeeze(1), igt[:,0:3,0:3], igt[:,0:3,3]) 
			# rot_rmse  += loss['rotation_rmse']
			# rot_mae += loss['rotation_mae']
			# trans_rmse += loss['translation_rmse']
			# trans_mae += loss['translation_mae']
			# total_loss = (loss['rotation_rmse']) + translation_factor * loss['translation_rmse']
			identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
			total_loss = F.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), igt[:,0:3,0:3]), identity) \
               + F.mse_loss(output['est_t'].squeeze(1), igt[:,0:3,3])
			# total_loss = (loss['rotation_rmse']) + translation_factor * loss['translation_rmse']
			test_loss += total_loss.item()
			count += 1
	test_loss = float(test_loss)/count
	print("Test loss is {}".format(test_loss))
	print("Chamfer loss is {}".format(chamfer_loss/count))
	return test_loss

def train_one_epoch(device, model, train_loader, optimizer, cycle_loss = True):
	model.train()
	train_loss = 0.0
	count = 0
	for i, data in enumerate(tqdm(train_loader)):
		template, source, _, igt_R, igt_t = data
		batch_size = template.size(0)
		source = source.to(device)
		template = template.to(device)
		#igt = igt.to(device)
		igt_R = igt_R.to(device)
		igt_t = igt_t.to(device)
		# mean substraction
		igt_t = igt_t - torch.mean(source, dim=1).unsqueeze(1)
		source = source - torch.mean(source, dim=1, keepdim=True)
		template = template - torch.mean(template, dim=1, keepdim=True)

		
		outputab = model(template, source)
		identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
		igt = batch_inverse(transform.convert2transformation(igt_R, igt_t).to(device))
		# loss = loss_function(outputab['est_R'], outputab['est_t'].squeeze(1), igt[:,0:3,0:3], igt[:,0:3,3]) 
		total_loss = 1.5 * F.mse_loss(torch.matmul(outputab['est_R'].transpose(2, 1), igt[:,0:3,0:3]), identity) \
               + F.mse_loss(outputab['est_t'].squeeze(1), igt[:,0:3,3])
		# total_loss = (loss['rotation_rmse']) + translation_factor * loss['translation_rmse']
		if cycle_loss:
			outputba = model(source, template)
			identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
			rotation_loss = F.mse_loss(torch.matmul(outputba['est_R'], outputab['est_R']), identity.clone())
			translation_loss = torch.mean((torch.matmul(outputba['est_R'].transpose(2, 1),outputab['est_t'].view(batch_size, 3, 1)).view(batch_size, 3)+ outputba['est_t'].view(batch_size,3)) ** 2, dim=[0, 1])
			cycle_loss = 1.5 * rotation_loss + translation_loss
			total_loss = total_loss + cycle_loss * 0.1
		# forward + backward + optimize
		optimizer.zero_grad()
		total_loss.backward()
		optimizer.step()
		# visualize_with_trans(source, template, igt) 

		train_loss += total_loss.item()
		count += 1
	
	train_loss = float(train_loss)/count
	return train_loss
    

In [7]:
def train(model, device, train_loader, test_loader, start_epoch):
	learnable_params = filter(lambda p: p.requires_grad, model.parameters())
	optimizer = torch.optim.Adam(learnable_params, lr = 0.001)
	scheduler = MultiStepLR(optimizer, milestones=[50, 100, 150], gamma=0.1)

	best_test_loss = np.inf

	for epoch in range(start_epoch, 300):
		train_loss = train_one_epoch(device, model, train_loader, optimizer)
		test_loss = test_one_epoch(device, model, test_loader)

		if test_loss<best_test_loss:
			best_test_loss = test_loss
			snap = {'epoch': epoch + 1,
					'model': model.state_dict(),
					'min_loss': best_test_loss,
					'optimizer' : optimizer.state_dict(),}
			torch.save(snap, SAVEDIR + 'best_model_snap.t7')
			# torch.save(model.state_dict(), SAVEDIR + 'best_model.t7')
			# torch.save(model.feature_model.state_dict(), SAVEDIR + 'best_ptnet_model.t7')
			print("Test loss after epoch # {} is : {}".format(epoch, best_test_loss))
		scheduler.step()
		# torch.save(snap, SAVEDIR + 'model_snap.t7')
		# torch.save(model.state_dict(), SAVEDIR + 'model.t7')
		# torch.save(model.feature_model.state_dict(), SAVEDIR + 'ptnet_model.t7')

In [8]:
BATCH_SIZE = 16
trainset = RegistrationData('PCRNet', ModelNet40Data(train=True, download=True))
testset = RegistrationData('PCRNet', ModelNet40Data(train=False, download=True))
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE , shuffle=True, drop_last=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE , shuffle=False, drop_last=True, num_workers=2)

## Train

In [9]:
if not torch.cuda.is_available():
    device = 'cpu'
else: 
    device = 'cuda:0'
device = torch.device(device)

# Create PointNet Model.
ptnet = get_model(arger, feature_extractor= AttentionPointResNetV2)
model = iPCRNet(feature_model=ptnet)

In [10]:
best_model_path = SAVEDIR + "best_model_snap.t7"
LOAD = False
if os.path.isfile(best_model_path) and LOAD:
    print("Found checkpoint, loading weights")
    checkpoint = torch.load(best_model_path)
    start_epoch = 0
    model.load_state_dict(checkpoint['model'])
    model.to(device)
else:
    print("start from scratch")
    start_epoch = 0
    model.to(device)

start from scratch


In [11]:
train(model, device, train_loader, test_loader, start_epoch)

100%|██████████| 615/615 [03:17<00:00,  3.11it/s]
100%|██████████| 154/154 [00:18<00:00,  8.47it/s]


Test loss is 0.12121910645396679
Chamfer loss is 0.05363178998231888
Test loss after epoch # 0 is : 0.12121910645396679


100%|██████████| 615/615 [03:18<00:00,  3.10it/s]
100%|██████████| 154/154 [00:17<00:00,  8.67it/s]


Test loss is 0.09915316063765582
Chamfer loss is 0.04061540216207504
Test loss after epoch # 1 is : 0.09915316063765582


100%|██████████| 615/615 [03:20<00:00,  3.07it/s]
100%|██████████| 154/154 [00:17<00:00,  8.71it/s]


Test loss is 0.0768584128372468
Chamfer loss is 0.029568037018179893
Test loss after epoch # 2 is : 0.0768584128372468


100%|██████████| 615/615 [03:20<00:00,  3.06it/s]
100%|██████████| 154/154 [00:17<00:00,  8.63it/s]


Test loss is 0.07084640057443024
Chamfer loss is 0.02712406776845455
Test loss after epoch # 3 is : 0.07084640057443024


100%|██████████| 615/615 [03:20<00:00,  3.06it/s]
100%|██████████| 154/154 [00:17<00:00,  8.71it/s]


Test loss is 0.06808276268859188
Chamfer loss is 0.026009587571024895
Test loss after epoch # 4 is : 0.06808276268859188


100%|██████████| 615/615 [03:21<00:00,  3.06it/s]
100%|██████████| 154/154 [00:17<00:00,  8.69it/s]


Test loss is 0.07175333754389317
Chamfer loss is 0.027904780581593513


100%|██████████| 615/615 [03:19<00:00,  3.08it/s]
100%|██████████| 154/154 [00:17<00:00,  8.64it/s]


Test loss is 0.07996066870143662
Chamfer loss is 0.029882607981562614


100%|██████████| 615/615 [03:18<00:00,  3.09it/s]
100%|██████████| 154/154 [00:17<00:00,  8.72it/s]


Test loss is 0.07008411447432908
Chamfer loss is 0.026963986456394196


100%|██████████| 615/615 [03:19<00:00,  3.08it/s]
100%|██████████| 154/154 [00:17<00:00,  8.73it/s]


Test loss is 0.06795093171637166
Chamfer loss is 0.025990353897213936
Test loss after epoch # 8 is : 0.06795093171637166


100%|██████████| 615/615 [03:19<00:00,  3.09it/s]
100%|██████████| 154/154 [00:17<00:00,  8.68it/s]


Test loss is 0.06496807374060154
Chamfer loss is 0.025673776865005493
Test loss after epoch # 9 is : 0.06496807374060154


100%|██████████| 615/615 [03:18<00:00,  3.10it/s]
100%|██████████| 154/154 [00:17<00:00,  8.69it/s]


Test loss is 0.07054179984253722
Chamfer loss is 0.027509871870279312


100%|██████████| 615/615 [03:27<00:00,  2.97it/s]
100%|██████████| 154/154 [00:18<00:00,  8.50it/s]


Test loss is 0.06399054906343098
Chamfer loss is 0.024847226217389107
Test loss after epoch # 11 is : 0.06399054906343098


100%|██████████| 615/615 [03:20<00:00,  3.07it/s]
100%|██████████| 154/154 [00:18<00:00,  8.19it/s]


Test loss is 0.0668097310397145
Chamfer loss is 0.025504272431135178


100%|██████████| 615/615 [03:25<00:00,  2.99it/s]
 41%|████      | 63/154 [00:12<00:05, 15.22it/s]

## Inferenece

In [12]:
it = iter(test_loader)

In [13]:
data = it._next_data()

In [14]:
template = data[0]
source= data[1]
igt= data[2]
template = template.to(device)
source = source.to(device)
igt = igt.to(device)
source = source - torch.mean(source, dim=1, keepdim=True)
template = template - torch.mean(template, dim=1, keepdim=True)

In [15]:
model.eval()
output = model(template, source, max_iteration=3)
loss_val = chamfer_distance(template, output['transformed_source'])
print(loss_val)

(tensor(0.0018, device='cuda:0', grad_fn=<AddBackward0>), None)


In [16]:
temp_cpu = template.cpu()
source_pred_cpu = output['transformed_source'].cpu()
source_cpu = source.cpu()

In [17]:
for i in range(8):
    template_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(temp_cpu.detach().numpy()[i,:,:]))
    template_cloud.paint_uniform_color(np.array([0,0,1]))#B
    source_pred_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(source_pred_cpu.detach().numpy()[i,:,:]))
    source_pred_cloud.paint_uniform_color(np.array([1,0,0]))#R
    source_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(source_cpu.detach().numpy()[i,:,:]))
    source_cloud.paint_uniform_color(np.array([0,1,0]))#G
    o3d.visualization.draw_geometries([template_cloud, source_pred_cloud, source_cloud])