In [None]:
import os
import shutil
import random
from pathlib import Path
import wandb

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

In [3]:
src_root = "/kaggle/input/modelnet10-princeton-3d-object-dataset/ModelNet10"
dst_root = "/kaggle/working/ModelNet10"

os.makedirs(dst_root, exist_ok=True)
random.seed(42)

for class_name in os.listdir(src_root):
    class_path = os.path.join(src_root, class_name)
    if not os.path.isdir(class_path):
        continue

    train_files = list(Path(class_path, "train").glob("*.off"))
    test_files = list(Path(class_path, "test").glob("*.off"))
    all_files = train_files + test_files
    random.shuffle(all_files)

    total = len(all_files)
    train_split = int(0.8 * total)
    val_split = int(0.1 * total)

    train_set = all_files[:train_split]
    val_set = all_files[train_split:train_split + val_split]
    test_set = all_files[train_split + val_split:]

    for subset, file_list in zip(["train", "val", "test"], [train_set, val_set, test_set]):
        subset_dir = os.path.join(dst_root, class_name, subset)
        os.makedirs(subset_dir, exist_ok=True)

        for file_path in file_list:
            shutil.copy(file_path, subset_dir)

In [4]:
class PointNetDataset(Dataset):
    def __init__(self, base_path='/kaggle/working/ModelNet10', mode='train', n_points=1024):
        super().__init__()
        self.n_points = n_points
        self.base_path = base_path
        self.mode = mode
        self.class_idx = {'bathtub': 0, 'bed': 1, 'chair': 2, 'desk': 3, 'dresser': 4,
                          'monitor': 5,'night_stand': 6, 'sofa': 7, 'table': 8, 'toilet': 9}

        self.samples = []
        self._prepare_dataset()

    def _prepare_dataset(self):
        for class_name in os.listdir(self.base_path):
            if class_name not in self.class_idx:
                continue
                
            class_id = self.class_idx[class_name]
            
            class_path = os.path.join(self.base_path, class_name, self.mode)
            for file in os.listdir(class_path):
                file_path = os.path.join(class_path, file)

                with open(file_path, 'r') as f:
                    lines = f.readlines()
                parts = lines[1].strip().split()
                verts = []

                num_verts = int(parts[0])
                for i in range(2, 2 + num_verts):
                    verts.append(list(map(float, lines[i].strip().split())))

                verts = self._normalize_point_cloud(verts)
                verts = self._sample_point_cloud(verts, self.n_points)

                self.samples.append((verts, class_id))

    def _normalize_point_cloud(self, verts):
        verts = np.array(verts)
        centroid = np.mean(verts, axis=0)
        verts = verts - centroid
        furthest_distance = np.max(np.linalg.norm(verts, axis=1))
        verts = verts / furthest_distance
        return verts

    def _sample_point_cloud(self, verts, n_points):
        return verts[np.random.choice(len(verts), n_points, replace=(len(verts) < n_points))]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        verts, label = self.samples[idx]

        if self.mode == 'train':
            verts = verts[np.random.permutation(len(verts))]
        
        verts = torch.tensor(verts, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        return verts, label

In [5]:
train_set = PointNetDataset(mode='train')
val_set = PointNetDataset(mode='val')
test_set = PointNetDataset(mode='test')

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=2)

In [None]:
class InputTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_mlp = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Linear(512, 256),
            nn.Linear(256, 9)
        )

        nn.init.constant_(self.fc[2].weight, 0)
        identity = torch.eye(3).view(9)
        nn.init.constant_(self.fc[2].bias, 0)
        with torch.no_grad():
            self.fc[2].bias.copy_(identity)

    def forward(self, x):
        batch_sz = x.shape[0]
        out = x.transpose(1, 2) # (B, 3, N)
        out = self.shared_mlp(out) # (B, 1024, N)
        out = torch.max(out, dim=2)[0] # (B, 1024)
        out = self.fc(out) # (B, 9)
        out = out.reshape(batch_sz, 3, -1) # (B, 3, 3)
        return out

In [None]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_mlp = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Linear(512, 256),
            nn.Linear(256, 4096)
        )

        nn.init.constant_(self.fc[2].weight, 0)
        identity = torch.eye(64).view(4096)
        nn.init.constant_(self.fc[2].bias, 0)
        with torch.no_grad():
            self.fc[2].bias.copy_(identity)

    def forward(self, x):
        batch_sz = x.shape[0]
        out = x.transpose(1, 2) # (B, 64, N)
        out = self.shared_mlp(out) # (B, 1024, N)
        out = torch.max(out, dim=2)[0] # (B, 1024)
        out = self.fc(out) # (B, 4096)
        out = out.reshape(batch_sz, 64, -1) # (B, 64, 64)
        return out

