<a href="https://colab.research.google.com/github/oscarwilkins1707/DeepLearningThesis/blob/main/PTV2_Autoencoder_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##1) Install Required Libraries

In [2]:
import torch

!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html #Use prebuilt wheels to make code faster
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git
!pip install einops

[0mLooking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_scatter-2.1.2%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m49.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt26cu124
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_cluster-1.6.3%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.3+pt26cu124
Collecting git+https://github.com/pyg-team/pytorch_geometric.git
  Cloning https://github.com/pyg-team/pytorch_ge

In [3]:
!set CUDA_LAUNCH_BLOCKING = 1
!set TORCH_USE_CUDA_DSA = 1

In [4]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/

from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import plotly.express as px
import natsort
import pickle
import torch.nn as nn
from torch_cluster import knn
import einops
from copy import deepcopy
import math
from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr, composite
from torch.utils.data import Dataset, DataLoader
import gc
import time
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

Mounted at /content/gdrive
/content/gdrive/MyDrive


##2) Define Classes for PTV2 Algorithm

####PointTransformerV2 Set Up Classes

In [221]:
def offset2batch(offset):
    return torch.cat([torch.tensor([i] * (o - offset[i - 1])) if i > 0 else
                      torch.tensor([i] * o) for i, o in enumerate(offset)],
                     dim=0).long().to(offset.device)

def batch2offset(batch):
    return torch.cumsum(batch.bincount(), dim=0).long()

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def inpdict_to_point(inp_dict):
  enc_ip = inp_dict['enc_inp']
  offset = inp_dict['inp_batch_ids']
  coords = enc_ip[:,:3]
  feat = enc_ip[:,3]
  return coords,feat,offset

def grouping(idx,
             feat,
             xyz,
             new_xyz=None,
             with_xyz=False):

    """
    Figure out whatever tf this function does.
    """

    #Added this to avoid errors. Is that ok?
    xyz = xyz.contiguous()
    feat = feat.contiguous()

    if new_xyz is None:
        new_xyz = xyz

    assert xyz.is_contiguous()
    assert feat.is_contiguous()

    m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1]
    xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0)
    feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0)
    grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c)  # (m, num_sample, c)

    if with_xyz:
        assert new_xyz.is_contiguous()
        mask = torch.sign(idx + 1)
        grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) - new_xyz.unsqueeze(1)  # (m, num_sample, 3)
        grouped_xyz = torch.einsum("n s c, n s -> n s c", grouped_xyz, mask)  # (m, num_sample, 3)

        return torch.cat((grouped_xyz, grouped_feat), -1)
    else:
        return grouped_feat

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """
    Zeros elements of the input tensor (x) with probability `drop_prob` during training.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob) #Increases magnitude of non-zeros elements to have same magnitude of points going into architecture
    return x * random_tensor


In [165]:
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        """
        Returns Tensor with zeroed out elements based on drop_prob.
        """
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'

class FeatureAggregation(nn.Module):
  def __init__(self, in_channels, dmodel):
    super(FeatureAggregation, self).__init__()
    self.q = nn.Parameter(torch.randn(1, dmodel).float())
    self.v = nn.Conv1d(in_channels = in_channels,
                       out_channels = int(in_channels*dmodel),
                       kernel_size = 1,
                       groups = in_channels)
    self.k = nn.Conv1d(in_channels = in_channels,
                       out_channels = int(in_channels*dmodel),
                       kernel_size = 1,
                       groups = in_channels)

  def forward(self, points):
    """f
    replaces features in points with aggregated features (based on a feature attention operation).

    Inputs:
    points (list): [coords, feats, batch]

    Outputs:
    points (list): [coords, feats, batch]
    """
    coords, feats, batch = points
    BK, fc = feats.shape
    feat_v = self.v(feats.float().unsqueeze(2)).reshape(BK, fc, -1)
    feat_k = self.k(feats.float().unsqueeze(2)).reshape(BK, fc, -1)

    qk = torch.einsum('qd, bkd -> bqk', self.q, feat_k)
    attn = nn.Softmax(dim=-1)(qk)
    out = torch.einsum('bkd, bqk -> bqd', feat_v, attn)

    return [coords, out.squeeze(1),  batch]

class PointAggregation(nn.Module):
    def __init__(self, dmodel):
        super(PointAggregation, self).__init__()
        #self.configs = configs

        self.q = nn.Parameter(torch.randn(1, dmodel).float())
        self.v = nn.Linear(dmodel, dmodel)
        self.k = nn.Linear(dmodel, dmodel)

    def forward(self, feats, batch):
          offset = batch2offset(batch)
          offset = torch.cat([offset.new_zeros(1), offset])
          #print(feats.shape, batch.shape)
          feats_v = self.v(feats.float())
          feats_k = self.k(feats.float())

          qk = torch.einsum('qd,kd -> kq', self.q, feats_k)
          attn = composite.scatter_softmax(qk, batch.long(), dim = 0)
          out = segment_csr(torch.einsum('qk,qd -> qd', attn, feats_v), offset)

          return out

In [232]:
class GroupedLinear(nn.Module):
    __constants__ = ['in_features', 'out_features', "groups"]
    in_features: int
    out_features: int
    groups: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, groups: int,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(GroupedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        assert in_features & groups == 0
        assert out_features % groups == 0
        # for convenient, currently only support out_features == groups, one output
        assert out_features == groups
        self.weight = nn.Parameter(torch.empty((1, in_features), **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input * self.weight).reshape(
            list(input.shape[:-1]) + [self.groups, input.shape[-1] // self.groups]).sum(-1)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class PointBatchNorm(nn.Module):
    """
    Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
    """

    def __init__(self, embed_channels):
        super().__init__()
        self.norm = nn.BatchNorm1d(embed_channels)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.dim() == 3:
            return self.norm(input.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
        elif input.dim() == 2:
            return self.norm(input)
        else:
            raise NotImplementedError

class GroupedVectorAttention(nn.Module):
    def __init__(self,
                 embed_channels,
                 groups,
                 attn_drop_rate=0.,
                 qkv_bias=True,
                 pe_multiplier=True,
                 pe_bias=True
                 ):
        super(GroupedVectorAttention, self).__init__()
        self.embed_channels = embed_channels
        self.groups = groups
        assert embed_channels % groups == 0
        self.attn_drop_rate = attn_drop_rate
        self.qkv_bias = qkv_bias
        self.pe_multiplier = pe_multiplier
        self.pe_bias = pe_bias

        self.linear_q = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True)
        )
        self.linear_k = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True)
        )

        self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)

        if self.pe_multiplier:
            self.linear_p_multiplier = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )
        if self.pe_bias:
            self.linear_p_bias = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )

        self.weight_encoding = nn.Sequential(
            GroupedLinear(embed_channels, groups, groups),
            PointBatchNorm(groups),
            nn.ReLU(inplace=True),
            nn.Linear(groups, groups)
        )
        self.softmax = nn.Softmax(dim=1)
        self.attn_drop = nn.Dropout(attn_drop_rate)

    def forward(self, feat, coord, reference_index):
        query, key, value = self.linear_q(feat), self.linear_k(feat), self.linear_v(feat)
        key = grouping(reference_index, key, coord, with_xyz=True)
        value = grouping(reference_index, value, coord, with_xyz=False)
        pos, key = key[:, :, 0:3],  key[:, :, 3:]
        relation_qk = key - query.unsqueeze(1)

        if self.pe_multiplier:
            pem = self.linear_p_multiplier(pos)
            relation_qk = relation_qk * pem
        if self.pe_bias:
            peb = self.linear_p_bias(pos)
            relation_qk = relation_qk + peb
            value = (value + peb)


        weight = self.weight_encoding(relation_qk)
        weight = self.attn_drop(self.softmax(weight))

        mask = torch.sign(reference_index + 1)
        weight = torch.einsum("n s g, n s -> n s g", weight, mask)
        value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups)
        feat = torch.einsum("n s g i, n s g -> n g i", value, weight)
        feat = einops.rearrange(feat, "n g i -> n (g i)")
        return feat

class GridPool(nn.Module):
    """
    Partition-based Pooling (Grid Pooling)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 grid_size,
                 bias=False):
        super(GridPool, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_size = grid_size

        self.fc = nn.Linear(in_channels, out_channels, bias=bias)
        self.norm = PointBatchNorm(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, points, start=None):
        coord, feat, batch = points
        offset = batch2offset(batch)

        feat = self.act(self.norm(self.fc(feat)))

        start = segment_csr(coord, torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
                            reduce="min") if start is None else start

        cluster = voxel_grid(pos=coord - start[batch], size=self.grid_size, batch=batch, start=0)

        unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)
        _, sorted_cluster_indices = torch.sort(cluster)
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
        feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")
        batch = batch[idx_ptr[:-1]]
        offset = batch2offset(batch)
        return [coord, feat, batch], cluster

