In [21]:
import torch
import torch.nn as nn
import torch.profiler
from pyg_pointnet2 import PyGPointNet2NoColor
import loralib as lora
import open3d as o3d
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import time
from pc_label_map import color_map

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# File path of pretrained model
pretrained_path = "checkpoints/pointnet2_s3dis_transform_seg_x3_45_checkpoint.pth"

# File path of LoRA weights
lora_path = "checkpoints/smartlab_lora_weights_x3_50_20250831.pth"

# File path of point cloud to be segmented
pcd_path = "C:/Users/yanpe/OneDrive - Metropolia Ammattikorkeakoulu Oy/Research/data/smartlab/SmartLab_2024_E57_Single_5mm.pcd"

In [23]:
def apply_lora(module, r=8, alpha=16, verbose=False):
    """
    Recursively replaces Linear layers with LoRA-enabled layers.
    Handles custom modules like MLP, SAModule, and FPModule.
    """
    # Special handling for MLP modules which likely contain multiple linear layers
    if hasattr(module, '__class__') and module.__class__.__name__ == 'MLP':
        if verbose:
            print(f"Processing MLP module: {module}")
        # Handle linear layers inside the MLP
        if hasattr(module, 'lins'):
            for i, lin in enumerate(module.lins):
                # Check if the layer has the necessary attributes of a Linear layer
                if hasattr(lin, 'in_channels') and hasattr(lin, 'out_channels') and hasattr(lin, 'weight'):
                    lora_layer = lora.Linear(
                        in_features=lin.in_channels,
                        out_features=lin.out_channels,
                        r=r,
                        lora_alpha=alpha
                    )
                    lora_layer.weight.data = lin.weight.data.clone()
                    if hasattr(lin, 'bias') and lin.bias is not None:
                        lora_layer.bias.data = lin.bias.data.clone()
                    module.lins[i] = lora_layer
                    if verbose:
                        print(f"Replaced MLP.lins[{i}] with LoRA ({lin.__class__.__name__})")

    # Process all named children modules
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            # Replace this Linear layer with LoRA
            lora_layer = lora.Linear(
                in_features=child.in_features,
                out_features=child.out_features,
                r=r,
                lora_alpha=alpha
            )
            
            # Copy original weights and biases
            lora_layer.weight.data = child.weight.data.clone()
            if child.bias is not None:
                lora_layer.bias.data = child.bias.data.clone()
            
            # Replace the layer
            setattr(module, name, lora_layer)
            if verbose:
                print(f"Replaced {name} with LoRA")
        elif isinstance(child, nn.Sequential):
            # Special handling for Sequential containers
            for idx, layer in enumerate(child):
                if isinstance(layer, nn.Linear):
                    lora_layer = lora.Linear(
                        in_features=layer.in_features,
                        out_features=layer.out_features,
                        r=r,
                        lora_alpha=alpha
                    )
                    lora_layer.weight.data = layer.weight.data.clone()
                    if layer.bias is not None:
                        lora_layer.bias.data = layer.bias.data.clone()
                    child[idx] = lora_layer
                    if verbose:
                        print(f"Replaced {name}[{idx}] with LoRA")
                else:
                    # Recursively apply to layers within Sequential that aren't Linear
                    apply_lora(layer, r, alpha, verbose)
        else:
            # Recursively apply to all other nested submodules
            apply_lora(child, r, alpha, verbose)

In [24]:
# Initialize model
model = PyGPointNet2NoColor(num_classes=13).to(device)
checkpoint = torch.load(pretrained_path, map_location=device)
# Extract the model state dictionary
model_state_dict = checkpoint['model_state_dict']
model.load_state_dict(model_state_dict, strict=False)
# Apply LoRA layers
apply_lora(model, r=8, alpha=16, verbose=True)

