In [1]:
# ✅ DGCNN Training Pipeline for Dental Segmentation
# ---------------------------------------------------
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
from monai.losses import DiceLoss, FocalLoss
from torch_points3d.applications.semseg import SemSegModel
import joblib

# -----------------------------------
# STEP 1: Dataset
# -----------------------------------
class PointCloudDataset(Dataset):
    def __init__(self, file_list):
        self.files = file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        pos = torch.tensor(data['pos'], dtype=torch.float32)
        y = torch.tensor(data['y'], dtype=torch.long) - 1  # [1,2,3] → [0,1,2]
        return {"pos": pos, "y": y}

def load_pointcloud_dataset(folder="pointclouds", cache_path="dgcnn_dataset_cache.pkl", force_reload=False):
    if not force_reload and os.path.exists(cache_path):
        print("🔄 Loading DGCNN dataset from cache...")
        return joblib.load(cache_path)
    else:
        print("📥 Processing point cloud files...")
        files = sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".npz")])
        joblib.dump(files, cache_path)
        print("✅ Dataset cached.")
        return files

files = load_pointcloud_dataset("pointclouds", "dgcnn_dataset_cache.pkl", force_reload=False)
train_files, val_files = train_test_split(files, test_size=0.2, random_state=42)
train_loader = DataLoader(PointCloudDataset(train_files), batch_size=1, shuffle=True)
val_loader = DataLoader(PointCloudDataset(val_files), batch_size=1)



  "cipher": algorithms.TripleDES,
  "class": algorithms.Blowfish,
  "class": algorithms.TripleDES,


ModuleNotFoundError: No module named 'torch_points3d'

In [None]:
# -----------------------------------
# STEP 2: Load DGCNN Model from YAML
# -----------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SemSegModel.load_from_config("dgcnn_dental.yaml").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# -----------------------------------
# STEP 3: Loss Function (Dice + Focal)
# -----------------------------------
dice_loss = DiceLoss(
    to_onehot_y=True,
    softmax=True,
    include_background=True,
    weight=torch.tensor([1.0, 4.0, 4.0]).to(device)
)

focal_loss = FocalLoss(
    to_onehot_y=True,
    weight=torch.tensor([1.0, 4.0, 4.0]).to(device)
)

def loss_fn(pred, target):
    return 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target)

# -----------------------------------
# STEP 4: Training Loop
# -----------------------------------
total_losses = []
dice_scores_class1 = []
dice_scores_class2 = []
dice_scores_class3 = []
avg_dice_scores = []
overall_accuracies = []

best_dice = 0.0

