In [None]:
import argparse
import os
import torch
import datetime
import logging
from pathlib import Path
import sys
import importlib
import shutil
from tqdm import tqdm
import numpy as np
import time
from torch.nn.utils import clip_grad_norm_

import matplotlib.pyplot as plt


def save_metrics(training_mean_loss, training_accuracy, eval_mean_loss, eval_accuracy, file_path):
    with open(file_path, 'a') as f:
        f.write(f"{training_mean_loss[-1]},{training_accuracy[-1]},{eval_mean_loss[-1]},{eval_accuracy[-1]}\n")

def plot_and_save_metrics(training_mean_loss, training_accuracy, eval_mean_loss, eval_accuracy, experiment_dir):
    # Move tensors to CPU
    training_mean_loss_cpu = [val.item() for val in training_mean_loss]
    training_accuracy_cpu = [val.item() for val in training_accuracy]
    eval_mean_loss_cpu = [val.item() for val in eval_mean_loss]
    eval_accuracy_cpu = [val.item() for val in eval_accuracy]

    # Plot training mean loss
    plt.figure(figsize=(10, 6))
    plt.plot(training_mean_loss_cpu, label='Training Mean Loss', color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('Mean Loss')
    plt.title('Training Mean Loss per Epoch')
    plt.legend()
    plt.grid(True)
    plt.savefig(str(experiment_dir) + '/training_mean_loss.png')
    plt.close()

    # Plot training accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(training_accuracy_cpu, label='Training Accuracy', color='green')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy per Epoch')
    plt.legend()
    plt.grid(True)
    plt.savefig(str(experiment_dir) + '/training_accuracy.png')
    plt.close()

    # Plot evaluation mean loss
    plt.figure(figsize=(10, 6))
    plt.plot(eval_mean_loss_cpu, label='Evaluation Mean Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Mean Loss')
    plt.title('Evaluation Mean Loss per Epoch')
    plt.legend()
    plt.grid(True)
    plt.savefig(str(experiment_dir) + '/eval_mean_loss.png')
    plt.close()

    # Plot evaluation accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(eval_accuracy_cpu, label='Evaluation Accuracy', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Evaluation Accuracy per Epoch')
    plt.legend()
    plt.grid(True)
    plt.savefig(str(experiment_dir) + '/eval_accuracy.png')
    plt.close()
# Usage:
# plot_and_save_metrics(training_mean_loss, training_accuracy, eval_mean_loss, eval_accuracy, experiment_dir)



# Import the necessary modules here
from data_utils.S3DISDataLoader import S3DISDataset
import provider

# Adjusted BASE_DIR and ROOT_DIR setup
BASE_DIR = os.getcwd()  # Use the current working directory
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

# Your class and function definitions remain unchanged...
classes = ['ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 'bookcase',
           'board', 'clutter']
class2label = {cls: i for i, cls in enumerate(classes)}
seg_classes = class2label
seg_label_to_cat = {}
for i, cat in enumerate(seg_classes.keys()):
    seg_label_to_cat[i] = cat

def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

def parse_args():
    """Parse command line arguments or set default values if running in a Jupyter notebook/IPython."""
    if sys.argv[0].endswith('ipykernel_launcher.py') or 'ipykernel' in sys.argv[0]:
        # Directly define args for Jupyter/Colab notebooks or IPython
        class Args:
            model = 'pointnet2_sem_seg_tran'
            batch_size = 16
            epoch = 32
            learning_rate = 0.001
            gpu = '0'
            optimizer = 'Adam'
            log_dir = 'pointnet2_sem_seg_tran'
            decay_rate = 1e-4
            npoint = 4096
            step_size = 10
            lr_decay = 0.7
            test_area = 2
        return Args()
    else:
        # Original argparse code for command-line execution
        parser = argparse.ArgumentParser('Model')
        parser.add_argument('--model', type=str, default='pointnet_sem_seg', help='model name [default: pointnet_sem_seg]')
        parser.add_argument('--batch_size', type=int, default=16, help='Batch Size during training [default: 16]')
        # Add the rest of your arguments here
        return parser.parse_args()





def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('sem_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = '/home/puoza/p2/Pointnet_Pointnet2_pytorch/new'
    NUM_CLASSES = 13
    NUM_POINT = args.npoint
    BATCH_SIZE = args.batch_size

    print("start loading training data ...")
    TRAIN_DATASET = S3DISDataset(split='train', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
    print("start loading test data ...")
    TEST_DATASET = S3DISDataset(split='test', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)

    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=10,
                                                  pin_memory=True, drop_last=True,
                                                  worker_init_fn=lambda x: np.random.seed(x + int(time.time())))
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, num_workers=10,
                                                 pin_memory=True, drop_last=True)
    weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda()

    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))

    '''MODEL LOADING'''
    MODEL = importlib.import_module(args.model)
    shutil.copy('models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('models/pointnet2_utils.py', str(experiment_dir))

    classifier = MODEL.get_model(NUM_CLASSES).cuda()
    criterion = MODEL.get_loss().cuda()
    classifier.apply(inplace_relu)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        classifier = classifier.apply(weights_init)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = args.step_size

    global_epoch = 0
    best_iou = 0

    for epoch in range(start_epoch, args.epoch):
        '''Train on chopped scenes'''
        log_string('**** Epoch %d (%d/%s) ****' % (global_epoch + 1, epoch + 1, args.epoch))
        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
        log_string('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
        if momentum < 0.01:
            momentum = 0.01
        print('BN momentum updated to: %f' % momentum)
        classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
        num_batches = len(trainDataLoader)
        total_correct = 0
        total_seen = 0
        loss_sum = 0
        classifier = classifier.train()

        for i, (points, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()

            points = points.data.numpy()
            points[:, :, :3] = provider.rotate_point_cloud_z(points[:, :, :3])
            points = torch.Tensor(points)
            points, target = points.float().cuda(), target.long().cuda()
            points = points.transpose(2, 1)

            seg_pred, trans_feat = classifier(points)
            seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)

            batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
            target = target.view(-1, 1)[:, 0]
            loss = criterion(seg_pred, target, trans_feat, weights)
            loss.backward()
            optimizer.step()

            pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
            correct = np.sum(pred_choice == batch_label)
            total_correct += correct
            total_seen += (BATCH_SIZE * NUM_POINT)
            loss_sum += loss
        log_string('Training mean loss: %f' % (loss_sum / num_batches))
        log_string('Training accuracy: %f' % (total_correct / float(total_seen)))

        if epoch % 5 == 0:
            logger.info('Save model...')
            savepath = str(checkpoints_dir) + '/model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)
            log_string('Saving model....')

        '''Evaluate on chopped scenes'''
        with torch.no_grad():
            num_batches = len(testDataLoader)
            total_correct = 0
            total_seen = 0
            loss_sum = 0
            labelweights = np.zeros(NUM_CLASSES)
            total_seen_class = [0 for _ in range(NUM_CLASSES)]
            total_correct_class = [0 for _ in range(NUM_CLASSES)]
            total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]
            classifier = classifier.eval()

            log_string('---- EPOCH %03d EVALUATION ----' % (global_epoch + 1))
            for i, (points, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
                points = points.data.numpy()
                points = torch.Tensor(points)
                points, target = points.float().cuda(), target.long().cuda()
                points = points.transpose(2, 1)

                seg_pred, trans_feat = classifier(points)
                pred_val = seg_pred.contiguous().cpu().data.numpy()
                seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)

                batch_label = target.cpu().data.numpy()
                target = target.view(-1, 1)[:, 0]
                loss = criterion(seg_pred, target, trans_feat, weights)
                loss_sum += loss
                pred_val = np.argmax(pred_val, 2)
                correct = np.sum((pred_val == batch_label))
                total_correct += correct
                total_seen += (BATCH_SIZE * NUM_POINT)
                tmp, _ = np.histogram(batch_label, range(NUM_CLASSES + 1))
                labelweights += tmp

                for l in range(NUM_CLASSES):
                    total_seen_class[l] += np.sum((batch_label == l))
                    total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l))
                    total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)))

            labelweights = labelweights.astype(float) / np.sum(labelweights.astype(float))
            mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=float) + 1e-6))
            log_string('eval mean loss: %f' % (loss_sum / float(num_batches)))
            log_string('eval point avg class IoU: %f' % (mIoU))
            log_string('eval point accuracy: %f' % (total_correct / float(total_seen)))
            log_string('eval point avg class acc: %f' % (
                np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=float) + 1e-6))))

            iou_per_class_str = '------- IoU --------\n'
            for l in range(NUM_CLASSES):
                iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
                    seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1],
                    total_correct_class[l] / float(total_iou_deno_class[l]))

            log_string(iou_per_class_str)
            log_string('Eval mean loss: %f' % (loss_sum / num_batches))
            log_string('Eval accuracy: %f' % (total_correct / float(total_seen)))

            if mIoU >= best_iou:
                best_iou = mIoU
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': epoch,
                    'class_avg_iou': mIoU,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
                log_string('Saving model....')
            log_string('Best mIoU: %f' % best_iou)
        global_epoch += 1
        #plot_and_save_metrics(training_mean_loss, training_accuracy, eval_mean_loss, eval_accuracy, experiment_dir)
        #save_metrics(training_mean_loss, training_accuracy, eval_mean_loss, eval_accuracy, experiment_dir)
