In [None]:
import time
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import random_split,Subset
import torch
import torch.nn as nn
import loralib as lora
import torch.nn.functional as F
from torchmetrics import JaccardIndex
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.typing import WITH_TORCH_CLUSTER

from pyg_pointnet2 import PyGPointNet2NoColor
from pc_dataset import H5PCDataset

if not WITH_TORCH_CLUSTER:
    quit("This example requires 'torch-cluster'")

In [3]:
# Empty the CUDA cache
torch.cuda.empty_cache()

In [5]:
# take out colors
class SelectLast3Features:
    def __call__(self, data):
        # If data.x is defined, select only its last 3 features.
        if data.x is not None:
            data.x = data.x[:, -3:]
        return data

# transform and pre_transform

transform = T.Compose([
    T.RandomJitter(0.01),
    T.RandomRotate(15, axis=0),
    T.RandomRotate(15, axis=1),
    T.RandomRotate(15, axis=2)
])

pre_transform =  T.Compose([
    #T.NormalizeScale(),
    SelectLast3Features()
    ])

#h5_file_path = "../docs/sim_pc_dataset_moved.h5" # local file path
h5_file_path= "../docs/smartLab_sim_dataset.h5" # local file path simulated pc with noise
#h5_file_path ='/scratch/project_2013104/datasets/sim_pc_dataset.h5' # csc file path

full_dataset = H5PCDataset(file_path=h5_file_path, pre_transform=pre_transform)

# Define split sizes (e.g., 80% training and 20% validation)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size

# Randomly split the dataset
train_subset, test_subset = random_split(full_dataset, [train_size, test_size])

In [6]:
#Wrap train_subset in AugmentedSubset
class AugmentedSubset(Subset):
    def __init__(self, subset, transform):
        super().__init__(subset.dataset, subset.indices)
        self.transform = transform

    def __getitem__(self, idx):
        data = super().__getitem__(idx)
        return self.transform(data)

train_dataset = AugmentedSubset(train_subset, transform)
test_dataset = test_subset 

In [5]:
print(train_dataset[0])
print(train_dataset.dataset.num_classes)
print(test_dataset[0])
print(test_dataset.dataset.num_classes)

Data(x=[4096, 3], y=[4096], pos=[4096, 3])
13
Data(x=[4096, 3], y=[4096], pos=[4096, 3])
13


In [7]:
batch_size=32
num_workers=0

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=True)

In [8]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PyGPointNet2NoColor(num_classes=13).to(device)

model_file_path = "checkpoints/pointnet2_s3dis_transform_seg_x3_45_checkpoint.pth"
# Load the checkpoint dictionary
checkpoint = torch.load(model_file_path, map_location=device)
# Extract the model state dictionary
model_state_dict = checkpoint['model_state_dict']

model.load_state_dict(model_state_dict, strict=True)  

<All keys matched successfully>

In [9]:
model.eval()

PyGPointNet2NoColor(
  (sa1_module): SAModule(
    (conv): PointNetConv(local_nn=MLP(6, 64, 64, 128), global_nn=None)
  )
  (sa2_module): SAModule(
    (conv): PointNetConv(local_nn=MLP(131, 128, 128, 256), global_nn=None)
  )
  (sa3_module): GlobalSAModule(
    (nn): MLP(259, 256, 512, 1024)
  )
  (fp3_module): FPModule(
    (nn): MLP(1280, 256, 256)
  )
  (fp2_module): FPModule(
    (nn): MLP(384, 256, 128)
  )
  (fp1_module): FPModule(
    (nn): MLP(131, 128, 128, 128)
  )
  (mlp): MLP(128, 128, 128, 13)
  (lin1): Linear(in_features=128, out_features=128, bias=True)
  (lin2): Linear(in_features=128, out_features=128, bias=True)
  (lin3): Linear(in_features=128, out_features=13, bias=True)
)

In [10]:
optimizer = torch.optim.Adam(
    model.parameters(),  # All parameters are trainable
    lr=1e-4,
    weight_decay=0.01
)

In [11]:
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    total_loss = correct_nodes = total_nodes = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)        
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out.view(-1, 13), data.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

        if (i + 1) % 10 == 0:
            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
                  f'Train Acc: {correct_nodes / total_nodes:.4f}')
            total_loss = correct_nodes = total_nodes = 0
    # If there are remaining batches that were not printed (i.e., i+1 not divisible by 10)
    if total_nodes > 0:
        num_remaining = (i + 1) % 10  # Number of batches in the leftover segment
        print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / num_remaining:.4f} '
              f'Train Acc: {correct_nodes / total_nodes:.4f}')

In [12]:
def train_one_epoch():
    model.train()
    running_loss = correct = total = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        correct      += out.argmax(1).eq(data.y).sum().item()
        total        += data.num_nodes

    # Average loss & accuracy for this epoch
    epoch_loss = running_loss / len(train_loader) 
    epoch_acc  = correct / total

    return epoch_loss, epoch_acc

In [13]:

@torch.no_grad()
def test(loader):
    model.eval()
    jaccard = JaccardIndex(num_classes=loader.dataset.dataset.num_classes, task="multiclass").to(device)
    
    for data in loader:
        data = data.to(device)
        outs = model(data)
        preds = outs.argmax(dim=-1)
        jaccard.update(preds, data.y)
    
    return jaccard.compute().item()


In [14]:
# Record history metrics
loss_history   = []
acc_history    = []
iou_history    = []
time_history   = []