class Block(nn.Module):
    def __init__(self,
                 embed_channels,
                 groups,
                 qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 enable_checkpoint=False
                 ):

        super(Block, self).__init__()
        self.attn = GroupedVectorAttention(
            embed_channels=embed_channels,
            groups=groups,
            qkv_bias=qkv_bias,
            attn_drop_rate=attn_drop_rate,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias
        )
        self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.norm1 = PointBatchNorm(embed_channels)
        self.norm2 = PointBatchNorm(embed_channels)
        self.norm3 = PointBatchNorm(embed_channels)
        self.act = nn.ReLU(inplace=True)
        self.enable_checkpoint = enable_checkpoint
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, points, reference_index):
        coord, feat, batch = points
        offset = batch2offset(batch)
        identity = feat
        feat = self.act(self.norm1(self.fc1(feat)))
        feat = self.attn(feat, coord, reference_index) #\
            #if not self.enable_checkpoint else checkpoint(self.attn, feat, coord, time, reference_index)
        feat = self.act(self.norm2(feat))
        feat = self.norm3(self.fc3(feat))
        feat = identity + self.drop_path(feat)
        feat = self.act(feat)
        return [coord, feat, batch]

class BlockSequence(nn.Module):
    def __init__(self,
                 depth,
                 embed_channels,
                 groups,
                 neighbours=16,
                 qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 enable_checkpoint=False
                 ):
        super(BlockSequence, self).__init__()

        if isinstance(drop_path_rate, list):
            drop_path_rates = drop_path_rate
            assert len(drop_path_rates) == depth
        elif isinstance(drop_path_rate, float):
            drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
        else:
            drop_path_rates = [0. for _ in range(depth)]
        self.neighbours = neighbours
        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                embed_channels=embed_channels,
                groups=groups,
                qkv_bias=qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=drop_path_rates[i],
                enable_checkpoint=enable_checkpoint
            )
            self.blocks.append(block)

    def forward(self, points):
        coord, feat, batch = points
        # reference index query of neighbourhood attention
        # for windows attention, modify reference index query method
        #print(coord.shape)
        #print(batch.shape)
        #print("PASSED")
        reference_index = knn(coord.contiguous(), coord.contiguous(), self.neighbours, batch.contiguous(), batch.contiguous())
        reference_index = reference_index[1,:].reshape(len(coord), self.neighbours)
        for block in self.blocks:
            points = block(points, reference_index)
        #print(points[0].shape)
        #print(points[1].shape)
        #print(points[2].shape)
        return points

