In [154]:
import json
import torch
import wandb
import numpy as np 
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
from dataloader import ShapeNetDataset
from utils import load_data, get_loss 
from visualization import visualize
import torch.nn.init as init
from easydict import EasyDict
from collections import Counter
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import StratifiedKFold

from model import PointNet

In [155]:
config = EasyDict({
    'TRAIN_DATA_PATH' : 'shapenetpart_hdf5_2048/train0.h5',
    'VALID_DATA_PATH' : 'shapenetpart_hdf5_2048/val0.h5',
    'TEST_DATA_PATH' : 'shapenetpart_hdf5_2048/test0.h5',
    'DEVICE' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'BATCH_SIZE' : 32,
    'EPOCHS' : 200,
    'LEARNING_RATE' : 0.001,
    'DROP_OUT' : 0.5,
    'SEG_NUM_ALL' : 50,
    'WEIGHT_DECAY' : 0.0001,
    'K' : 40
})

In [156]:
### 전체 17000개 데이터, 각각 2048개의 point로 이루어져 있음
shapenetpart_cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 
                       'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 
                       'motorbike': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15}

shapenetpart_seg_num= [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] # 전체 50개의 segmentation class

shapenetpart_seg_start_index = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]

In [157]:
with open('shapenetpart_hdf5_2048/train0_id2file.json') as json_file:
    train_id = json.load(json_file)
with open('shapenetpart_hdf5_2048/val0_id2file.json') as json_file:
    valid_id = json.load(json_file)
    
with open('shapenetpart_hdf5_2048/train0_id2name.json') as json_file:
    train_name = json.load(json_file)
with open('shapenetpart_hdf5_2048/val0_id2name.json') as json_file:
    valid_name = json.load(json_file)

In [158]:
train_df = pd.DataFrame({'path' : train_id, 'label': train_name})
train_df['segmentation_part_num'] = train_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])

valid_df = pd.DataFrame({'path' : valid_id, 'label': valid_name})
valid_df['segmentation_part_num'] = valid_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])

print('#'*30)
print(f"Label distribution:\n{train_df['label'].value_counts()}\n")
print('#'*30)
print(f'Total data: {train_df.shape[0]}')
print('#'*30)
train_df.head()

##############################
Label distribution:
table         615
chair         454
airplane      335
lamp          198
car           120
guitar         88
laptop         57
knife          41
pistol         37
mug            27
skateboard     20
motorbike      19
bag            12
rocket         11
cap             9
earphone        5
Name: label, dtype: int64

##############################
Total data: 2048
##############################


Unnamed: 0,path,label,segmentation_part_num
0,02691156/points/d4d61a35e8b568fb7f1f82f6fc8747...,airplane,4
1,03636649/points/eee7062babab62aa8930422448288e...,lamp,4
2,04379243/points/90992c45f7b2ee7d71a48b5339c6e0...,table,3
3,02691156/points/a3c928995562fca8ca8607f540cc62...,airplane,4
4,03636649/points/85335cc8e6ac212a3834555ce6c51f...,lamp,4


In [159]:
# Load Data

train_data, train_label, train_seg = load_data(config['TRAIN_DATA_PATH'])
valid_data, valid_label, valid_seg = load_data(config['VALID_DATA_PATH'])

In [160]:
# For quick experiment, get 20% data of the total valid set. (Total valid set has 1870 data.)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for i, (_, valid_index) in enumerate(folds.split(valid_label,valid_label.reshape(-1,))):
    valid_part_data = valid_data[valid_index]
    valid_part_label = valid_label[valid_index]
    valid_part_seg = valid_seg[valid_index]
    valid_part_df = valid_df.iloc[valid_index].reset_index(drop=True)
    break
print(f'Total valid data: {valid_data.shape[0]}')
print(f'Valid data to use: {valid_part_data.shape[0]}')

Total valid data: 1870
Valid data to use: 374


### Visualization

In [118]:
# visualize_index = 0
# visualize(train_data[visualize_index],
#           train_seg[visualize_index],
#           train_df.loc[visualize_index, 'label'],
#           train_df.loc[visualize_index, 'segmentation_part_num'],
#           train_df.loc[visualize_index, 'label'])

### Metrics

In [141]:
def calculate_shape_IoU(pred_np, seg_np, label):
    
    label = label.squeeze()
    shape_ious = []
    
    for shape_idx in range(seg_np.shape[0]): 
        # class별 segmentation index가 다르므로, 해당 class의 index 범위로 설정해줘야 한다.
        start_index = shapenetpart_seg_start_index[label[shape_idx]]
        num = shapenetpart_seg_num[label[shape_idx]]
        parts = range(start_index, start_index + num)
        part_ious = []
        
        for part in parts:
            I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
            U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
            if U == 0:
                iou = 1  # If the union of groundtruth and prediction points is empty, then count part IoU as 1
            else:
                iou = I / float(U)
            part_ious.append(iou)
        shape_ious.append(np.mean(part_ious))
    return shape_ious