if __name__ == '__main__':
    
    args = parse_args()
    main(args)



PARAMETER ...
<__main__.parse_args.<locals>.Args object at 0x155550142d30>
start loading training data ...


100%|██████████| 30/30 [00:02<00:00, 13.19it/s]


[1.0872371 1.1862054 1.        1.7367345 2.126663  2.0220127 1.788688
 1.8918622 2.0941825 4.048591  1.7390136 2.4426463 1.2522498]
Totally 7363 samples in train set.
start loading test data ...


100%|██████████| 8/8 [00:00<00:00, 21.11it/s]
  self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)


[1.3699402 1.4154122 1.              inf 7.784936        inf 1.9372996
 2.1385393 2.7646968       inf 2.5357857       inf 1.3914901]
Totally 1180 samples in test set.
The number of training data is: 7363
The number of test data is: 1180
Use pretrain model
**** Epoch 1 (1/32) ****
Learning rate:0.001000
BN momentum updated to: 0.100000


100%|██████████| 460/460 [04:50<00:00,  1.58it/s]

Training mean loss: 0.731074
Training accuracy: 0.773577
Saving at log/sem_seg/pointnet2_sem_seg_tran/checkpoints/model.pth
Saving model....
---- EPOCH 001 EVALUATION ----



100%|██████████| 73/73 [00:32<00:00,  2.24it/s]

