In [20]:
import os
import h5py
import numpy as np
from torch.utils.data import Dataset, random_split
import torch
import torch.nn as nn
import loralib as lora
import torch.nn.functional as F
from torchmetrics import JaccardIndex
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.typing import WITH_TORCH_CLUSTER
from torch_geometric.data import Data

from pyg_pointnet2 import PyGPointNet2NoColorLoRa
from pc_dataset import H5PCDataset


if not WITH_TORCH_CLUSTER:
    quit("This example requires 'torch-cluster'")

In [2]:
# take out colors
class SelectLast3Features:
    def __call__(self, data):
        # If data.x is defined, select only its last 3 features.
        if data.x is not None:
            data.x = data.x[:, -3:]
        return data

# transform and pre_transform
transform = T.Compose([
    T.RandomJitter(0.01),
    T.RandomRotate(15, axis=0),
    T.RandomRotate(15, axis=1),
    T.RandomRotate(15, axis=2),
    SelectLast3Features()
])

full_dataset = H5PCDataset(file_path='../docs/sim_pc_dataset.h5', transform=transform)

# Define split sizes (e.g., 80% training and 20% validation)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size

# Randomly split the dataset
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

In [3]:
print(train_dataset[0])
print(train_dataset.dataset.num_classes)

Data(x=[4096, 3], y=[4096], pos=[4096, 3])
13


In [23]:
batch_size=64
num_workers=0

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=True)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PyGPointNet2NoColorLoRa(num_classes=13).to(device)

In [16]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PyGPointNet2NoColorLoRa(num_classes=13).to(device)
model.load_state_dict(torch.load("checkpoints/pointnet2_s3dis_colorless_seg_x3_45_checkpoint.pth", map_location=device), strict=False)  # Load pretrained weights

# Freeze all parameters except LoRA
for param in model.parameters():
    param.requires_grad = False

# Unfreeze LoRA parameters
for name, param in model.named_parameters():
    if "lora_" in name:  # LoRA parameters have "lora_A" or "lora_B" in their names
        param.requires_grad = True

In [17]:
model.eval()

PyGPointNet2NoColorLoRa(
  (sa1_module): SAModule(
    (conv): PointNetConv(local_nn=Sequential(
      (0): Linear(in_features=6, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=128, bias=True)
    ), global_nn=None)
  )
  (sa2_module): SAModule(
    (conv): PointNetConv(local_nn=Sequential(
      (0): Linear(in_features=131, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=256, bias=True)
    ), global_nn=None)
  )
  (sa3_module): GlobalSAModule(
    (nn): Sequential(
      (0): Linear(in_features=259, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=1024, bias=True)
    )
  )
  (fp3_module): FPModule(
    (nn)

In [18]:
# Verify trainable parameters
trainable_params = [name for name, p in model.named_parameters() if p.requires_grad]
print("Trainable parameters:", trainable_params)

# After freezing the base model and enabling LoRA:
optimizer = torch.optim.Adam(
    [p for p in model.parameters() if p.requires_grad],  # Manual LoRA params
    lr=1e-3,
    weight_decay=0.01
)

Trainable parameters: ['sa1_module.conv.local_nn.0.lora_A', 'sa1_module.conv.local_nn.0.lora_B', 'sa1_module.conv.local_nn.2.lora_A', 'sa1_module.conv.local_nn.2.lora_B', 'sa1_module.conv.local_nn.4.lora_A', 'sa1_module.conv.local_nn.4.lora_B', 'sa2_module.conv.local_nn.0.lora_A', 'sa2_module.conv.local_nn.0.lora_B', 'sa2_module.conv.local_nn.2.lora_A', 'sa2_module.conv.local_nn.2.lora_B', 'sa2_module.conv.local_nn.4.lora_A', 'sa2_module.conv.local_nn.4.lora_B', 'sa3_module.nn.0.lora_A', 'sa3_module.nn.0.lora_B', 'sa3_module.nn.2.lora_A', 'sa3_module.nn.2.lora_B', 'sa3_module.nn.4.lora_A', 'sa3_module.nn.4.lora_B', 'fp3_module.nn.0.lora_A', 'fp3_module.nn.0.lora_B', 'fp3_module.nn.2.lora_A', 'fp3_module.nn.2.lora_B', 'fp2_module.nn.0.lora_A', 'fp2_module.nn.0.lora_B', 'fp2_module.nn.2.lora_A', 'fp2_module.nn.2.lora_B', 'fp1_module.nn.0.lora_A', 'fp1_module.nn.0.lora_B', 'fp1_module.nn.2.lora_A', 'fp1_module.nn.2.lora_B', 'fp1_module.nn.4.lora_A', 'fp1_module.nn.4.lora_B', 'mlp.0.lora_A

In [19]:
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    total_loss = correct_nodes = total_nodes = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out.view(-1, 13), data.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

        if (i + 1) % 10 == 0:
            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
                  f'Train Acc: {correct_nodes / total_nodes:.4f}')
            total_loss = correct_nodes = total_nodes = 0
    # If there are remaining batches that were not printed (i.e., i+1 not divisible by 10)
    if total_nodes > 0:
        num_remaining = (i + 1) % 10  # Number of batches in the leftover segment
        print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / num_remaining:.4f} '
              f'Train Acc: {correct_nodes / total_nodes:.4f}')

