In [1]:
import torch
import pickle
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d as BN
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
from torch_geometric.data import DataLoader, Batch, Data
from BSA_model import BSANet
from chamfer_distance import ChamferDistance
import open3d as o3d

  points = torch.tensor(np.meshgrid(x, y), dtype=torch.float32)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# Load data

In [2]:
class Dataset(Dataset):
    def __init__(self, pc_in_file, pc_out_file, img_file, transform=None):
        self.imgs = np.load(img_file)
        self.pcs_in = np.load(pc_in_file)
        self.pcs_out = np.load(pc_out_file)
        self.transform = transform

    def __len__(self):
        return self.pcs_in.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img = self.imgs[idx]
        img = torch.Tensor(img)
        img = img.permute(2,0,1)
        
        pc_in = self.pcs_in[idx]
        pc_in = torch.Tensor(pc_in)
        pc_in = Data(pos=pc_in)
        
        pc_out = self.pcs_out[idx]
        pc_out = torch.Tensor(pc_out)
        pc_out = Data(pos=pc_out)
        
        sample = {"pc_in": pc_in, "pc_out": pc_out, 'img': img}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [3]:
with open('data.pickle', 'rb') as handle:
    dataset = pickle.load(handle)

In [4]:
test_set = dataset['test']
batch_size = 4
test_dataloader = DataLoader(test_set, batch_size=batch_size,
                    shuffle=True)



# Cuda

In [5]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pre-trained model

In [9]:
model = BSANet()
model.load_state_dict(torch.load("trained/BSA-Net_2000_5.7170_5.6984.pt"))
model = model.to(device)
criterion = ChamferDistance()

# Evaluation

In [10]:
def evaluation():
    model.eval()
    total_loss = 0
    
    for data in test_dataloader:      
        pc_out = data['pc_out']
        pc_out = pc_out.to(device)
        
        with torch.no_grad():
            decoded = model(data)
            dist1, dist2 = criterion(decoded.reshape(-1,2048,3), pc_out.pos.reshape(-1,2048,3))
            loss = (torch.mean(dist1)) + (torch.mean(dist2))
            total_loss += loss.item() * pc_out.num_graphs
            
    return total_loss/len(test_set)

In [11]:
evaluation()

5.7249289258321125