In [1]:
!pip install trimesh

Collecting trimesh
  Downloading trimesh-4.6.8-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.8-py3-none-any.whl (709 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/709.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m706.6/709.3 kB[0m [31m31.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m709.3/709.3 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.8


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import trimesh
from tqdm import tqdm
import random


In [4]:
class CachedModelNet40Dataset(Dataset):
    def __init__(self, root_dir, num_points=1024, split='train', cache_dir='cache'):
        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.cache_dir = os.path.join(root_dir, cache_dir, split)
        os.makedirs(self.cache_dir, exist_ok=True)

        self.files = []
        self.classes = []

        for class_name in sorted(os.listdir(root_dir)):
            class_path = os.path.join(root_dir, class_name)
            split_path = os.path.join(class_path, split)
            if not os.path.isdir(split_path):
                continue

            self.classes.append(class_name)

            for file in os.listdir(split_path):
                if file.endswith('.obj'):
                    obj_path = os.path.join(split_path, file)
                    npy_path = os.path.join(self.cache_dir, f"{class_name}_{file}.npy")
                    self.files.append((obj_path, npy_path, class_name))

        self.class_to_idx = {cls: idx for idx, cls in enumerate(sorted(set(self.classes)))}

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        obj_path, npy_path, class_name = self.files[idx]
        label = self.class_to_idx[class_name]

        if os.path.exists(npy_path):
            points = np.load(npy_path)
        else:
            mesh = trimesh.load(obj_path, process=False)
            points = np.array(mesh.vertices)
            np.save(npy_path, points)

        # Приведение к num_points
        if points.shape[0] < self.num_points:
            diff = self.num_points - points.shape[0]
            points = np.concatenate([points, points[np.random.choice(points.shape[0], diff)]])
        else:
            choice = np.random.choice(points.shape[0], self.num_points, replace=False)
            points = points[choice]

        # Нормализация
        points = points - np.mean(points, axis=0)
        points = points / np.max(np.linalg.norm(points, axis=1))

        return torch.from_numpy(points).float(), label


In [5]:
class TNet(nn.Module):
    def __init__(self, k=3):
        super().__init__()
        self.k = k

        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size(0)

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        x = torch.max(x, 2)[0]

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        identity = torch.eye(self.k, device=x.device).view(1, self.k * self.k).repeat(batchsize, 1)
        x = x + identity
        x = x.view(-1, self.k, self.k)
        return x


In [6]:
class PointNetClassifier(nn.Module):
    def __init__(self, k=40):  # 40 классов в ModelNet40
        super().__init__()
        self.input_transform = TNet(k=3)
        self.feature_transform = TNet(k=64)

        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)

        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        # x: (B, N, 3) → (B, 3, N)
        x = x.transpose(2, 1)
        trans = self.input_transform(x)
        x = torch.bmm(trans, x)

        x = F.relu(self.bn1(self.conv1(x)))

        trans_feat = self.feature_transform(x)
        x = torch.bmm(trans_feat, x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        x = torch.max(x, 2)[0]  # (B, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.dropout(self.fc2(x))))
        x = self.fc3(x)

        return F.log_softmax(x, dim=1), trans_feat


In [7]:
def feature_transform_regularizer(trans):
    d = trans.size(1)
    batchsize = trans.size(0)
    I = torch.eye(d, device=trans.device).unsqueeze(0).repeat(batchsize, 1, 1)
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)))
    return loss


In [8]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for points, labels in tqdm(loader):
        points, labels = points.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs, trans_feat = model(points)
        loss = criterion(outputs, labels)
        reg_loss = feature_transform_regularizer(trans_feat) * 0.001
        total_batch_loss = loss + reg_loss

        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()
        pred = outputs.max(1)[1]
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total


In [9]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for points, labels in loader:
            points, labels = points.to(device), labels.to(device)
            outputs, _ = model(points)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = outputs.max(1)[1]
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), correct / total


In [10]:
# Настройки
BATCH_SIZE = 32
NUM_POINTS = 1024
EPOCHS = 15
LEARNING_RATE = 0.001