class Encoder(nn.Module):
    def __init__(self,
                 depth,
                 in_channels,
                 embed_channels,
                 groups,
                 grid_size=None,
                 neighbours=16,
                 qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=None,
                 drop_path_rate=None,
                 enable_checkpoint=False,
                 ):
        super(Encoder, self).__init__()

        self.down = GridPool(
            in_channels=in_channels,
            out_channels=embed_channels,
            grid_size=grid_size,
        )

        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate if attn_drop_rate is not None else 0.,
            drop_path_rate=drop_path_rate if drop_path_rate is not None else 0.,
            enable_checkpoint=enable_checkpoint
        )

    def forward(self, points):
        #print(f"PRE DOWNSAMPLE SHAPE: {points[0].shape} {points[1].shape} {points[2].shape}")
        points, cluster = self.down(points)
        #print(f"POST DOWNSAMPLE SHAPE: {points[0].shape} {points[1].shape} {points[2].shape}")
        #print(f"PRE BLOCK SHAPE: {points[0].shape} {points[1].shape} {points[2].shape}")
        check = self.blocks(points)
        #print(f"POST BLOCK SHAPE: {check[0].shape} {check[1].shape} {check[2].shape}")
        return self.blocks(points), cluster



class GVAPatchEmbed(nn.Module):
    def __init__(self,
                 depth,
                 in_channels,
                 embed_channels,
                 groups,
                 neighbours=8,
                 qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 enable_checkpoint=False
                 ):
        super(GVAPatchEmbed, self).__init__()
        self.in_channels = in_channels
        self.embed_channels = embed_channels
        self.proj = nn.Sequential(
            nn.Linear(in_channels, embed_channels, bias=False),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True)
        )
        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            enable_checkpoint=enable_checkpoint
        )

    def forward(self, points):
        coord, feat, batch = points
        offset = batch2offset(batch)
        feat = self.proj(feat)
        return self.blocks([coord, feat, batch])

###2.2) Point Transformer Class

In [223]:
class PointTransformerV2(nn.Module):
    def __init__(self,
                 in_channels,
                 patch_embed_depth=1,
                 patch_embed_channels=16,
                 patch_embed_groups= 4 ,
                 patch_embed_neighbours=16,
                 enc_depths=(2, 2, 6),
                 enc_channels=(32, 64, 128),
                 enc_groups=(8, 16, 32),
                 enc_neighbours=(16, 16, 16),
                 grid_sizes=(0.06, 0.12, 0.25),
                 attn_qkv_bias=True,
                 pe_multiplier=True,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0,
                 enable_checkpoint=False,
                 unpool_backend="map"
                 ):

        super(PointTransformerV2, self).__init__()

        self.in_channels = in_channels
        self.num_stages = len(enc_depths)
        assert self.num_stages == len(enc_channels)
        assert self.num_stages == len(enc_groups)
        assert self.num_stages == len(enc_neighbours)
        assert self.num_stages == len(grid_sizes)

        self.feature_aggr = FeatureAggregation(in_channels, patch_embed_channels)

        self.patch_embed = GVAPatchEmbed(
            in_channels=patch_embed_channels,
            embed_channels=patch_embed_channels,
            groups=patch_embed_groups,
            depth=patch_embed_depth,
            neighbours=patch_embed_neighbours,
            qkv_bias=attn_qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            enable_checkpoint=enable_checkpoint
        )

        enc_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(enc_depths))]

        enc_channels = [patch_embed_channels] + list(enc_channels)

        self.enc_stages = nn.ModuleList()

        for i in range(self.num_stages):
            enc = Encoder(
                depth=enc_depths[i],
                in_channels=enc_channels[i],
                embed_channels=enc_channels[i + 1],
                groups=enc_groups[i],
                grid_size=grid_sizes[i],
                neighbours=enc_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=enc_dp_rates[sum(enc_depths[:i]):sum(enc_depths[:i + 1])],
                enable_checkpoint=enable_checkpoint
            )

            self.enc_stages.append(enc)

        self.pt_aggr = PointAggregation(enc_channels[-1])
        self.reg_head = nn.Sequential(
            nn.BatchNorm1d(enc_channels[-1]),
            nn.ReLU(inplace=True),
            nn.Linear(enc_channels[-1], enc_channels[-1])
        )

    def forward(self, data_dict):
        coords, feat, batch = inpdict_to_point(data_dict)
        offset = batch2offset(batch)
        points = [coords.float(), feat.float(), batch.int()]

        points = self.feature_aggr(points)
        points = self.patch_embed(points)
        skips = [[points]]

        for i in range(self.num_stages):
            #print(f"STAGE {i}")
            #print(f"PRE ENCODER SHAPE: {points[0].shape} {points[1].shape} {points[2].shape}")
            points, cluster = self.enc_stages[i](points)
            #print(f"POST ENCODER SHAPE: {points[0].shape} {points[1].shape} {points[2].shape}")
            skips[-1].append(cluster)  # record grid cluster of pooling
            skips.append([points])  # record points info of current stage

        points = skips.pop(-1)[0]
        coord, feat, batch = points
        seg_logits = self.reg_head(self.pt_aggr(feat, batch))

        return seg_logits