for epoch in range(1, 21):
    model.train()
    total_loss = 0
    train_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for batch in train_iter:
        pos = batch["pos"].to(device)
        y = batch["y"].to(device)

        outputs = model.forward(pos=pos)["logits"]  # Output shape: [N, num_classes]
        loss = loss_fn(outputs, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_iter.set_postfix(loss=loss.item())

    print(f"\n📉 Epoch {epoch} | Total Loss: {total_loss:.4f}")

    # -----------------------------
    # 📊 Validation
    # -----------------------------
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in val_loader:
            pos = batch["pos"].to(device)
            y = batch["y"].to(device)

            out = model.forward(pos=pos)["logits"]
            preds = out.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # 🔁 Remap [0,1,2] → [1,2,3] for reporting
    all_preds = [p + 1 for p in all_preds]
    all_labels = [l + 1 for l in all_labels]

    print("\n📊 Validation Report:")
    print(classification_report(
        all_labels,
        all_preds,
        labels=[1, 2, 3],
        target_names=["Teeth", "Left Canal", "Right Canal"],
        zero_division=0
    ))

    print("🎯 Dice per class:")
    dice_scores = []
    for cls in [1, 2, 3]:
        cls_dice = f1_score(
            np.array(all_labels) == cls,
            np.array(all_preds) == cls
        )
        dice_scores.append(cls_dice)
        print(f"Class {cls} Dice: {cls_dice:.3f}")

    class_1_dice = dice_scores[0]
    class_2_dice = dice_scores[1]
    class_3_dice = dice_scores[2]
    avg_dice = np.mean(dice_scores)

    if avg_dice > best_dice:
        best_dice = avg_dice
        torch.save(model.state_dict(), "best_dgcnn_model.pth")
        print(f"💾 Best model saved with Avg Dice: {avg_dice:.4f}")

    acc = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
    print(f"✅ Overall Accuracy: {acc:.2f}%")

    total_losses.append(total_loss)
    dice_scores_class1.append(class_1_dice)
    dice_scores_class2.append(class_2_dice)
    dice_scores_class3.append(class_3_dice)
    avg_dice_scores.append(avg_dice)
    overall_accuracies.append(acc)

# -----------------------------------
# ✅ Final Summary
# -----------------------------------
best_epoch = np.argmax(avg_dice_scores)
print("\n📈 Training Complete!")
print(f"🏆 Best Model at Epoch {best_epoch + 1}")
print(f"   • Total Loss      : {total_losses[best_epoch]:.4f}")
print(f"   • Dice - Teeth    : {dice_scores_class1[best_epoch]:.4f}")
print(f"   • Dice - Left Canal : {dice_scores_class2[best_epoch]:.4f}")
print(f"   • Dice - Right Canal: {dice_scores_class3[best_epoch]:.4f}")
print(f"   • Avg Dice        : {avg_dice_scores[best_epoch]:.4f}")
print(f"   • Accuracy        : {overall_accuracies[best_epoch]:.2f}%")


In [None]:
import torch
import numpy as np
from torch_points3d.applications.semseg import SemSegModel
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SemSegModel.load_from_config("pointnet2_dental.yaml")
model.load_state_dict(torch.load("best_dgcnn_model.pth"))
model.eval().to(device)

# Load point cloud
data = np.load("pointclouds/sample_case_01.npz")
points = torch.tensor(data['pos'], dtype=torch.float32).unsqueeze(0).to(device)  # [1, N, 3]
true_labels = data['y'] - 1

# Inference
with torch.no_grad():
    out = model(pos_dict={"pos": points})
    preds = out.argmax(dim=1).cpu().numpy()[0]

np.save("predicted_labels_pointnet.npy", preds)

# Visualize
def visualize(points_np, pred_labels, true_labels=None):
    label_colors = {0: 'yellow', 1: 'green', 2: 'red'}
    pred_colors = [label_colors[int(l)] for l in pred_labels]

    fig = plt.figure(figsize=(18, 6))
    for i, angle in enumerate([30, 60, 90]):
        ax = fig.add_subplot(1, 3, i+1, projection='3d')
        ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=pred_colors, s=3)
        ax.set_title(f"Predicted - View {angle}°")
        ax.view_init(elev=10, azim=angle)
        ax.set_axis_off()

    legend = [mpatches.Patch(color='yellow', label='Teeth'),
              mpatches.Patch(color='green', label='Left Canal'),
              mpatches.Patch(color='red', label='Right Canal')]
    plt.legend(handles=legend, loc='upper center', bbox_to_anchor=(0.5, 0.05), ncol=3)
    plt.tight_layout()
    plt.show()

    if true_labels is not None:
        true_colors = [label_colors[int(l)] for l in true_labels]
        fig_gt = plt.figure()
        ax = fig_gt.add_subplot(111, projection='3d')
        ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=true_colors, s=3)
        ax.set_title("Ground Truth")
        ax.set_axis_off()
        plt.show()

visualize(points.squeeze(0).cpu().numpy(), preds, true_labels)


In [None]:
import json
import matplotlib.pyplot as plt

# 📁 Save metrics
metrics = {
    "losses": total_losses,
    "dice_teeth": dice_scores_class1,
    "dice_left": dice_scores_class2,
    "dice_right": dice_scores_class3,
    "avg_dice": avg_dice_scores,
    "accuracy": overall_accuracies
}
with open("metrics_dgcnn.json", "w") as f:
    json.dump(metrics, f)  # 🔁 Change filename for each model

# 📊 Plot
epochs = range(1, len(avg_dice_scores) + 1)
plt.figure(figsize=(12, 6))
plt.plot(epochs, avg_dice_scores, label='Avg Dice')
plt.plot(epochs, overall_accuracies, label='Accuracy')
plt.title("DGCNN: Dice & Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.legend()
plt.grid(True)
plt.show()