Processing MLP module: MLP(6, 64, 64, 128)
Replaced MLP.lins[0] with LoRA (Linear)
Replaced MLP.lins[1] with LoRA (Linear)
Replaced MLP.lins[2] with LoRA (Linear)
Replaced 0 with LoRA
Replaced 1 with LoRA
Replaced 2 with LoRA
Processing MLP module: MLP(131, 128, 128, 256)
Replaced MLP.lins[0] with LoRA (Linear)
Replaced MLP.lins[1] with LoRA (Linear)
Replaced MLP.lins[2] with LoRA (Linear)
Replaced 0 with LoRA
Replaced 1 with LoRA
Replaced 2 with LoRA
Processing MLP module: MLP(259, 256, 512, 1024)
Replaced MLP.lins[0] with LoRA (Linear)
Replaced MLP.lins[1] with LoRA (Linear)
Replaced MLP.lins[2] with LoRA (Linear)
Replaced 0 with LoRA
Replaced 1 with LoRA
Replaced 2 with LoRA
Processing MLP module: MLP(1280, 256, 256)
Replaced MLP.lins[0] with LoRA (Linear)
Replaced MLP.lins[1] with LoRA (Linear)
Replaced 0 with LoRA
Replaced 1 with LoRA
Processing MLP module: MLP(384, 256, 128)
Replaced MLP.lins[0] with LoRA (Linear)
Replaced MLP.lins[1] with LoRA (Linear)
Replaced 0 with LoRA
Repla

In [25]:
# Load LoRA weights
model.load_state_dict(torch.load(lora_path), strict=False)
model.to(device)
model.eval()

PyGPointNet2NoColor(
  (sa1_module): SAModule(
    (conv): PointNetConv(local_nn=MLP(6, 64, 64, 128), global_nn=None)
  )
  (sa2_module): SAModule(
    (conv): PointNetConv(local_nn=MLP(131, 128, 128, 256), global_nn=None)
  )
  (sa3_module): GlobalSAModule(
    (nn): MLP(259, 256, 512, 1024)
  )
  (fp3_module): FPModule(
    (nn): MLP(1280, 256, 256)
  )
  (fp2_module): FPModule(
    (nn): MLP(384, 256, 128)
  )
  (fp1_module): FPModule(
    (nn): MLP(131, 128, 128, 128)
  )
  (mlp): MLP(128, 128, 128, 13)
  (lin1): Linear(in_features=128, out_features=128, bias=True)
  (lin2): Linear(in_features=128, out_features=128, bias=True)
  (lin3): Linear(in_features=128, out_features=13, bias=True)
)

In [26]:
# Freeze all parameters except LoRA
for param in model.parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if "lora_" in name:
        param.requires_grad = True

In [27]:
# Verify
trainable_params = [n for n, p in model.named_parameters() if p.requires_grad]
from pprint import pprint
print(f"Trainable LoRA parameters: {len(trainable_params)}")
pprint(trainable_params)