In [6]:
class PointNetClassifier(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        self.input_tnet = InputTNet()
        self.shared_mlp = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        self.shared_mlp_2 = nn.Sequential(
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
        
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        input_trans = self.input_tnet(x) # (B, 3, 3)
        out = torch.bmm(x, input_trans) # (B, N, 3) * (B, 3, 3) = (B, N, 3)
        out = out.transpose(1, 2) # (B, 3, N)
        out = self.shared_mlp(out) # (B, 64, N)
        out = out.transpose(1, 2) # (B, N, 64)
        feature_trans = self.feature_tnet(out) # (B, 64, 64)
        out = torch.bmm(out, feature_trans) # (B, N, 64) * (B, 64, 64) = (B, N, 64)
        out = out.transpose(1, 2) # (B, 64, N)
        out = self.shared_mlp_2(out) # (B, 1024, N)
        out = torch.max(out, dim=2)[0] # (B, 1024)
        out = self.fc(out) # (B, num_classes)
        return out, input_trans, feature_trans

In [None]:
class PointNetSegmentation(nn.Module):
    
    def __init__(self, num_classes=10):
        super().__init__()
        self.input_tnet = InputTNet()
        self.shared_mlp = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        self.shared_mlp_2 = nn.Sequential(
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
        )
        self.seg_mlp = nn.Sequential(
            nn.Conv1d(1088, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Conv1d(256, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        
            nn.Conv1d(128, num_classes, 1)
        )

    def forward(self, x):
        input_trans = self.input_tnet(x) # (B, 3, 3)
        out = torch.bmm(x, input_trans) # (B, N, 3) * (B, 3, 3) = (B, N, 3)
        out = out.transpose(1, 2) # (B, 3, N)
        out = self.shared_mlp(out) # (B, 64, N)
        out = out.transpose(1, 2) # (B, N, 64)
        feature_trans = self.feature_tnet(out) # (B, 64, 64)
        feature_out = torch.bmm(out, feature_trans) # (B, N, 64) * (B, 64, 64) = (B, N, 64)
        feature_out = feature_out.transpose(1, 2) # (B, 64, N)
        out = self.shared_mlp_2(feature_out) # (B, 1024, N)
        out = torch.cat((feature_out, out), dim=1) # (B, 1088, N)
        out = self.seg_mlp(out) # (B, num_classes, N)
        out = out.transpose(1, 2) # (B, N, num_classes)
        return out, input_trans, feature_trans

In [7]:
def orthogonality_loss(trans):
    batch_size, k, _ = trans.size()
    I = torch.eye(k, device=trans.device).unsqueeze(0).expand(batch_size, -1, -1)
    trans_transpose = trans.transpose(2, 1)
    prod = torch.bmm(trans, trans_transpose)
    diff = prod - I # (B, k, k)
    return torch.mean(torch.norm(diff, dim=(1, 2)))

In [None]:
import wandb
from tqdm.notebook import tqdm

wandb.init(project="pointnet-classification")

model = PointNetClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

best_val_loss = float("inf")
patience = 3
patience_counter = 0
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    train_loss, train_total = 0.0, 0

    # Training loop with tqdm progress bar
    train_loader_tqdm = tqdm(train_loader, desc=f"[Epoch {epoch+1}] Training")
    for verts, labels in train_loader_tqdm:
        verts, labels = verts.to(device), labels.to(device)
        out, input_trans, feature_trans = model(verts)

        cls_loss = F.cross_entropy(out, labels)
        inp_trans_loss = orthogonality_loss(input_trans)
        feat_trans_loss = orthogonality_loss(feature_trans)
        loss = cls_loss + 0.001 * (inp_trans_loss + feat_trans_loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_size = verts.size(0)
        train_loss += loss.item() * batch_size
        train_total += batch_size

        # Update tqdm bar with current batch loss
        train_loader_tqdm.set_postfix(batch_loss=loss.item())

    avg_train_loss = train_loss / train_total

    model.eval()
    val_loss, val_total = 0.0, 0

    # Validation loop with tqdm
    val_loader_tqdm = tqdm(val_loader, desc=f"[Epoch {epoch+1}] Validation")
    with torch.no_grad():
        for verts, labels in val_loader_tqdm:
            verts, labels = verts.to(device), labels.to(device)
            out, input_trans, feature_trans = model(verts)

            cls_loss = F.cross_entropy(out, labels)
            inp_trans_loss = orthogonality_loss(input_trans)
            feat_trans_loss = orthogonality_loss(feature_trans)
            loss = cls_loss + 0.001 * (inp_trans_loss + feat_trans_loss)

            batch_size = verts.size(0)
            val_loss += loss.item() * batch_size
            val_total += batch_size

            val_loader_tqdm.set_postfix(batch_loss=loss.item())

    avg_val_loss = val_loss / val_total

    # Log to Weights & Biases
    wandb.log({
        "epoch": epoch,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss
    })

    # Print epoch summary
    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # Early stopping logic
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "pointnet_classifier.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break