##3) Define AutoEncoder

###3.1) Define Decoder

In [224]:
class PointCloudDecoder(nn.Module):
    def __init__(self,
                 dec_in_dim,
                 dec_out_dim,
                 num_points,
                 num_layers = 1
                 ):
        super(PointCloudDecoder, self).__init__()

        self.num_points = num_points
        self.dec_out_dim = dec_out_dim

        self.decoder = nn.ModuleList()
        layer_dims = list(np.linspace(dec_in_dim,dec_out_dim*num_points,num=num_layers+1, dtype = int))
        for i in range(num_layers):
          layer = nn.Sequential(
              nn.Linear(layer_dims[i], layer_dims[i+1]),
              nn.BatchNorm1d(layer_dims[i+1]),
              nn.ReLU(inplace=True)
          )
          self.decoder.append(layer)

    def forward(self, points):
        x = points
        for layer in self.decoder:
          x = layer(x)
        output = x.view(x.size(0), self.num_points, self.dec_out_dim)
        return output


'''
class PointCloudDecoder(nn.Module):
    def __init__(self,
                latent_dim,
                out_feat_dim,
                hidden_dims=(256, 128, 64),
                num_points_template=1024  # Default template size
                ):
        super(PointCloudDecoder, self).__init__()

        self.latent_dim = latent_dim
        self.out_feat_dim = out_feat_dim
        self.hidden_dims = hidden_dims
        self.num_points_template = num_points_template

        # Create a learnable point template
        self.point_template = nn.Parameter(torch.randn(1, num_points_template, 3) * 0.1)

        # MLP to process latent vector
        self.latent_mlp = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[0]),
            nn.LayerNorm(hidden_dims[0]),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dims[0], hidden_dims[0]),
            nn.LayerNorm(hidden_dims[0]),
            nn.ReLU(inplace=True)
        )

        # Deformation MLP - combines point template with latent features
        deform_layers = []
        deform_input_dim = 3 + hidden_dims[0]  # Concat of point coords + latent
        last_dim = deform_input_dim

        for h in hidden_dims:
            deform_layers.extend([
                nn.Linear(last_dim, h),
                nn.LayerNorm(h),
                nn.ReLU(inplace=True)
            ])
            last_dim = h

        self.deform_mlp = nn.Sequential(*deform_layers)

        # Output heads
        self.coord_head = nn.Linear(last_dim, 3)
        if out_feat_dim > 0:
            self.feat_head = nn.Linear(last_dim, out_feat_dim)

    def forward(self, x, num_points_list):
        """
        x: [B, latent_dim] - batch of latent vectors
        num_points_list: list or tensor of length B - number of points to generate for each batch item
        """
        coords_list = []
        feats_list = []
        B = x.size(0)

        # Process latent vector
        latent_features = self.latent_mlp(x)  # [B, hidden_dims[0]]

        for i in range(B):
            num_points = num_points_list[i]

            # Get latent for this sample and expand
            latent_feat = latent_features[i].unsqueeze(0).expand(num_points, -1)  # [num_points, hidden_dims[0]]

            # Sample points from template
            # If num_points > template size, we need to interpolate
            if num_points <= self.num_points_template:
                # Take first num_points from template
                template_points = self.point_template[0, :num_points, :]
            else:
                # Interpolate (sample with repetition)
                indices = torch.linspace(0, self.num_points_template-1, num_points).long()
                template_points = self.point_template[0, indices, :]

            # Combine point coords with latent features
            point_features = torch.cat([template_points, latent_feat], dim=-1)  # [num_points, 3+hidden_dims[0]]

            # Apply deformation network
            deformed_features = self.deform_mlp(point_features)  # [num_points, last_hidden_dim]

            # Generate output coordinates (as offsets to template)
            coord_offsets = self.coord_head(deformed_features)  # [num_points, 3]
            coords = template_points + coord_offsets  # Apply offset to template
            coords_list.append(coords)

            # Generate output features if needed
            if self.out_feat_dim > 0:
                feats = self.feat_head(deformed_features)  # [num_points, feat_dim]
                feats_list.append(feats)

        # Concatenate results from batch
        coords_out = torch.cat(coords_list, dim=0)  # [sum(num_points), 3]

        if self.out_feat_dim > 0:
            feats_out = torch.cat(feats_list, dim=0)  # [sum(num_points), feat_dim]
            return coords_out, feats_out
        else:
            return coords_out

'''
print()




###3.2) Define AutoEncoder Class

In [225]:
class AutoEncoder(nn.Module):
  def __init__(self, config_dict):
    super(AutoEncoder, self).__init__()
    self.config_dict = config_dict
    self.point_transformer = PointTransformerV2(in_channels = config_dict['enc_ip_dim'])
    self.point_cloud_decoder = PointCloudDecoder(dec_in_dim = config_dict['enc_out_dim'],
                                                 dec_out_dim = config_dict['enc_ip_dim'],
                                                 num_points = config_dict['num_points'])
    self.enc_ip_dim = config_dict['enc_ip_dim']

  def forward(self, points):
    feat = self.point_transformer(points)
    output = self.point_cloud_decoder(feat)
    output = output.reshape(-1,self.enc_ip_dim)
    return feat, output