eval mean loss: 1.691470
eval point avg class IoU: 0.247397
eval point accuracy: 0.568435
eval point avg class acc: 0.334346
------- IoU --------
class ceiling        weight: 0.164, IoU: 0.504 
class floor          weight: 0.170, IoU: 0.814 
class wall           weight: 0.149, IoU: 0.528 
class beam           weight: 0.367, IoU: 0.000 
class column         weight: 0.000, IoU: 0.000 
class window         weight: 0.001, IoU: 0.000 
class door           weight: 0.000, IoU: 0.479 
class table          weight: 0.056, IoU: 0.340 
class chair          weight: 0.047, IoU: 0.238 
class sofa           weight: 0.020, IoU: 0.000 
class bookcase       weight: 0.000, IoU: 0.111 
class board          weight: 0.026, IoU: 0.000 
class clutter        weight: 0.000, IoU: 0.204 

Eval mean loss: 1.691470
Eval accuracy: 0.568435
Saving at log/sem_seg/pointnet2_sem_seg_tran/checkpoints/best_model.pth
Saving model....
Best mIoU: 0.247397
**** Epoch 2 (2/32) ****
Learning rate:0.001000
BN momentum updated to:


100%|██████████| 460/460 [04:51<00:00,  1.58it/s]

Training mean loss: 0.549866
Training accuracy: 0.824319
---- EPOCH 002 EVALUATION ----



100%|██████████| 73/73 [00:32<00:00,  2.24it/s]

eval mean loss: 1.725555
eval point avg class IoU: 0.262722
eval point accuracy: 0.591020
eval point avg class acc: 0.348289
------- IoU --------
class ceiling        weight: 0.157, IoU: 0.520 
class floor          weight: 0.165, IoU: 0.802 
class wall           weight: 0.142, IoU: 0.524 
class beam           weight: 0.385, IoU: 0.000 
class column         weight: 0.000, IoU: 0.000 
class window         weight: 0.001, IoU: 0.000 
class door           weight: 0.000, IoU: 0.480 
class table          weight: 0.055, IoU: 0.420 
class chair          weight: 0.043, IoU: 0.257 
class sofa           weight: 0.022, IoU: 0.000 
class bookcase       weight: 0.000, IoU: 0.145 
class board          weight: 0.029, IoU: 0.000 
class clutter        weight: 0.000, IoU: 0.267 

Eval mean loss: 1.725555
Eval accuracy: 0.591020
Saving at log/sem_seg/pointnet2_sem_seg_tran/checkpoints/best_model.pth
Saving model....
Best mIoU: 0.262722
**** Epoch 3 (3/32) ****
Learning rate:0.001000
BN momentum updated to:


100%|██████████| 460/460 [04:50<00:00,  1.59it/s]

Training mean loss: 0.452486
Training accuracy: 0.848800
---- EPOCH 003 EVALUATION ----



100%|██████████| 73/73 [00:32<00:00,  2.23it/s]

eval mean loss: 1.756842
eval point avg class IoU: 0.239607
eval point accuracy: 0.553503
eval point avg class acc: 0.325495
------- IoU --------
class ceiling        weight: 0.158, IoU: 0.277 
class floor          weight: 0.162, IoU: 0.828 
class wall           weight: 0.144, IoU: 0.575 
class beam           weight: 0.386, IoU: 0.000 
class column         weight: 0.000, IoU: 0.000 
class window         weight: 0.001, IoU: 0.000 
class door           weight: 0.000, IoU: 0.496 
class table          weight: 0.058, IoU: 0.330 
class chair          weight: 0.044, IoU: 0.255 
class sofa           weight: 0.019, IoU: 0.000 
class bookcase       weight: 0.000, IoU: 0.148 
class board          weight: 0.029, IoU: 0.000 
class clutter        weight: 0.000, IoU: 0.206 

Eval mean loss: 1.756842
Eval accuracy: 0.553503
Best mIoU: 0.262722
**** Epoch 4 (4/32) ****
Learning rate:0.001000
BN momentum updated to: 0.100000



 28%|██▊       | 131/460 [01:27<03:13,  1.70it/s]

In [None]:
pip install tqdm

In [None]:
import argparse
import os
from pathlib import Path
import torch
import logging
import sys
import importlib
from tqdm import tqdm
import numpy as np

# Import the necessary modules here
from data_utils.S3DISDataLoader import ScannetDatasetWholeScene
from data_utils.indoor3d_util import g_label2color
import provider

BASE_DIR = os.getcwd() 
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

classes = ['ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 'bookcase',
           'board', 'clutter']
class2label = {cls: i for i, cls in enumerate(classes)}
seg_classes = class2label
seg_label_to_cat = {}
for i, cat in enumerate(seg_classes.keys()):
    seg_label_to_cat[i] = cat


def parse_args(argv=None):
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('Model')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--num_point', type=int, default=4096, help='point number [default: 4096]')
    parser.add_argument('--log_dir', type=str, default='pointnet2_sem_seg_msg', help='experiment root')
    parser.add_argument('--visual', action='store_true', default=True, help='visualize result [default: False]')
    parser.add_argument('--test_area', type=int, default=5, help='area for testing, option: 1-6 [default: 5]')
    parser.add_argument('--num_votes', type=int, default=3, help='aggregate segmentation scores with voting [default: 5]')
    if argv is not None:
        return parser.parse_args(argv)
    else:
        return parser.parse_args()


def add_vote(vote_label_pool, point_idx, pred_label, weight):
    B = pred_label.shape[0]
    N = pred_label.shape[1]
    for b in range(B):
        for n in range(N):
            if weight[b, n] != 0 and not np.isinf(weight[b, n]):
                vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1
    return vote_label_pool


