In [1]:
import os
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
from glob import glob
from tqdm.auto import tqdm

import wandb

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, PointConv, fps, global_max_pool, radius

In [2]:
wandb.init(project="pyg-point-cloud", entity="geekyrakshit", job_type="test/train")

config = wandb.config
config.sample_points = 1024

categories = sorted([
    x.split(os.sep)[-2]
    for x in glob(os.path.join("ModelNet10", "raw", '*', ''))
])
categories.pop(7)
categories.pop(7)
config.categories = categories

config.batch_size = 32
config.num_workers = 6

config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(config.device)

config.learning_rate = 1e-4
config.epochs = 20

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
pre_transform = T.NormalizeScale()
transform = T.SamplePoints(config.sample_points)


train_dataset = ModelNet(
    root="ModelNet10",
    name='10',
    train=True,
    transform=transform,
    pre_transform=pre_transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers
)


val_dataset = ModelNet(
    root="ModelNet10",
    name='10',
    train=False,
    transform=transform,
    pre_transform=pre_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers
)

In [4]:
class SetAbstraction(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = PointConv(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch

In [5]:
class GlobalSetAbstraction(torch.nn.Module):
    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch

In [6]:
class PointNet2(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SetAbstraction(0.5, 0.2, MLP([3, 64, 64, 128]))
        self.sa2_module = SetAbstraction(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
        self.sa3_module = GlobalSetAbstraction(MLP([256 + 3, 256, 512, 1024]))

        self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)
        x, pos, batch = sa3_out

        return self.mlp(x).log_softmax(dim=-1)

In [7]:
model = PointNet2().to(device)
optimizer = torch.optim.Adam(
    model.parameters(), lr=config.learning_rate
)

In [9]:
def train_step(epoch):
    model.train()
    epoch_loss, correct = 0, 0
    num_train_examples = len(train_loader)
    
    progress_bar = tqdm(
        range(num_train_examples),
        desc=f"Training Epoch {epoch}/{config.epochs}"
    )
    for batch_idx in progress_bar:
        data = next(iter(train_loader)).to(device)
        
        optimizer.zero_grad()
        prediction = model(data)
        loss = F.nll_loss(prediction, data.y)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        correct += prediction.max(1)[1].eq(data.y).sum().item()
    
    epoch_loss = epoch_loss / num_train_examples
    epoch_accuracy = correct / len(train_loader.dataset)
    
    wandb.log({
        "Train/Loss": epoch_loss,
        "Train/Accuracy": epoch_accuracy
    })


def val_step(epoch):
    model.eval()
    epoch_loss, correct = 0, 0
    num_val_examples = len(val_loader)
    
    progress_bar = tqdm(
        range(num_val_examples),
        desc=f"Validation Epoch {epoch}/{config.epochs}"
    )
    for batch_idx in progress_bar:
        data = next(iter(val_loader)).to(device)
        
        with torch.no_grad():
            prediction = model(data)
        
        loss = F.nll_loss(prediction, data.y)
        epoch_loss += loss.item()
        correct += prediction.max(1)[1].eq(data.y).sum().item()
    
    epoch_loss = epoch_loss / num_val_examples
    epoch_accuracy = correct / len(val_loader.dataset)
    
    wandb.log({
        "Validation/Loss": epoch_loss,
        "Validation/Accuracy": epoch_accuracy
    })

In [10]:
for epoch in range(1, config.epochs + 1):
    train_step(epoch)
    val_step(epoch)

Training Epoch 1/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 1/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 2/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 2/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 3/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 3/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 4/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 4/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 5/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 5/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 6/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 6/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 7/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 7/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 8/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 8/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 9/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 9/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 10/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 10/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 11/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 11/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 12/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 12/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 13/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 13/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 14/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 14/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 15/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 15/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 16/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 16/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 17/20:   0%|          | 0/125 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/opt/conda/lib/python3.7/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/opt/conda/lib/python3.7/shutil.py", line 498, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/opt/conda/lib/python3.7/shutil.py", line 496, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-2eezg_ow'


Validation Epoch 17/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 18/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 18/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 19/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 19/20:   0%|          | 0/29 [00:00<?, ?it/s]

Training Epoch 20/20:   0%|          | 0/125 [00:00<?, ?it/s]

Validation Epoch 20/20:   0%|          | 0/29 [00:00<?, ?it/s]

In [11]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.088014…

0,1
Train/Accuracy,▁▅▆▆▇▇▇▇████████████
Train/Loss,█▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Validation/Accuracy,▁▁▃▆▆▇▇▇▇██████▇▇█▇█
Validation/Loss,█▅▄▃▃▂▂▂▂▁▁▂▂▁▂▂▂▁▂▂

0,1
Train/Accuracy,0.94838
Train/Loss,0.14137
Validation/Accuracy,0.88436
Validation/Loss,0.53636