'''
class AutoEncoder(nn.Module):
  def __init__(self, config_dict):
    super(AutoEncoder, self).__init__()
    self.config_dict = config_dict
    self.point_transformer = PointTransformerV2(in_channels = config_dict['enc_ip_dim'])
    self.point_cloud_decoder = PointCloudDecoder(latent_dim = config_dict['latent_dim'],
                                                 out_feat_dim = config_dict['enc_ip_dim'])

  def forward(self, points):
    num_points_list = points['num_points']
    feat = self.point_transformer(points)
    output = self.point_cloud_decoder(feat, num_points_list)


    return feat, output
'''
print()




##4) Define Configs

In [226]:
config_dict = {'enc_ip_dim':1,
                'batch_size':8,
                'val_batch_size':1,
                'dmodel':128,
                'n_head':8,
                'num_enc':2,
                'num_points': 4096,
                'dropout':0.1,
                'enc_out_dim':128,
                'latent_dim': 128,
                'total_samples': 2000,
                'num_epochs': 5}


##5) Training

###5.1) Example Code

In [227]:
#TEST CODE
example_pc = np.load('2024-10-28_0.npy')
example_pc = example_pc[:config_dict['num_points'],:]
offset = torch.tensor([0, len(example_pc)]).long()
batch = (offset2batch(offset)*0).contiguous()
enc_inp = torch.from_numpy(example_pc[:, :]).float().contiguous()
data_dict = {'enc_inp':enc_inp,'inp_batch_ids': batch, 'num_points': [len(example_pc)]}
print("enc_inp is contiguous:", enc_inp.is_contiguous())
print("inp_batch_ids is contiguous:", batch.is_contiguous())


example_point_transformer = PointTransformerV2(in_channels = config_dict['enc_ip_dim'])
example_point_transformer.eval()
example_autoencoder = AutoEncoder(config_dict)
example_autoencoder.eval()
features, output = example_autoencoder(data_dict)

enc_inp is contiguous: True
inp_batch_ids is contiguous: True


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x4096 and 16x16)

###5.2) Define Class to Format Data

In [228]:
class PC_dataset_individuals(object):
  def __init__(self, root_dir, pc_dir_list, config_dict):
    self.pc_dir_list = pc_dir_list
    self.root_dir = root_dir
    self.num_points = config_dict['num_points']


  def get_sample(self, idx, count):

    enc_inp = []
    inp_batch_ids = []

    pc_df = []
    pc_df_good_points = 0
    while pc_df_good_points < 5:
      if os.path.exists(os.path.join(self.root_dir, self.pc_dir_list[idx])):
          #print(os.path.join(self.root_dir, self.pc_dir_list[idx]))
          pc_df = np.load(os.path.join(self.root_dir, self.pc_dir_list[idx]))
          if self.num_points < len(pc_df):
            pc_df = pc_df[:self.num_points,:]
          pc_df_good_points = len(pc_df[(pc_df[:,3] >= -5) & (pc_df[:,3] <= 5)])

          if idx == len(self.pc_dir_list)-1:
            idx = 0
          else:
            idx += 1
      if np.isnan(pc_df).any():
          print('Nan in file: ')
          pc_df = []

    enc_inp.append(pc_df)
    inp_batch_ids.extend((count + np.zeros_like(pc_df[:,0])).tolist())


    enc_inp = torch.from_numpy(np.concatenate(enc_inp, axis = 0))
    inp_batch_ids = torch.from_numpy(np.asarray(inp_batch_ids)).int()

    #enc_inp[:,3] = (enc_inp[:,3] + 30)/45


    data_dict = {'enc_inp': enc_inp,
                 'inp_batch_ids': inp_batch_ids,
                 'num_points': torch.tensor([len(enc_inp)])}


    return data_dict

  def data_len(self):
    return len(self.pc_dir_list)

  def collate_fn(self, batch):
    batch_dict = dict.fromkeys(list(batch[0].keys()))
    for key in list(batch_dict.keys()):
      bkey = [b[key] for b in batch]
      batch_dict[key] = torch.cat(bkey, dim = 0)

    return batch_dict

  def get_batch_single_pc(self, batch_size): # Two functions internally : get_sample, collate_fn

    idxs = np.arange(self.data_len()).astype(int)
    np.random.shuffle(idxs)
    random_idxs = idxs[:batch_size]

    databatch = []
    for count,b in enumerate(random_idxs):
      databatch.append(self.get_sample(b,count))

    return self.collate_fn(databatch)

###5.3) Define Loss Criterions