# Загрузка данных
train_dataset = CachedModelNet40Dataset(root_dir='/content/drive/MyDrive/PointNet/ModelNet40', num_points=NUM_POINTS, split='train')
test_dataset = CachedModelNet40Dataset(root_dir='/content/drive/MyDrive/PointNet/ModelNet40', num_points=NUM_POINTS, split='test')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Устройство и модель
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PointNetClassifier(k=len(train_dataset.class_to_idx)).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.NLLLoss()


In [11]:
# Тренировка
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"  Test  Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")

100%|██████████| 307/307 [36:13<00:00,  7.08s/it]


Epoch 1/15
  Train Loss: 2.2278, Accuracy: 0.4338
  Test  Loss: 1.9113, Accuracy: 0.5000


100%|██████████| 307/307 [00:55<00:00,  5.57it/s]


Epoch 2/15
  Train Loss: 1.5686, Accuracy: 0.5759
  Test  Loss: 1.3521, Accuracy: 0.6305


100%|██████████| 307/307 [00:41<00:00,  7.42it/s]


Epoch 3/15
  Train Loss: 1.2553, Accuracy: 0.6559
  Test  Loss: 1.1461, Accuracy: 0.6754


100%|██████████| 307/307 [00:40<00:00,  7.55it/s]


Epoch 4/15
  Train Loss: 1.0931, Accuracy: 0.6910
  Test  Loss: 1.0523, Accuracy: 0.7075


100%|██████████| 307/307 [00:40<00:00,  7.52it/s]


Epoch 5/15
  Train Loss: 0.9336, Accuracy: 0.7276
  Test  Loss: 0.9059, Accuracy: 0.7277


100%|██████████| 307/307 [00:40<00:00,  7.52it/s]


Epoch 6/15
  Train Loss: 0.8217, Accuracy: 0.7607
  Test  Loss: 0.7957, Accuracy: 0.7622


100%|██████████| 307/307 [00:40<00:00,  7.57it/s]


Epoch 7/15
  Train Loss: 0.7536, Accuracy: 0.7812
  Test  Loss: 0.7548, Accuracy: 0.7699


100%|██████████| 307/307 [00:41<00:00,  7.42it/s]


Epoch 8/15
  Train Loss: 0.7209, Accuracy: 0.7843
  Test  Loss: 0.7714, Accuracy: 0.7739


100%|██████████| 307/307 [00:40<00:00,  7.62it/s]


Epoch 9/15
  Train Loss: 0.6510, Accuracy: 0.8042
  Test  Loss: 0.6946, Accuracy: 0.7954


100%|██████████| 307/307 [00:40<00:00,  7.51it/s]


Epoch 10/15
  Train Loss: 0.6407, Accuracy: 0.8071
  Test  Loss: 0.7730, Accuracy: 0.7759


100%|██████████| 307/307 [00:40<00:00,  7.59it/s]


Epoch 11/15
  Train Loss: 0.6104, Accuracy: 0.8119
  Test  Loss: 0.7184, Accuracy: 0.7857


100%|██████████| 307/307 [00:40<00:00,  7.56it/s]


Epoch 12/15
  Train Loss: 0.5572, Accuracy: 0.8304
  Test  Loss: 0.6639, Accuracy: 0.8031


100%|██████████| 307/307 [00:40<00:00,  7.57it/s]


Epoch 13/15
  Train Loss: 0.5382, Accuracy: 0.8329
  Test  Loss: 0.6644, Accuracy: 0.7954


100%|██████████| 307/307 [00:40<00:00,  7.57it/s]


Epoch 14/15
  Train Loss: 0.5506, Accuracy: 0.8270
  Test  Loss: 0.6697, Accuracy: 0.8019


100%|██████████| 307/307 [00:40<00:00,  7.60it/s]


Epoch 15/15
  Train Loss: 0.4961, Accuracy: 0.8433
  Test  Loss: 0.6871, Accuracy: 0.8023


In [12]:
torch.save(model.state_dict(), 'last_pointnet_model.pth')
print("📦 Сохранены последние веса модели.")

📦 Сохранены последние веса модели.
