In [1]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset

class PreprocessedScrewDataset(Dataset):
    def __init__(self, root_dir, augment=False):
        self.files = sorted(glob.glob(os.path.join(root_dir, "*.npz")))
        self.augment = augment

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        points = data["points"]  # [8192, 6]
        labels = data["labels"]  # [8192]

        if self.augment:
            points = self._augment(points)

        return torch.from_numpy(points).float(), torch.from_numpy(labels).long()

    def _augment(self, points):
        xyz = points[:, :3]
        normals = points[:, 3:]

        # Jitter
        xyz += np.random.normal(0, 0.005, xyz.shape)

        # Rotation around Z-axis
        theta = np.random.uniform(0, 2*np.pi)
        cos_t, sin_t = np.cos(theta), np.sin(theta)
        rot = np.array([
            [cos_t, -sin_t, 0],
            [sin_t,  cos_t, 0],
            [0,      0,     1]
        ])
        xyz = xyz @ rot.T
        normals = normals @ rot.T

        return np.hstack((xyz, normals))



In [2]:
from torch.utils.data import DataLoader

train_dataset = PreprocessedScrewDataset("preprocessed_data", augment=True)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)


In [3]:
import sys
sys.path.append("PointNet_Pointnet2")

from models.pointnet2_sem_seg import get_model

import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model: 6 input channels (xyz + normals), 2 output classes (screw vs bg)
model = get_model(num_classes=2).to(device)


In [4]:
criterion = nn.NLLLoss()  # since model uses log_softmax
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [5]:
from tqdm import tqdm
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)  # use 0 for debugging


num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = total = 0

    for points, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        points = points.to(device)         # [B, N, 6]
        labels = labels.to(device)         # [B, N]
        points = points.transpose(1, 2)     # [B, 6, N]

        optimizer.zero_grad()
        outputs, _ = model(points)         # [B, N, 2]
        loss = criterion(outputs.reshape(-1, 2), labels.reshape(-1))  # ✅ Fixed using reshape

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(dim=2)      # [B, N]
        correct += (preds == labels).sum().item()
        total += labels.numel()

    acc = correct / total * 100
    print(f" Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%")

 


Epoch 1/20: 100%|██████████| 38/38 [03:40<00:00,  5.80s/it]


 Epoch 1/20 - Loss: 12.9848, Accuracy: 83.58%


Epoch 2/20: 100%|██████████| 38/38 [03:48<00:00,  6.02s/it]


 Epoch 2/20 - Loss: 7.3571, Accuracy: 91.63%


Epoch 3/20: 100%|██████████| 38/38 [03:46<00:00,  5.97s/it]


 Epoch 3/20 - Loss: 6.3225, Accuracy: 92.96%


Epoch 4/20: 100%|██████████| 38/38 [03:49<00:00,  6.03s/it]


 Epoch 4/20 - Loss: 6.3566, Accuracy: 93.26%


Epoch 5/20: 100%|██████████| 38/38 [03:48<00:00,  6.03s/it]


 Epoch 5/20 - Loss: 5.2347, Accuracy: 94.47%


Epoch 6/20: 100%|██████████| 38/38 [03:49<00:00,  6.04s/it]


 Epoch 6/20 - Loss: 4.6914, Accuracy: 94.96%


Epoch 7/20: 100%|██████████| 38/38 [03:46<00:00,  5.95s/it]


 Epoch 7/20 - Loss: 4.1274, Accuracy: 95.70%


Epoch 8/20: 100%|██████████| 38/38 [03:45<00:00,  5.95s/it]


 Epoch 8/20 - Loss: 3.8067, Accuracy: 95.98%


Epoch 9/20: 100%|██████████| 38/38 [03:38<00:00,  5.74s/it]


 Epoch 9/20 - Loss: 3.5093, Accuracy: 96.24%


Epoch 10/20: 100%|██████████| 38/38 [03:49<00:00,  6.03s/it]


 Epoch 10/20 - Loss: 3.3685, Accuracy: 96.59%


Epoch 11/20: 100%|██████████| 38/38 [03:45<00:00,  5.94s/it]


 Epoch 11/20 - Loss: 3.0374, Accuracy: 97.02%


Epoch 12/20: 100%|██████████| 38/38 [03:48<00:00,  6.01s/it]


 Epoch 12/20 - Loss: 2.5820, Accuracy: 97.33%


Epoch 13/20: 100%|██████████| 38/38 [03:47<00:00,  6.00s/it]


 Epoch 13/20 - Loss: 2.3222, Accuracy: 97.64%


Epoch 14/20: 100%|██████████| 38/38 [03:41<00:00,  5.82s/it]


 Epoch 14/20 - Loss: 2.1488, Accuracy: 97.84%


Epoch 15/20: 100%|██████████| 38/38 [03:37<00:00,  5.72s/it]


 Epoch 15/20 - Loss: 1.7954, Accuracy: 98.18%


Epoch 16/20: 100%|██████████| 38/38 [03:47<00:00,  5.98s/it]


 Epoch 16/20 - Loss: 1.5774, Accuracy: 98.46%


Epoch 17/20: 100%|██████████| 38/38 [03:35<00:00,  5.67s/it]


 Epoch 17/20 - Loss: 1.5683, Accuracy: 98.45%


Epoch 18/20: 100%|██████████| 38/38 [03:45<00:00,  5.95s/it]


 Epoch 18/20 - Loss: 1.3714, Accuracy: 98.69%


Epoch 19/20: 100%|██████████| 38/38 [03:32<00:00,  5.59s/it]


 Epoch 19/20 - Loss: 1.1620, Accuracy: 98.89%


Epoch 20/20: 100%|██████████| 38/38 [03:45<00:00,  5.94s/it]

 Epoch 20/20 - Loss: 1.0275, Accuracy: 99.01%





In [6]:
torch.save(model.state_dict(), "pointnet2_screw_segmentation_final.pth")