In [229]:
#DON'T NEED THESE LOSSES. GO BACK TO HUBERLOSSES
'''
def chamfer_distance(x, y):
    """
    Calculate the Chamfer Distance between two point clouds
    x: [N, D] first point cloud (e.g., predicted)
    y: [M, D] second point cloud (e.g., target)
    Returns the Chamfer Distance between the two point clouds
    """
    # Reshape to [1, N, D] and [1, M, D] if inputs are not batched
    if x.dim() == 2:
        x = x.unsqueeze(0)
    if y.dim() == 2:
        y = y.unsqueeze(0)

    # Get batch size
    batch_size = x.size(0)

    # Compute pairwise distances
    xx = torch.sum(x**2, dim=2, keepdim=True)       # [B, N, 1]
    yy = torch.sum(y**2, dim=2, keepdim=True)       # [B, M, 1]

    # Compute all pairwise distances using matrix multiplication
    inner = -2 * torch.matmul(x, y.transpose(1, 2))   # [B, N, M]
    distances = xx + inner + yy.transpose(1, 2)       # [B, N, M]

    # Get min distance for each point in x to any point in y
    mins_x, _ = torch.min(distances, dim=2)  # [B, N]

    # Get min distance for each point in y to any point in x
    mins_y, _ = torch.min(distances, dim=1)  # [B, M]

    # Compute the mean over points and add both directions
    chamfer_dist = torch.mean(mins_x, dim=1) + torch.mean(mins_y, dim=1)  # [B]

    # Return the mean over the batch
    return torch.mean(chamfer_dist)

class PointCloudLoss(nn.Module):
    """Combined loss for point cloud reconstruction"""
    def __init__(self, chamfer_weight=1.0, feature_weight=0.1):
        super(PointCloudLoss, self).__init__()
        self.chamfer_weight = chamfer_weight
        self.feature_weight = feature_weight
        self.feature_criterion = nn.HuberLoss()

    def forward(self, pred_coords, target_coords, pred_feats=None, target_feats=None):
        # Chamfer distance for coordinates
        chamfer_loss = chamfer_distance(pred_coords, target_coords)

        # Feature loss (if features are provided)
        feature_loss = 0.0
        if pred_feats is not None and target_feats is not None:
            feature_loss = self.feature_criterion(pred_feats, target_feats)

        # Combined loss
        total_loss = self.chamfer_weight * chamfer_loss + self.feature_weight * feature_loss

        return total_loss, chamfer_loss, feature_loss
'''
print()




###5.4) Define Train and Val Functions

In [230]:
def visualise_point_cloud(coords, original, reconstructed, save_path):
    """
    Visualize original and reconstructed point clouds side by side

    Args:
        original: original point cloud coordinates [N, 3]
        reconstructed: reconstructed point cloud coordinates [N, 3]
        save_path: path to save the visualization
    """
    # Convert to numpy from tensors
    if isinstance(original, torch.Tensor):
        original = original.detach().cpu().numpy()
    if isinstance(reconstructed, torch.Tensor):
        reconstructed = reconstructed.detach().cpu().numpy()
    if isinstance(coords, torch.Tensor):
        coords = coords.detach().cpu().numpy()

    #add back coordinates
    original = np.hstack((coords,original))
    reconstructed = np.hstack((coords,reconstructed))
    #print(f'original: {original[:3,:]}')
    #print(f'\nreconstructed {reconstructed[:3,:]}')
    #Determine points in correct CNR band
    original = original[(original[:,3] >= -5) & (original[:,3] <= 5)]
    reconstructed = reconstructed[(reconstructed[:,3] >= -5) & (reconstructed[:,3] <= 5)]
    #print(f'new original: {original[:3,:]}')
    #print(f'\nnew reconstructed {reconstructed[:3,:]}')

    fig = plt.figure(figsize=(12, 6))

    # Original point cloud
    ax1 = fig.add_subplot(121, projection='3d')
    ax1.scatter(original[:, 0], original[:, 1], original[:, 2], s=1, c=original[:, 2], cmap='viridis')
    ax1.set_title('Original Point Cloud')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')

    # Reconstructed point cloud
    ax2 = fig.add_subplot(122, projection='3d')
    ax2.scatter(reconstructed[:, 0], reconstructed[:, 1], reconstructed[:, 2], s=1, c=reconstructed[:, 2], cmap='viridis')
    ax2.set_title('Reconstructed Point Cloud')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def inspect_outputs(target, reconstructed, num_samples=5):
    """Inspect target and reconstructed point clouds"""
    print(f"\nOutput Diagnostics (showing {num_samples} samples):")

    # Check if all reconstructed points are the same
    recon_all_same = torch.allclose(
        reconstructed[0].unsqueeze(0).expand_as(reconstructed),
        reconstructed,
        rtol=1e-4, atol=1e-4
    )

    if recon_all_same:
        print(f"CRITICAL ISSUE: All reconstructed points are identical!")

    # Print statistics
    print(f"Target points - Mean: {target.mean(dim=0)}, Std: {target.std(dim=0)}")
    print(f"Reconstructed points - Mean: {reconstructed.mean(dim=0)}, Std: {reconstructed.std(dim=0)}")

    # Show some examples
    print("\nSample points (target vs reconstructed):")
    for i in range(min(num_samples, len(target))):
        print(f"Point {i}:")
        print(f"  Target:        {target[i]}")
        print(f"  Reconstructed: {reconstructed[i]}")

    # Check variance across dimensions
    print(f"\nVariance per dimension:")
    print(f"  Target:        {torch.var(target, dim=0)}")
    print(f"  Reconstructed: {torch.var(reconstructed, dim=0)}")