### Train

In [78]:
def train(config, wandb):
    
    wandb_init(config, wandb)
    
    train_dataset = ShapeNetDataset(train_df, train_data, train_label, train_seg)
    valid_dataset = ShapeNetDataset(valid_part_df, valid_part_data, valid_part_label, valid_part_seg)

    train_loader  = DataLoader(train_dataset, batch_size=config['BATCH_SIZE'], shuffle=True, drop_last=True, num_workers=4)
    valid_loader  = DataLoader(valid_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, drop_last=True, num_workers=4)
    
    
    model = PointNet().to(config['DEVICE'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'], weight_decay=config['WEIGHT_DECAY'])
    criterion = get_loss()
    
    device = config['DEVICE']
    best_valid_iou = 0.0

    print(f'{"#"*30} Start Training {"#"*30}')
    for epoch in range(config['EPOCHS']):
        ############
        ## TRAIN  ##
        ############
        running_loss = 0.0
        model.train()
        total_step = len(train_loader)

        train_true_cls = []
        train_pred_cls = []
        train_true_seg = []
        train_pred_seg = []
        train_label_seg = []

        for i, (point, label, label_one_hot, seg, c_name) in enumerate(train_loader):

            point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
            # point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

            seg_pred, trans_feat = model(point, label_one_hot) # [32, 2048, 50], [32, 128, 128]
            loss = criterion(seg_pred.contiguous().view(-1, config['SEG_NUM_ALL']),  # [65536, 50]
                 seg.view(-1, 1)[:, 0], # [65536] (batchsizex2048)
                 trans_feat)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            pred = seg_pred.max(dim=2)[1]
            seg_np = seg.cpu().numpy()  # (batch_size, num_points)
            pred_np = pred.detach().cpu().numpy() # (batch_size, num_points)

            # For Accuracy computation
            train_true_cls.append(seg_np.reshape(-1)) # (batch_size * num_points, )
            train_pred_cls.append(pred_np.reshape(-1)) # (batch_size * num_points, )

            train_true_seg.append(seg_np)
            train_pred_seg.append(pred_np)
            train_label_seg.append(label.reshape(-1))


        train_true_cls = np.concatenate(train_true_cls)
        train_pred_cls = np.concatenate(train_pred_cls)
        train_accuracy = sum(train_pred_cls == train_true_cls) / len(train_pred_cls) # Accuracy

        train_true_seg = np.concatenate(train_true_seg, axis=0)
        train_pred_seg = np.concatenate(train_pred_seg, axis=0)
        train_label_seg = np.concatenate(train_label_seg)
        train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg) # IoU
        torch.cuda.empty_cache()

        ############
        ## VALID  ##
        ############    
        valid_loss = 0.0
        model.eval()
        valid_true_cls = []
        valid_pred_cls = []
        valid_true_seg = []
        valid_pred_seg = []
        valid_label_seg = []

        with torch.no_grad():
            for i, (point, label, label_one_hot, seg, c_name) in enumerate(valid_loader):
                
                point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
                point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
                
                seg_pred, trans_feat = model(point, label_one_hot) ###            
                loss = criterion(seg_pred.contiguous().view(-1, config['SEG_NUM_ALL']), # [65536, 50]
                                 seg.view(-1, 1)[:, 0], # [65536] (batchsizex2048)
                                 trans_feat) # [32, 128, 128]
                valid_loss += loss.item()
                
                pred = seg_pred.max(dim=2)[1] # 65536

                seg_np = seg.cpu().numpy()
                pred_np = pred.detach().cpu().numpy()

                valid_true_cls.append(seg_np.reshape(-1))
                valid_pred_cls.append(pred_np.reshape(-1))

                valid_true_seg.append(seg_np)
                valid_pred_seg.append(pred_np)
                valid_label_seg.append(label.reshape(-1))


            valid_true_cls = np.concatenate(valid_true_cls)
            valid_pred_cls = np.concatenate(valid_pred_cls)
            valid_accuracy = sum(valid_pred_cls == valid_true_cls) / len(valid_pred_cls) # Accuracy

            valid_true_seg = np.concatenate(valid_true_seg, axis=0)
            valid_pred_seg = np.concatenate(valid_pred_seg, axis=0)
            valid_label_seg = np.concatenate(valid_label_seg)
            valid_ious = calculate_shape_IoU(valid_pred_seg, valid_true_seg, valid_label_seg) # IoU

            print("Epoch: {}/{}.. ".format(epoch + 1, config['EPOCHS']) +
                                  "Loss: {:.5f}.. ".format(running_loss / total_step) +
                                  "IoU: {:.5f}.. ".format(np.mean(train_ious)) + 
                                  "Accuracy: {:.5f}.. ".format(train_accuracy) + 
                                  "Valid Loss: {:.5f}.. ".format(valid_loss / len(valid_loader)) +
                                  "Valid Accuracy: {:.5f}.. ".format(valid_accuracy) +
                                  "Valid IoU: {:.5f}..".format(np.mean(valid_ious)))
            wandb.log({
                'Train/Loss': (running_loss / total_step), 
                'Train/IoU': np.mean(train_ious), 
                'Train/Accuracy': train_accuracy,
                'Valid/Loss': (valid_loss / len(valid_loader)), 
                'Valid/IoU': valid_accuracy, 
                'Valid/Accuracy': np.mean(valid_ious),                
            })

        # Early Stopping
        if np.mean(valid_ious) >= best_valid_iou:
            print('haha')
            best_valid_iou = np.mean(valid_ious)
            torch.save(model.state_dict(), f'checkpoints_pointnet/epoch{str(epoch + 1).zfill(3)}_seg.tar')  
        torch.cuda.empty_cache()

In [79]:
def wandb_init(config, wandb):
    wandb.init(project="pointcloud-segmentation", entity="sseunghyun", name = f"PointNet")
    wandb.config.update({
    "Epochs": config["EPOCHS"],
    "learning_rate": config["LEARNING_RATE"],
    "batch_size": config["BATCH_SIZE"],
    "Weight decay": config["WEIGHT_DECAY"],
    "K": config["K"],
    })

In [1]:
train(config, wandb)

### Test

In [148]:
def test(config):
    with open('shapenetpart_hdf5_2048/test0_id2file.json') as json_file:
        test_id = json.load(json_file)
    with open('shapenetpart_hdf5_2048/test0_id2name.json') as json_file:
        test_name = json.load(json_file)
        
    test_df = pd.DataFrame({'path' : test_id, 'label': test_name})
    test_df['segmentation_part_num'] = test_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])
    test_data, test_label, test_seg = load_data(config['TEST_DATA_PATH'])
    
    test_dataset = ShapeNetDataset(test_df, test_data, test_label, test_seg)
    test_loader  = DataLoader(test_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=4)

    model = PointNet().to(config['DEVICE'])
    model.load_state_dict(torch.load('checkpoints_pointnet/epoch192_seg.tar', map_location=config['DEVICE']))
    model.eval()
    
    device = config['DEVICE']

    test_true_cls = []
    test_pred_cls = []
    test_true_seg = []
    test_pred_seg = []
    test_label_seg = []
    test_point = []
    with torch.no_grad():
        for i, (point, label, label_one_hot, seg, c_name) in tqdm(enumerate(test_loader), total=len(test_loader)):
            point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
            point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
            seg_prediction, _ = model(point, label_one_hot) 
            # seg_prediction = seg_prediction.permute(0, 2, 1).contiguous()

            pred = seg_prediction.max(dim=2)[1]
            seg_np = seg.cpu().numpy()
            pred_np = pred.detach().cpu().numpy()

            test_true_cls.append(seg_np.reshape(-1))
            test_pred_cls.append(pred_np.reshape(-1))

            test_true_seg.append(seg_np)
            test_pred_seg.append(pred_np)
            test_label_seg.append(label.reshape(-1))

            test_point.append(point.permute(0,2,1).detach().cpu().numpy())

        test_true_cls = np.concatenate(test_true_cls)
        test_pred_cls = np.concatenate(test_pred_cls)
        test_accuracy = sum(test_pred_cls == test_true_cls) / len(test_pred_cls) # Accuracy

        test_true_seg = np.concatenate(test_true_seg, axis=0)
        test_pred_seg = np.concatenate(test_pred_seg, axis=0)
        test_label_seg = np.concatenate(test_label_seg)
        test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg) # IoU

        test_point = np.concatenate(test_point)
    print('Inference fin... ')
    # print("Test IoU: {:.5f}.. ".format(np.mean(test_ious)) + 
    #       "Test Accuracy: {:.5f}.. ".format(test_accuracy))   
    
    return test_true_seg, test_pred_seg, test_point,  test_df, test_ious

In [149]:
test_true_seg, test_pred_seg, test_point, test_df, test_ious = test(config)

100%|██████████| 64/64 [00:04<00:00, 15.66it/s]


Inference fin... 


### Prediction result

In [153]:
idx = 10
visualize(test_point[idx], 
          test_pred_seg[idx],
          test_df.loc[idx, 'label'],
          test_df.loc[idx, 'segmentation_part_num'],
          'infer_vis/_pred')

In [129]:
# visualize(test_point[idx], 
#           test_true_seg[idx],
#           test_df.loc[idx, 'label'],
#           test_df.loc[idx, 'segmentation_part_num'],
#           'infer_vis/chair_true')
# print(f'Test IoU: {test_ious[idx]}')