Trainable LoRA parameters: 44
['sa1_module.conv.local_nn.lins.0.lora_A',
 'sa1_module.conv.local_nn.lins.0.lora_B',
 'sa1_module.conv.local_nn.lins.1.lora_A',
 'sa1_module.conv.local_nn.lins.1.lora_B',
 'sa1_module.conv.local_nn.lins.2.lora_A',
 'sa1_module.conv.local_nn.lins.2.lora_B',
 'sa2_module.conv.local_nn.lins.0.lora_A',
 'sa2_module.conv.local_nn.lins.0.lora_B',
 'sa2_module.conv.local_nn.lins.1.lora_A',
 'sa2_module.conv.local_nn.lins.1.lora_B',
 'sa2_module.conv.local_nn.lins.2.lora_A',
 'sa2_module.conv.local_nn.lins.2.lora_B',
 'sa3_module.nn.lins.0.lora_A',
 'sa3_module.nn.lins.0.lora_B',
 'sa3_module.nn.lins.1.lora_A',
 'sa3_module.nn.lins.1.lora_B',
 'sa3_module.nn.lins.2.lora_A',
 'sa3_module.nn.lins.2.lora_B',
 'fp3_module.nn.lins.0.lora_A',
 'fp3_module.nn.lins.0.lora_B',
 'fp3_module.nn.lins.1.lora_A',
 'fp3_module.nn.lins.1.lora_B',
 'fp2_module.nn.lins.0.lora_A',
 'fp2_module.nn.lins.0.lora_B',
 'fp2_module.nn.lins.1.lora_A',
 'fp2_module.nn.lins.1.lora_B',
 'fp1_

In [30]:
# Load point cloud
pcd = o3d.io.read_point_cloud(pcd_path)
# Move the point cloud to its min(x,y,z) corner 
def move_to_corner(points):     
    min_xyz = points.min(axis=0)    
    moved_points = points - min_xyz    
    return moved_points

moved_points = move_to_corner(np.array(pcd.points))
pcd.points = o3d.utility.Vector3dVector(moved_points)

# Downsample the point cloud with a voxel of 0.03
downpcd = pcd.voxel_down_sample(voxel_size=0.03)
# Check size of point cloud
print(len(downpcd.points))

866900


In [31]:
# Normalized coordinates as x features
def normalize_points_corner(points):    
    min_vals = np.min(points, axis=0)
    shifted_points = points - min_vals    
    max_vals = np.max(shifted_points, axis=0)
    scale = max_vals.copy()    
    scale[scale == 0] = 1    
    normalized_points = shifted_points / scale
    return normalized_points

normalized = normalize_points_corner(np.array(downpcd.points))

In [32]:
# Extract coordinates and colors from the point cloud
down_points = torch.tensor(np.array(downpcd.points), dtype=torch.float32)  
down_colors = torch.tensor(np.array(downpcd.colors), dtype=torch.float32)
down_normalized = torch.tensor(normalized, dtype=torch.float32)

In [33]:
# Create a Data object with x (3 features) and pos (coordinates)
data = Data(x=down_normalized, pos=down_points)
data = data.to(device)

In [None]:
# Load dataset for inference
dataset = [data]  
num_workers = 0
batch_size = 32
# Create a DataLoader (batch_size can be adjusted as needed)
custom_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers) #, pin_memory=True

In [35]:
# Segmentation
model.eval()

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True,
) as prof:

    with torch.no_grad():
        start_time = time.time()
        for data in custom_loader:
            data = data.to(device)
            with torch.amp.autocast("cuda"):
                predictions = model(data)
            labels = predictions.argmax(dim=-1)
            unique_labels, label_counts = torch.unique(labels, return_counts=True)
            result_labels = torch.stack((unique_labels, label_counts), dim=1).cpu()
            print("Label counts:")
            print(result_labels)
        end_time = time.time()
        print(f"Total inference time: {end_time - start_time:.4f} seconds")  

# Prediction results    
print(prof.key_averages().table(sort_by="cuda_time_total"))

Label counts:
tensor([[     0, 110516],
        [     1, 123230],
        [     2,  75714],
        [     3,   3887],
        [     4, 201839],
        [     6,   9118],
        [     7,   9390],
        [     8,   2160],
        [     9,      2],
        [    10, 248825],
        [    12,  82219]])
Total inference time: 242.2589 seconds
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                     torch_cluster::fps         0

In [None]:
# Assign predicted colors to the point cloud
predicted_colors = color_map[labels.cpu().numpy()]  # Shape: [num_points, 3]
downpcd.colors = o3d.utility.Vector3dVector(predicted_colors)

In [42]:
# Visualize the point cloud with colored labels
o3d.visualization.draw_geometries([downpcd])

In [None]:
# Save the point cloud to a file
save_path = "C:/Users/yanpe/OneDrive - Metropolia Ammattikorkeakoulu Oy/Research/data/smartlab/labelled/Smartlab_pcd_lora_label_pointnet2_x3_0.03_20250831.ply"
o3d.io.write_point_cloud(save_path, downpcd)

True