def train(model, train_dataset, batch_size, optimizer, scheduler, criterion, epoch, device, config_dict, vis_dir=None):
    model.train()
    epoch_start = time.time()
    train_feat_loss = 0
    num_steps_per_epoch = (train_dataset.data_len() // batch_size) + 1
    count_good_steps = 0

    for step in range(num_steps_per_epoch):
        step_start = time.time()

        # Make batch
        data_dict = train_dataset.get_batch_single_pc(batch_size)
        data_dict = {k: v.to(device) for k, v in data_dict.items()}

        optimizer.zero_grad()

        # Forward pass
        embedded_features, recon_feats = model(data_dict)
        coords, feats, batch = inpdict_to_point(data_dict)
        '''
        print(f'embedded_features: {embedded_features.shape}')
        print(f'recon_feats: {recon_feats.shape}')
        print(f'coords: {coords.shape}')
        print(f'feats: {feats.shape}')
        print(f'batch: {batch.shape}')
        print(data_dict['num_points'])
        '''
        coords = coords.float()
        target_feats = feats.float()

        # Visualisation
        if vis_dir and step == 0 and epoch % 1 == 0:
            os.makedirs(vis_dir, exist_ok=True)
            # Only visualize first point cloud in batch
            first_pc_len = data_dict['num_points'][0].item()
            visualise_point_cloud(
                coords[:first_pc_len],
                target_feats[:first_pc_len],
                recon_feats[:first_pc_len],
                os.path.join(vis_dir, f'epoch_{epoch}_train.png')
            )

        # === Sanity checks ===
        if torch.isnan(recon_feats).any():
            print("NaN in model output. Skipping step.")
            continue

        # Loss criterion (only care about losses for features as coordinates should be same)
        feat_loss = criterion(recon_feats.squeeze(-1).float(), target_feats.squeeze(-1).float())

        # Backwards and optimise
        feat_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()


        train_feat_loss += feat_loss.item()
        count_good_steps += 1

        step_end = time.time()
        print(f"Epoch {epoch}, Step {step}/{num_steps_per_epoch}, "
              f"Feat Loss: {feat_loss.item():.4f}, "
              f"Time: {step_end - step_start:.2f}s")

        gc.collect()
        del data_dict

    epoch_duration = time.time() - epoch_start
    print(f"Epoch {epoch} completed in {epoch_duration:.2f}s")

    # Return average losses
    avg_feat_loss = train_feat_loss / max(count_good_steps, 1)

    return avg_feat_loss

def val(model, val_dataset, batch_size, criterion, epoch, device, config_dict, vis_dir=None):
    model.eval()
    epoch_start = time.time()
    val_feat_loss = 0
    num_steps_per_epoch = (val_dataset.data_len() // batch_size) + 1

    for step in range(num_steps_per_epoch):
        step_start = time.time()

        data_dict = val_dataset.get_batch_single_pc(batch_size)
        data_dict = {k: v.to(device) for k, v in data_dict.items()}

        # Forward pass
        embedded_features, recon_feats = model(data_dict)
        coords, feats, batch = inpdict_to_point(data_dict)
        coords = coords.float()
        target_feats = feats.float()

        # Visualisation
        if vis_dir and step == 0 and epoch % 1 == 0:
            os.makedirs(vis_dir, exist_ok=True)
            # Only visualize first point cloud in batch
            first_pc_len = data_dict['num_points'][0].item()
            visualise_point_cloud(
                coords[:first_pc_len],
                target_feats[:first_pc_len],
                recon_feats[:first_pc_len],
                os.path.join(vis_dir, f'epoch_{epoch}_train.png')
            )

            # Debug output inspection
            inspect_outputs(target_feats[:first_pc_len], recon_feats[:first_pc_len])

        # Loss criterion (only care about losses for features as coordinates should be same)
        feat_loss = criterion(recon_feats.squeeze(-1).float(), target_feats.squeeze(-1).float())

        step_end = time.time()
        print(f"Epoch: {epoch}, Step: {step}/{num_steps_per_epoch}, "
              f"Feat Loss: {feat_loss.item():.4f}, "
              f"Time: {step_end-step_start:.2f}s")

        val_feat_loss += feat_loss.item()

        gc.collect()
        del data_dict

    print(f'Time to complete validation: {time.time()-epoch_start:.2f}s')

    # Return average losses
    avg_feat_loss = val_feat_loss / num_steps_per_epoch

    return avg_feat_loss

###5.5) Training Script

In [231]:
#Setting up device stuff
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
torch.backends.cudnn.benchmark = False

#Finding where the data is and making save and visualisation directories
root = os.getcwd()
lidar_dir = os.path.join(root, 'LiDAR')
pcdata_dir = os.path.join(lidar_dir, 'pcloud_norm')
pc_files = natsort.natsorted(os.listdir(pcdata_dir))
pc_paths = [os.path.join(pcdata_dir, pcf) for pcf in pc_files]
if config_dict['total_samples'] < len(pc_paths):
  pc_paths = pc_paths[:config_dict['total_samples']]
save_dir = os.path.join(lidar_dir, os.path.join('saved_models'))
if not os.path.exists(save_dir):
  os.makedirs(save_dir)
vis_dir = os.path.join(lidar_dir, 'visualisations')
os.makedirs(vis_dir, exist_ok=True)


#Creating datasets
train_val_pcs = pc_paths[:(int(len(pc_paths)*0.7))]
test_pcs = pc_paths[(int(len(pc_paths)*0.7)):]
train_idx = np.arange(len(train_val_pcs)).astype(int)[:int(0.9*len(train_val_pcs))]
val_idx = np.arange(len(train_val_pcs)).astype(int)[int(0.9*len(train_val_pcs)):]
train_pcs = np.asarray(train_val_pcs)[train_idx].tolist()
val_pcs = np.asarray(train_val_pcs)[val_idx].tolist()
train_dataset = PC_dataset_individuals(lidar_dir, train_pcs, config_dict)
val_dataset = PC_dataset_individuals(lidar_dir, val_pcs, config_dict)
test_dataset = PC_dataset_individuals(lidar_dir, test_pcs, config_dict)
num_steps_per_epoch = (train_dataset.data_len() // config_dict['batch_size']) + 1

#Initialising model stuff
model = AutoEncoder(config_dict).to(device)
criterion = nn.HuberLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.0005,
    weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr = 0.001,
                                                epochs=config_dict['num_epochs'],
                                                steps_per_epoch=num_steps_per_epoch)


# Training loop
for epoch in range(config_dict['num_epochs']):
    # Train
    train_feat_loss = train(
        model, train_dataset, config_dict['batch_size'],
        optimizer, scheduler, criterion, epoch, device, config_dict,
        vis_dir=os.path.join(vis_dir, 'train')
    )

    # Validate
    with torch.no_grad():
        val_feat_loss = val(
            model, val_dataset, config_dict['val_batch_size'],
            criterion, epoch, device, config_dict,
            vis_dir=os.path.join(vis_dir, 'val')
        )

    # Update learning rate based on validation loss
    scheduler.step(val_feat_loss)

    print(f'Epoch {epoch} summary:')
    print(f'  Train - Total: {train_feat_loss:.4f}')
    print(f'  Val   - Total: {val_feat_loss:.4f}')

    # Save model checkpoint
    if epoch % 5 == 0 or epoch == config_dict['num_epochs'] - 1:
        torch.save(model.state_dict(), os.path.join(save_dir, f'model_epoch_{epoch}.pth'))

    # Save best model
    if epoch == 0 or val_feat_loss < min_val_loss:
        min_val_loss = val_feat_loss
        torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
        print(f'Saved new best model with val_loss: {val_feat_loss:.4f}')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x32768 and 16x16)

##6) Testing

###6.1) Define Testing Function

In [None]:
def test(model, test_dataset, batch_size, criterion, device, config_dict, vis_dir=None):
    model.eval()
    epoch_start = time.time()
    test_loss = 0
    test_coord_loss = 0
    test_feat_loss = 0
    num_steps_per_epoch = (test_dataset.data_len() // batch_size) + 1 * (test_dataset.data_len() % batch_size != 0)

    recon_coords_list = []
    recon_feats_list = []
    target_coords_list = []
    target_feats_list = []

    for step in range(num_steps_per_epoch):
        step_start = time.time()

        data_dict = test_dataset.get_batch_single_pc(batch_size)
        data_dict = {k: v.to(device) for k, v in data_dict.items()}

        # === Forward pass ===
        _, output = model(data_dict)
        recon_coords = output[0]
        recon_feats = output[1]

        coords, feats, _, batch = inpdict_to_point(data_dict)
        target_coords = coords.float()
        target_feats = feats.float()

        # === Store results for visualization ===
        recon_coords_list.append(recon_coords.detach().cpu())
        recon_feats_list.append(recon_feats.detach().cpu())
        target_coords_list.append(target_coords.detach().cpu())
        target_feats_list.append(target_feats.detach().cpu())

        # === Loss computation ===
        total_loss, coord_loss, feat_loss = criterion(
            recon_coords, target_coords, recon_feats, target_feats
        )

        step_end = time.time()
        print(f"Step: {step}/{num_steps_per_epoch}, "
              f"Total Loss: {total_loss.item():.4f}, "
              f"Coord Loss: {coord_loss.item():.4f}, "
              f"Feat Loss: {feat_loss.item():.4f}, "
              f"Time: {step_end-step_start:.2f}s")

        test_loss += total_loss.item()
        test_coord_loss += coord_loss.item()
        test_feat_loss += feat_loss.item()

        # === Visualize test results ===
        if vis_dir and step < 5:  # Visualize first 5 test samples
            os.makedirs(vis_dir, exist_ok=True)
            # Only visualize first point cloud in batch
            first_pc_len = data_dict['num_points'][0].item()
            visualize_point_cloud(
                target_coords[:first_pc_len],
                recon_coords[:first_pc_len],
                os.path.join(vis_dir, f'test_sample_{step}.png')
            )

            # Debug output inspection
            inspect_outputs(target_coords[:first_pc_len], recon_coords[:first_pc_len])

        gc.collect()
        del data_dict

    print(f'Time to complete testing: {time.time()-epoch_start:.2f}s')

    # Return average losses and result lists
    avg_loss = test_loss / num_steps_per_epoch
    avg_coord_loss = test_coord_loss / num_steps_per_epoch
    avg_feat_loss = test_feat_loss / num_steps_per_epoch

    return (avg_loss, avg_coord_loss, avg_feat_loss,
            recon_coords_list, recon_feats_list,
            target_coords_list, target_feats_list)

###6.2) Testing Script

In [None]:
# Test the best model
model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pth')))
model.eval()

with torch.no_grad():
    test_results = test(
        model, test_dataset, batch_size=1, criterion=criterion,
        device=device, config_dict=config_dict,
        vis_dir=os.path.join(vis_dir, 'test')
    )

print(f'Final test loss: {test_results[0]:.4f}')

TypeError: 'PC_dataset_individuals' object is not subscriptable