In [None]:
import io
import json

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import open3d as o3d

from torch.utils.data import Dataset, DataLoader

In [None]:
import os
from torchvision.io import read_image

class MyLidarDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.pcg_dir = os.path.join(data_dir, 'pcg')
        labels_dir = os.path.join(data_dir, 'labels')
        self.label_list = os.listdir(labels_dir)

    def __getitem__(self, index):
        label_path = os.path.join(self.data_dir, self.label_list[index])
        file_name_with_postfix = label_path.split('/')[-1]
        file_name = file_name_with_postfix.split('.')[0]
        pcg_path = os.path.join(self.pcg_dir, f'./{file_name}.pcg')
        class_name = []
        annotation = []
        
        with io.open(label_path, 'r') as f:
            j = json.load(f)
            for c in j["Annotation"]:
                cname = c["class_name"]
                data = c["data"]
                
                class_name.append(cname)
                annotation.append(data)
        
        pcd = o3d.io.read_point_cloud(pcg_path)
        downpcd = pcd.voxel_down_sample(voxel_size=0.05)
        pos = downpcd.point.positions

        return class_name, annotation, pos

    def __len__(self):
        return len(self.label_list)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_path = r"G:\set1\training\labels"
test_path = r"G:\set1\testing\labels"
train_set = MyLidarDataset(train_path)
test_set = MyLidarDataset(test_path)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=True)




In [None]:
from torch import nn

class ContextEncoder(nn.Module):
    def __init__(self):
        super(ContextEncoder, self).__init__()
        self.fc1 = nn.Linear(in_features=512, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=32)
        self.fc5 = nn.Linear(in_features=32, out_features=16)
        self.fc6 = nn.Linear(in_features=16, out_features=8)
        self.fc7 = nn.Linear(in_features=8, out_features=4)
        self.fc8 = nn.Linear(in_features=4, out_features=2)
        self.relu = nn.ReLU()
        

class Prediction(nn.Module):
    