In [15]:
for epoch in range(1, 51):
    # Track epoch start time
    start_time = time.perf_counter()
    #train()

    loss, acc = train_one_epoch()
    loss_history.append(loss)
    acc_history.append(acc)

    iou = test(test_loader)
    iou_history.append(iou)

    # Calculate epoch duration
    epoch_time = time.perf_counter() - start_time
    time_history.append(epoch_time)
    
    # Print results with time
    print(f"Epoch {epoch:02d} | "
          f"Loss: {loss:.4f} | "
          f"Acc: {acc:.4f} | "
          f"IoU: {iou:.4f} | "
          f"Time: {epoch_time:.2f}s")

Epoch 01 | Loss: 2.0434 | Acc: 0.6557 | IoU: 0.2067 | Time: 9.94s
Epoch 02 | Loss: 1.4047 | Acc: 0.7165 | IoU: 0.2130 | Time: 3.35s
Epoch 03 | Loss: 1.1189 | Acc: 0.7526 | IoU: 0.2171 | Time: 3.35s
Epoch 04 | Loss: 1.0157 | Acc: 0.7649 | IoU: 0.2265 | Time: 3.37s
Epoch 05 | Loss: 0.8979 | Acc: 0.7863 | IoU: 0.2214 | Time: 3.35s
Epoch 06 | Loss: 0.8626 | Acc: 0.7931 | IoU: 0.2268 | Time: 3.49s
Epoch 07 | Loss: 0.8247 | Acc: 0.7965 | IoU: 0.2202 | Time: 3.38s
Epoch 08 | Loss: 0.7027 | Acc: 0.8037 | IoU: 0.2356 | Time: 3.42s
Epoch 09 | Loss: 0.8251 | Acc: 0.8072 | IoU: 0.2270 | Time: 3.40s
Epoch 10 | Loss: 0.7522 | Acc: 0.8038 | IoU: 0.2249 | Time: 3.42s
Epoch 11 | Loss: 0.6342 | Acc: 0.8113 | IoU: 0.2227 | Time: 3.42s
Epoch 12 | Loss: 0.6527 | Acc: 0.8096 | IoU: 0.2207 | Time: 3.40s
Epoch 13 | Loss: 0.6660 | Acc: 0.8138 | IoU: 0.2224 | Time: 3.43s
Epoch 14 | Loss: 0.6269 | Acc: 0.8149 | IoU: 0.2258 | Time: 3.41s
Epoch 15 | Loss: 0.6267 | Acc: 0.8155 | IoU: 0.2179 | Time: 3.43s
Epoch 16 |

In [13]:
# Train
import time
begin_time = time.perf_counter()
for epoch in range(1, 101):
    start_time = time.perf_counter()
    train()
    iou = test(test_loader)
    epoch_time = time.perf_counter() - start_time    
    print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}, Time: {epoch_time:.2f}s')
total_time = time.perf_counter() - begin_time
print(f'Training time: {total_time/60:.2f}m')


[10/13] Loss: 0.3839 Train Acc: 0.8706
[13/13] Loss: 0.5363 Train Acc: 0.8651
Epoch: 01, Test IoU: 0.4071, Time: 3.63s
[10/13] Loss: 0.3737 Train Acc: 0.8748
[13/13] Loss: 0.9227 Train Acc: 0.8271
Epoch: 02, Test IoU: 0.3919, Time: 3.31s
[10/13] Loss: 0.3804 Train Acc: 0.8725
[13/13] Loss: 0.7386 Train Acc: 0.8519
Epoch: 03, Test IoU: 0.4086, Time: 3.31s
[10/13] Loss: 0.4086 Train Acc: 0.8672
[13/13] Loss: 0.7433 Train Acc: 0.8758
Epoch: 04, Test IoU: 0.4099, Time: 3.24s
[10/13] Loss: 0.3814 Train Acc: 0.8712
[13/13] Loss: 0.7290 Train Acc: 0.8723
Epoch: 05, Test IoU: 0.4081, Time: 3.23s
[10/13] Loss: 0.3729 Train Acc: 0.8754
[13/13] Loss: 0.3335 Train Acc: 0.8728
Epoch: 06, Test IoU: 0.4165, Time: 3.37s
[10/13] Loss: 0.3956 Train Acc: 0.8661
[13/13] Loss: 0.8483 Train Acc: 0.8787
Epoch: 07, Test IoU: 0.4136, Time: 3.25s
[10/13] Loss: 0.3663 Train Acc: 0.8766
[13/13] Loss: 0.4353 Train Acc: 0.8532
Epoch: 08, Test IoU: 0.4157, Time: 3.27s
[10/13] Loss: 0.3604 Train Acc: 0.8755
[13/13] L

In [None]:
# Visualize with seaborn

# 1) Build DataFrame
df = pd.DataFrame({
    'epoch': range(1, len(loss_history) + 1),
    'Loss': loss_history,
    'Accuracy': acc_history,
    'IoU': iou_history,
    'Time (s)': time_history
})

# 2) Melt to long form for seaborn
df_long = df.melt(id_vars='epoch',
                  var_name='Metric',
                  value_name='Value')

# 3) Plot all metrics in one figure
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_long, x='epoch', y='Value', hue='Metric')  
plt.title('Training Metrics per Epoch')  
plt.xlabel('Epoch')  
plt.ylabel('Metric Value')  
plt.legend(title='Metric', loc='best')
plt.tight_layout()
plt.show()

In [23]:
# Export without the index column
df.to_csv("../docs/finetune_train_metrics.csv", index=False)

In [16]:
checkpoint_path = "checkpoints/smartlab_fine_tuning_transform_x3_50_20250831.pth"

# Save model, optimizer state, and any other info needed
torch.save({
    'epoch': 50,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    #'loss': loss,
    #'test_accuracy': test_acc
}, checkpoint_path)

print("Checkpoint saved successfully!")

Checkpoint saved successfully!


In [14]:
del model