def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    experiment_dir = 'log/sem_seg/' + args.log_dir
    visual_dir = experiment_dir + '/visual/'
    visual_dir = Path(visual_dir)
    visual_dir.mkdir(exist_ok=True)

    '''LOG'''
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    NUM_CLASSES = 13
    BATCH_SIZE = args.batch_size
    NUM_POINT = args.num_point

    root = '/home/puoza/p2/Pointnet_Pointnet2_pytorch/stanford_indoor3d/'

    TEST_DATASET_WHOLE_SCENE = ScannetDatasetWholeScene(root, split='test', test_area=args.test_area, block_points=NUM_POINT)
    log_string("The number of test data is: %d" % len(TEST_DATASET_WHOLE_SCENE))

    transformer_input_dim = 1024  # Example value, adjust based on your model's architecture
    transformer_num_heads = 8    # Example value, typically a power of 2
    transformer_hidden_dim = 1024

    '''MODEL LOADING'''
    model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
    MODEL = importlib.import_module(model_name)
    classifier = MODEL.get_model(NUM_CLASSES, transformer_input_dim, transformer_num_heads, 
    transformer_hidden_dim).cuda()
    checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    classifier.load_state_dict(checkpoint['model_state_dict'])
    classifier = classifier.eval()

    with torch.no_grad():
        scene_id = TEST_DATASET_WHOLE_SCENE.file_list
        scene_id = [x[:-4] for x in scene_id]
        num_batches = len(TEST_DATASET_WHOLE_SCENE)

        total_seen_class = [0 for _ in range(NUM_CLASSES)]
        total_correct_class = [0 for _ in range(NUM_CLASSES)]
        total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]

        log_string('---- EVALUATION WHOLE SCENE----')

        for batch_idx in range(num_batches):
            print("Inference [%d/%d] %s ..." % (batch_idx + 1, num_batches, scene_id[batch_idx]))
            total_seen_class_tmp = [0 for _ in range(NUM_CLASSES)]
            total_correct_class_tmp = [0 for _ in range(NUM_CLASSES)]
            total_iou_deno_class_tmp = [0 for _ in range(NUM_CLASSES)]
            if args.visual:
                fout = open(os.path.join(visual_dir, scene_id[batch_idx] + '_pred.obj'), 'w')
                fout_gt = open(os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'), 'w')

            whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[batch_idx]
            whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx]
            vote_label_pool = np.zeros((whole_scene_label.shape[0], NUM_CLASSES))
            for _ in tqdm(range(args.num_votes), total=args.num_votes):
                scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx]
                num_blocks = scene_data.shape[0]
                s_batch_num = (num_blocks + BATCH_SIZE - 1) // BATCH_SIZE
                batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 9))

                batch_label = np.zeros((BATCH_SIZE, NUM_POINT))
                batch_point_index = np.zeros((BATCH_SIZE, NUM_POINT))
                batch_smpw = np.zeros((BATCH_SIZE, NUM_POINT))

                for sbatch in range(s_batch_num):
                    start_idx = sbatch * BATCH_SIZE
                    end_idx = min((sbatch + 1) * BATCH_SIZE, num_blocks)
                    real_batch_size = end_idx - start_idx
                    batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...]
                    batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...]
                    batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...]
                    batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...]
                    batch_data[:, :, 3:6] /= 1.0

                    torch_data = torch.Tensor(batch_data)
                    torch_data = torch_data.float().cuda()
                    torch_data = torch_data.transpose(2, 1)
                    seg_pred, _ = classifier(torch_data)
                    batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy()

                    vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...],
                                               batch_pred_label[0:real_batch_size, ...],
                                               batch_smpw[0:real_batch_size, ...])

            pred_label = np.argmax(vote_label_pool, 1)

            for l in range(NUM_CLASSES):
                total_seen_class_tmp[l] += np.sum((whole_scene_label == l))
                total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l))
                total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l)))
                total_seen_class[l] += total_seen_class_tmp[l]
                total_correct_class[l] += total_correct_class_tmp[l]
                total_iou_deno_class[l] += total_iou_deno_class_tmp[l]

            iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=float) + 1e-6)
            print(iou_map)
            arr = np.array(total_seen_class_tmp)
            tmp_iou = np.mean(iou_map[arr != 0])
            log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou))
            print('----------------------------')

            filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt')
            with open(filename, 'w') as pl_save:
                for i in pred_label:
                    pl_save.write(str(int(i)) + '\n')
                pl_save.close()
            for i in range(whole_scene_label.shape[0]):
                color = g_label2color[pred_label[i]]
                color_gt = g_label2color[whole_scene_label[i]]
                if args.visual:
                    fout.write('v %f %f %f %d %d %d\n' % (
                        whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color[0], color[1],
                        color[2]))
                    fout_gt.write(
                        'v %f %f %f %d %d %d\n' % (
                            whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color_gt[0],
                            color_gt[1], color_gt[2]))
            if args.visual:
                fout.close()
                fout_gt.close()

        IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=float) + 1e-6)
        iou_per_class_str = '------- IoU --------\n'
        for l in range(NUM_CLASSES):
            iou_per_class_str += 'class %s, IoU: %.3f \n' % (
                seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])),
                total_correct_class[l] / float(total_iou_deno_class[l]))
        log_string(iou_per_class_str)
        log_string('eval point avg class IoU: %f' % np.mean(IoU))
        log_string('eval whole scene point avg class acc: %f' % (
            np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=float) + 1e-6))))
        log_string('eval whole scene point accuracy: %f' % (
                np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6)))

        print("Done!")


if __name__ == '__main__':
    import sys
    if 'ipykernel' in sys.modules:
        # Running in a Jupyter Notebook, skip argument parsing
        args = parse_args([])
    else:
        # Running as the main program, parse command-line arguments
        args = parse_args()
    main(args)