In [21]:

@torch.no_grad()
def test(loader):
    model.eval()
    jaccard = JaccardIndex(num_classes=loader.dataset.dataset.num_classes, task="multiclass").to(device)
    
    for data in loader:
        data = data.to(device)
        outs = model(data)
        preds = outs.argmax(dim=-1)
        jaccard.update(preds, data.y)
    
    return jaccard.compute().item()


In [24]:
# Train
import time
begin_time = time.perf_counter()
for epoch in range(1, 46):
    start_time = time.perf_counter()
    train()
    iou = test(test_loader)
    epoch_time = time.perf_counter() - start_time
    print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')
print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}, Time: {epoch_time:.2f}s')
total_time = time.perf_counter() - begin_time
print(f'Training time: {(total_time)/60:.2f}m')



[7/7] Loss: 2.5402 Train Acc: 0.0728
Epoch: 01, Test IoU: 0.0054
[7/7] Loss: 2.5244 Train Acc: 0.1614
Epoch: 02, Test IoU: 0.0393
[7/7] Loss: 2.4909 Train Acc: 0.3302
Epoch: 03, Test IoU: 0.0393
[7/7] Loss: 2.4235 Train Acc: 0.3876
Epoch: 04, Test IoU: 0.0393
[7/7] Loss: 2.2363 Train Acc: 0.3902
Epoch: 05, Test IoU: 0.0393
[7/7] Loss: 1.8473 Train Acc: 0.3914
Epoch: 06, Test IoU: 0.0393
[7/7] Loss: 1.8038 Train Acc: 0.3870
Epoch: 07, Test IoU: 0.0393
[7/7] Loss: 1.7454 Train Acc: 0.3503
Epoch: 08, Test IoU: 0.0393
[7/7] Loss: 1.7332 Train Acc: 0.3261
Epoch: 09, Test IoU: 0.0393
[7/7] Loss: 1.6532 Train Acc: 0.3479
Epoch: 10, Test IoU: 0.0393
[7/7] Loss: 1.6404 Train Acc: 0.3800
Epoch: 11, Test IoU: 0.0393
[7/7] Loss: 1.6674 Train Acc: 0.3845
Epoch: 12, Test IoU: 0.0393
[7/7] Loss: 1.6904 Train Acc: 0.3805
Epoch: 13, Test IoU: 0.0393
[7/7] Loss: 1.6617 Train Acc: 0.3729
Epoch: 14, Test IoU: 0.0393
[7/7] Loss: 1.6064 Train Acc: 0.3692
Epoch: 15, Test IoU: 0.0393
[7/7] Loss: 1.6412 Train 

In [25]:
torch.save(lora.lora_state_dict(model), "checkpoints/smartlab_lora_weights_x3_45_20250416.pth")