In [1]:
from models.waffleiron.segmenter import Segmenter
import torch
from datasets import LIST_DATASETS, Collate
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from collections import OrderedDict
import warnings
import copy
import random
import numpy as np
import os

import argparse
import wandb
from torchmetrics.classification import MulticlassJaccardIndex
import torchmetrics
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

import torchhd
from torchhd.models import Centroid
from torchhd import embeddings

Using torch.scatter_reduce for 3D to 2D projection.
Using torch.scatter_reduce for 3D to 2D projection.


In [203]:
class Feature_Extractor:
    def __init__(self, input_channels=5, feat_channels=768, depth=48, 
                 grid_shape=[[256, 256], [256, 32], [256, 32]], nb_class=16, layer_norm=True, 
                 device=torch.device("cpu"), early_exit = 48, **kwargs):
        self.model = Segmenter(
            input_channels=input_channels,
            feat_channels=feat_channels,
            depth=depth,
            grid_shape=grid_shape,
            nb_class=nb_class, # class for prediction
            #drop_path_prob=config["waffleiron"]["drop_path"],
            layer_norm=layer_norm,
        )

        classif = torch.nn.Conv1d(
            feat_channels, nb_class, 1 # So it fits 16 = nb_class but classifier is not used
        )
        torch.nn.init.constant_(classif.bias, 0)
        torch.nn.init.constant_(classif.weight, 0)
        self.model.classif = torch.nn.Sequential(
            torch.nn.BatchNorm1d(feat_channels),
            classif,
        )

        for p in self.model.parameters():
            p.requires_grad = False
        for p in self.model.classif.parameters():
            p.requires_grad = True

        def get_optimizer(parameters):
            return torch.optim.AdamW(
                parameters,
                lr=0.001,
                weight_decay=0.003,
            )

        optim = get_optimizer(self.model.parameters())
        self.device = device
        self.device_string = "cuda:0"
        self.num_classes = nb_class
        self.early_exit = early_exit
        self.kwargs = kwargs
    
    def load_pretrained(self, path):
        # Load pretrained model
        path_to_ckpt = path
        checkpoint = torch.load(path_to_ckpt,
            map_location=self.device_string)
        state_dict = checkpoint["net"]  # Adjust key as needed
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            new_key = k.replace("module.", "")  # Remove "module." prefix
            new_state_dict[new_key] = v

        self.model.load_state_dict(new_state_dict)

        print(
            f"Checkpoint loaded on {self.device_string}: {path_to_ckpt}"
        )

        if self.device_string != 'cpu':
            torch.cuda.set_device(self.device_string) # cuda:0
            self.model = self.model.cuda(self.device_string) # cuda:0

        self.model.eval()

    def forward_model(self, it, batch):
        feat = batch["feat"]
        labels = batch["labels_orig"]
        cell_ind = batch["cell_ind"]
        occupied_cell = batch["occupied_cells"]
        neighbors_emb = batch["neighbors_emb"]
        if self.device_string != 'cpu':
            feat = feat.cuda(0, non_blocking=True)
            labels = labels.cuda(0, non_blocking=True)
            batch["upsample"] = [
                up.cuda(0, non_blocking=True) for up in batch["upsample"]
            ]
            cell_ind = cell_ind.cuda(0, non_blocking=True)
            occupied_cell = occupied_cell.cuda(0, non_blocking=True)
            neighbors_emb = neighbors_emb.cuda(0, non_blocking=True)
        net_inputs = (feat, cell_ind, occupied_cell, neighbors_emb)

        if self.device_string != 'cpu':
            with torch.autocast("cuda", enabled=True):
                # Logits
                with torch.no_grad():
                    out = self.model(*net_inputs, self.early_exit)
                    encode, tokens, out = out[0], out[1], out[2]
                    pred_label = out.max(1)[1]

                    # Only return samples that are not noise
                    #torch.cuda.synchronize(device=self.device)
                    where = labels != 255
                    #torch.cuda.synchronize(device=self.device)
        else:
            with torch.no_grad():
                out = self.model(*net_inputs, self.early_exit)
                encode, tokens, out = out[0], out[1], out[2]
                pred_label = out.max(1)[1]

                # Only return samples that are not noise
                where = labels != 255
        
        return tokens[0,:,where], labels[where], pred_label[0, where]

    def test(self, loader, total_voxels):        
        # Metric
        miou = MulticlassJaccardIndex(num_classes=self.num_classes, average=None).to(self.device, non_blocking=True)
        final_labels = torch.empty((total_voxels), device=self.device)
        final_pred = torch.empty((total_voxels), device=self.device)
        final_labels_sep = torch.empty((1, 90000), device=self.device)
        final_feat_sep = torch.empty((1, 768, 90000), device=self.device)
        final_soa_result_sep = torch.empty((1, 90000), device=self.device)
        num_voxels = []
        
        start_idx = 0
        for it, batch in tqdm(enumerate(loader), desc="SoA testing"):
            features, labels, soa_result = self.forward_model(it, batch)
            shape_sample = labels.shape[0]
            num_voxels.append(shape_sample)
            labels_tensor = torch.reshape(torch.Tensor(labels), (1,shape_sample)).to(self.device)
            feat_tensor = torch.reshape(torch.Tensor(features), (1,768,shape_sample)).to(self.device)
            soa_tensor = torch.reshape(torch.Tensor(soa_result), (1,shape_sample)).to(self.device)
            final_labels_sep = torch.concat((final_labels_sep, F.pad(input=labels_tensor, pad=(0, 90000 - shape_sample), mode='constant', value=0)))
            final_soa_result_sep = torch.concat((final_soa_result_sep, F.pad(input=soa_tensor, pad=(0, 90000 - shape_sample), mode='constant', value=0)))
            final_feat_sep = torch.concat((final_feat_sep, F.pad(input=feat_tensor, pad=(0, 90000 - shape_sample), mode='constant', value=0)))
            print(final_labels_sep.shape)
            print(final_soa_result_sep.shape)
            print(final_feat_sep.shape)

            labels = labels.to(dtype = torch.int64, device = self.device, non_blocking=True)
            soa_result = soa_result.to(device=self.device, non_blocking=True)
            final_labels[start_idx:start_idx+shape_sample] = labels

            final_pred[start_idx:start_idx+shape_sample] = soa_result

            start_idx += shape_sample

            if it == 50:
                break

        final_labels = final_labels[:start_idx]
        final_pred = final_pred[:start_idx]

        print("================================")

        print('Pred FE', final_pred, "\tShape: ", final_pred.shape)
        print('Label', final_labels, "\tShape: ", final_labels.shape)
        accuracy = miou(final_pred, final_labels)
        avg_acc = torch.mean(accuracy)
        print(f'accuracy: {accuracy}')
        print(f'avg acc: {avg_acc}')

        #cm = confusion_matrix(pred_hd, first_label, labels=torch.Tensor(range(0,15)))
        #print("Confusion matrix \n")
        #print(cm)

        print("================================")

        return final_labels_sep, final_soa_result_sep, final_feat_sep, num_voxels

In [204]:
fe = Feature_Extractor(nb_class=19)
fe.load_pretrained('/root/main/ScaLR/saved_models/ckpt_last_kitti.pth')

  checkpoint = torch.load(path_to_ckpt,


Checkpoint loaded on cuda:0: /root/main/ScaLR/saved_models/ckpt_last_kitti.pth


In [205]:
kwargs = {
    "rootdir": '/root/main/dataset/semantickitti',
    "input_feat": ["intensity", "xyz", "radius"],
    "voxel_size": 0.1,
    "num_neighbors": 16,
    "dim_proj": [2, 1, 0],
    "grids_shape": [[256, 256], [256, 32], [256, 32]],
    "fov_xyz": [[-64, -64, -8], [64, 64, 8]], # Check here
}

# Get datatset
DATASET = LIST_DATASETS.get('semantic_kitti')

dataset_train = DATASET(
    phase="train",
    **kwargs,
)

# Validation dataset
dataset_val = DATASET(
    phase="val",
    **kwargs,
)

num_classes = 19

path_pretrained = '/root/main/ScaLR/saved_models/ckpt_last_kitti.pth'

Using original split
Using original split


In [206]:
device = torch.device("cuda")
print("Using {} device".format(device))

train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=1,
        pin_memory=True,
        drop_last=True,
        collate_fn=Collate(device=device),
        persistent_workers=False,
    )

val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=1,
    pin_memory=True,
    drop_last=True,
    collate_fn=Collate(device=device),
    persistent_workers=False,
)

Using cuda device


In [None]:
labels_train, soa_train, feat_train, vox_train = fe.test(train_loader, 5500000) # With the nuscenes hyperparameters...

SoA testing: 1it [00:03,  3.37s/it]

torch.Size([2, 90000])
torch.Size([2, 90000])
torch.Size([2, 768, 90000])


SoA testing: 2it [00:06,  3.20s/it]

torch.Size([3, 90000])
torch.Size([3, 90000])
torch.Size([3, 768, 90000])


SoA testing: 3it [00:10,  3.53s/it]

torch.Size([4, 90000])
torch.Size([4, 90000])
torch.Size([4, 768, 90000])


SoA testing: 4it [00:13,  3.38s/it]

torch.Size([5, 90000])
torch.Size([5, 90000])
torch.Size([5, 768, 90000])


SoA testing: 5it [00:17,  3.50s/it]

torch.Size([6, 90000])
torch.Size([6, 90000])
torch.Size([6, 768, 90000])


SoA testing: 6it [00:20,  3.49s/it]

torch.Size([7, 90000])
torch.Size([7, 90000])
torch.Size([7, 768, 90000])


SoA testing: 7it [00:25,  3.80s/it]

torch.Size([8, 90000])
torch.Size([8, 90000])
torch.Size([8, 768, 90000])


SoA testing: 8it [00:30,  4.25s/it]

torch.Size([9, 90000])
torch.Size([9, 90000])
torch.Size([9, 768, 90000])


SoA testing: 9it [00:34,  4.30s/it]

torch.Size([10, 90000])
torch.Size([10, 90000])
torch.Size([10, 768, 90000])


SoA testing: 10it [00:39,  4.29s/it]

torch.Size([11, 90000])
torch.Size([11, 90000])
torch.Size([11, 768, 90000])


SoA testing: 11it [00:45,  4.83s/it]

torch.Size([12, 90000])
torch.Size([12, 90000])
torch.Size([12, 768, 90000])


SoA testing: 12it [00:50,  4.95s/it]

torch.Size([13, 90000])
torch.Size([13, 90000])
torch.Size([13, 768, 90000])


SoA testing: 13it [00:55,  5.01s/it]

torch.Size([14, 90000])
torch.Size([14, 90000])
torch.Size([14, 768, 90000])


SoA testing: 14it [01:00,  5.05s/it]

torch.Size([15, 90000])
torch.Size([15, 90000])
torch.Size([15, 768, 90000])


SoA testing: 15it [01:05,  5.14s/it]

torch.Size([16, 90000])
torch.Size([16, 90000])
torch.Size([16, 768, 90000])


In [193]:
torch.tensor(vox_train[1:])

tensor([44142, 44003, 59890, 40398, 41625, 53900, 45817, 56911, 51037, 53170,
        59709, 43178, 53050, 62282, 44822, 53380, 45886, 47859, 57482, 57516,
        54655, 46799, 61926, 50007, 60078, 60736, 70450, 59455, 33977, 54562,
        43954, 41885, 58911, 51025, 64296, 47641, 61382, 61788, 42247, 52006,
        47080, 58451, 59847, 52650, 33660, 55842, 54676, 44737, 56141, 64482,
        52811, 55376, 67701, 48054, 51237, 54986, 43940, 54340, 54894, 27346,
        48815, 57825, 60356, 49891, 58465, 50979, 53075, 53087, 51195, 46261,
        53950, 54229, 52888, 52543, 46392, 53398, 61353, 59786, 55711, 58418,
        57230, 58920, 58534, 49536, 43975, 58928, 52022, 53749, 56769, 48545,
        58366, 48006, 56668, 61652, 47422, 46437, 54898, 48368, 62403, 49621])

In [195]:
torch.save(labels_train[1:], 'labels_train_semkitti.pt')
torch.save(soa_train[1:], 'soa_train_semkitti.pt')
torch.save(feat_train[1:], 'feat_train_semkitti.pt')
torch.save(torch.tensor(vox_train[1:]), 'voxels_train_semkitti.pt')

In [197]:
labels, soa, feat, vox = fe.test(val_loader, 5500000) # With the nuscenes hyperparameters...

SoA testing: 1it [00:03,  3.46s/it]

torch.Size([2, 70000])
torch.Size([2, 70000])
torch.Size([2, 768, 70000])


SoA testing: 2it [00:06,  3.17s/it]

torch.Size([3, 70000])
torch.Size([3, 70000])
torch.Size([3, 768, 70000])


SoA testing: 3it [00:09,  3.24s/it]

torch.Size([4, 70000])
torch.Size([4, 70000])
torch.Size([4, 768, 70000])


SoA testing: 4it [00:13,  3.51s/it]

torch.Size([5, 70000])
torch.Size([5, 70000])
torch.Size([5, 768, 70000])


SoA testing: 5it [00:17,  3.55s/it]

torch.Size([6, 70000])
torch.Size([6, 70000])
torch.Size([6, 768, 70000])


SoA testing: 6it [00:21,  3.85s/it]

torch.Size([7, 70000])
torch.Size([7, 70000])
torch.Size([7, 768, 70000])


SoA testing: 7it [00:25,  3.97s/it]

torch.Size([8, 70000])
torch.Size([8, 70000])
torch.Size([8, 768, 70000])


SoA testing: 8it [00:30,  4.23s/it]

torch.Size([9, 70000])
torch.Size([9, 70000])
torch.Size([9, 768, 70000])


SoA testing: 9it [00:35,  4.37s/it]

torch.Size([10, 70000])
torch.Size([10, 70000])
torch.Size([10, 768, 70000])


SoA testing: 10it [00:39,  4.36s/it]

torch.Size([11, 70000])
torch.Size([11, 70000])
torch.Size([11, 768, 70000])


SoA testing: 11it [00:44,  4.35s/it]

torch.Size([12, 70000])
torch.Size([12, 70000])
torch.Size([12, 768, 70000])


SoA testing: 12it [00:49,  4.67s/it]

torch.Size([13, 70000])
torch.Size([13, 70000])
torch.Size([13, 768, 70000])


SoA testing: 13it [00:53,  4.49s/it]

torch.Size([14, 70000])
torch.Size([14, 70000])
torch.Size([14, 768, 70000])


SoA testing: 14it [00:58,  4.55s/it]

torch.Size([15, 70000])
torch.Size([15, 70000])
torch.Size([15, 768, 70000])


SoA testing: 15it [01:04,  4.99s/it]

torch.Size([16, 70000])
torch.Size([16, 70000])
torch.Size([16, 768, 70000])


SoA testing: 16it [01:08,  4.69s/it]

torch.Size([17, 70000])
torch.Size([17, 70000])
torch.Size([17, 768, 70000])


SoA testing: 17it [01:12,  4.60s/it]

torch.Size([18, 70000])
torch.Size([18, 70000])
torch.Size([18, 768, 70000])


SoA testing: 18it [01:18,  4.96s/it]

torch.Size([19, 70000])
torch.Size([19, 70000])
torch.Size([19, 768, 70000])


SoA testing: 19it [01:23,  4.97s/it]

torch.Size([20, 70000])
torch.Size([20, 70000])
torch.Size([20, 768, 70000])


SoA testing: 20it [01:28,  5.13s/it]

torch.Size([21, 70000])
torch.Size([21, 70000])
torch.Size([21, 768, 70000])


SoA testing: 21it [01:34,  5.36s/it]

torch.Size([22, 70000])
torch.Size([22, 70000])
torch.Size([22, 768, 70000])


SoA testing: 22it [01:40,  5.32s/it]

torch.Size([23, 70000])
torch.Size([23, 70000])
torch.Size([23, 768, 70000])


SoA testing: 23it [01:46,  5.61s/it]

torch.Size([24, 70000])
torch.Size([24, 70000])
torch.Size([24, 768, 70000])


SoA testing: 24it [01:53,  6.12s/it]

torch.Size([25, 70000])
torch.Size([25, 70000])
torch.Size([25, 768, 70000])


SoA testing: 25it [01:59,  6.02s/it]

torch.Size([26, 70000])
torch.Size([26, 70000])
torch.Size([26, 768, 70000])


SoA testing: 26it [02:06,  6.47s/it]

torch.Size([27, 70000])
torch.Size([27, 70000])
torch.Size([27, 768, 70000])


SoA testing: 27it [02:13,  6.63s/it]

torch.Size([28, 70000])
torch.Size([28, 70000])
torch.Size([28, 768, 70000])


SoA testing: 28it [02:22,  7.30s/it]

torch.Size([29, 70000])
torch.Size([29, 70000])
torch.Size([29, 768, 70000])


SoA testing: 29it [02:29,  6.98s/it]

torch.Size([30, 70000])
torch.Size([30, 70000])
torch.Size([30, 768, 70000])


SoA testing: 30it [02:37,  7.28s/it]

torch.Size([31, 70000])
torch.Size([31, 70000])
torch.Size([31, 768, 70000])


SoA testing: 31it [02:43,  7.07s/it]

torch.Size([32, 70000])
torch.Size([32, 70000])
torch.Size([32, 768, 70000])


SoA testing: 32it [02:51,  7.35s/it]

torch.Size([33, 70000])
torch.Size([33, 70000])
torch.Size([33, 768, 70000])


SoA testing: 33it [03:00,  7.84s/it]

torch.Size([34, 70000])
torch.Size([34, 70000])
torch.Size([34, 768, 70000])


SoA testing: 34it [03:07,  7.41s/it]

torch.Size([35, 70000])
torch.Size([35, 70000])
torch.Size([35, 768, 70000])


SoA testing: 35it [03:16,  8.01s/it]

torch.Size([36, 70000])
torch.Size([36, 70000])
torch.Size([36, 768, 70000])


SoA testing: 36it [03:25,  8.38s/it]

torch.Size([37, 70000])
torch.Size([37, 70000])
torch.Size([37, 768, 70000])


SoA testing: 37it [03:32,  7.90s/it]

torch.Size([38, 70000])
torch.Size([38, 70000])
torch.Size([38, 768, 70000])


SoA testing: 38it [03:39,  7.56s/it]

torch.Size([39, 70000])
torch.Size([39, 70000])
torch.Size([39, 768, 70000])


SoA testing: 39it [03:47,  7.84s/it]

torch.Size([40, 70000])
torch.Size([40, 70000])
torch.Size([40, 768, 70000])


SoA testing: 40it [03:55,  7.92s/it]

torch.Size([41, 70000])
torch.Size([41, 70000])
torch.Size([41, 768, 70000])


SoA testing: 41it [04:08,  9.23s/it]

torch.Size([42, 70000])
torch.Size([42, 70000])
torch.Size([42, 768, 70000])


SoA testing: 42it [04:19,  9.76s/it]

torch.Size([43, 70000])
torch.Size([43, 70000])
torch.Size([43, 768, 70000])


SoA testing: 43it [04:28,  9.60s/it]

torch.Size([44, 70000])
torch.Size([44, 70000])
torch.Size([44, 768, 70000])


SoA testing: 44it [04:38,  9.67s/it]

torch.Size([45, 70000])
torch.Size([45, 70000])
torch.Size([45, 768, 70000])


SoA testing: 45it [04:49, 10.02s/it]

torch.Size([46, 70000])
torch.Size([46, 70000])
torch.Size([46, 768, 70000])


SoA testing: 46it [04:58,  9.90s/it]

torch.Size([47, 70000])
torch.Size([47, 70000])
torch.Size([47, 768, 70000])


SoA testing: 47it [05:13, 11.25s/it]

torch.Size([48, 70000])
torch.Size([48, 70000])
torch.Size([48, 768, 70000])


SoA testing: 48it [05:26, 11.80s/it]

torch.Size([49, 70000])
torch.Size([49, 70000])
torch.Size([49, 768, 70000])


SoA testing: 49it [05:38, 11.92s/it]

torch.Size([50, 70000])
torch.Size([50, 70000])
torch.Size([50, 768, 70000])


SoA testing: 50it [05:47, 11.18s/it]

torch.Size([51, 70000])
torch.Size([51, 70000])
torch.Size([51, 768, 70000])


SoA testing: 51it [05:58, 11.17s/it]

torch.Size([52, 70000])
torch.Size([52, 70000])
torch.Size([52, 768, 70000])


SoA testing: 52it [06:09, 10.85s/it]

torch.Size([53, 70000])
torch.Size([53, 70000])
torch.Size([53, 768, 70000])


SoA testing: 53it [06:21, 11.23s/it]

torch.Size([54, 70000])
torch.Size([54, 70000])
torch.Size([54, 768, 70000])


SoA testing: 54it [06:31, 10.95s/it]

torch.Size([55, 70000])
torch.Size([55, 70000])
torch.Size([55, 768, 70000])


SoA testing: 55it [06:42, 11.00s/it]

torch.Size([56, 70000])
torch.Size([56, 70000])
torch.Size([56, 768, 70000])


SoA testing: 56it [06:52, 10.68s/it]

torch.Size([57, 70000])
torch.Size([57, 70000])
torch.Size([57, 768, 70000])


SoA testing: 57it [07:03, 10.91s/it]

torch.Size([58, 70000])
torch.Size([58, 70000])
torch.Size([58, 768, 70000])


SoA testing: 58it [07:14, 10.91s/it]

torch.Size([59, 70000])
torch.Size([59, 70000])
torch.Size([59, 768, 70000])


SoA testing: 59it [07:26, 11.27s/it]

torch.Size([60, 70000])
torch.Size([60, 70000])
torch.Size([60, 768, 70000])


SoA testing: 60it [07:37, 11.18s/it]

torch.Size([61, 70000])
torch.Size([61, 70000])
torch.Size([61, 768, 70000])


SoA testing: 61it [07:50, 11.64s/it]

torch.Size([62, 70000])
torch.Size([62, 70000])
torch.Size([62, 768, 70000])


SoA testing: 62it [08:01, 11.33s/it]

torch.Size([63, 70000])
torch.Size([63, 70000])
torch.Size([63, 768, 70000])


SoA testing: 63it [08:14, 11.82s/it]

torch.Size([64, 70000])
torch.Size([64, 70000])
torch.Size([64, 768, 70000])


SoA testing: 64it [08:25, 11.78s/it]

torch.Size([65, 70000])
torch.Size([65, 70000])
torch.Size([65, 768, 70000])


SoA testing: 65it [08:39, 12.38s/it]

torch.Size([66, 70000])
torch.Size([66, 70000])
torch.Size([66, 768, 70000])


SoA testing: 66it [08:51, 12.12s/it]

torch.Size([67, 70000])
torch.Size([67, 70000])
torch.Size([67, 768, 70000])


SoA testing: 67it [09:04, 12.49s/it]

torch.Size([68, 70000])
torch.Size([68, 70000])
torch.Size([68, 768, 70000])


SoA testing: 68it [09:17, 12.58s/it]

torch.Size([69, 70000])
torch.Size([69, 70000])
torch.Size([69, 768, 70000])


SoA testing: 69it [09:31, 13.13s/it]

torch.Size([70, 70000])
torch.Size([70, 70000])
torch.Size([70, 768, 70000])


SoA testing: 70it [09:44, 13.12s/it]

torch.Size([71, 70000])
torch.Size([71, 70000])
torch.Size([71, 768, 70000])


SoA testing: 71it [09:58, 13.39s/it]

torch.Size([72, 70000])
torch.Size([72, 70000])
torch.Size([72, 768, 70000])


SoA testing: 72it [10:11, 13.02s/it]

torch.Size([73, 70000])
torch.Size([73, 70000])
torch.Size([73, 768, 70000])


SoA testing: 73it [10:24, 13.09s/it]

torch.Size([74, 70000])
torch.Size([74, 70000])
torch.Size([74, 768, 70000])


SoA testing: 74it [10:38, 13.35s/it]

torch.Size([75, 70000])
torch.Size([75, 70000])
torch.Size([75, 768, 70000])


SoA testing: 75it [10:53, 13.87s/it]

torch.Size([76, 70000])
torch.Size([76, 70000])
torch.Size([76, 768, 70000])


SoA testing: 76it [11:06, 13.68s/it]

torch.Size([77, 70000])
torch.Size([77, 70000])
torch.Size([77, 768, 70000])


SoA testing: 77it [11:21, 14.00s/it]

torch.Size([78, 70000])
torch.Size([78, 70000])
torch.Size([78, 768, 70000])


SoA testing: 78it [11:34, 13.68s/it]

torch.Size([79, 70000])
torch.Size([79, 70000])
torch.Size([79, 768, 70000])


SoA testing: 79it [11:49, 14.11s/it]

torch.Size([80, 70000])
torch.Size([80, 70000])
torch.Size([80, 768, 70000])


SoA testing: 80it [12:02, 13.72s/it]

torch.Size([81, 70000])
torch.Size([81, 70000])
torch.Size([81, 768, 70000])


SoA testing: 81it [12:18, 14.50s/it]

torch.Size([82, 70000])
torch.Size([82, 70000])
torch.Size([82, 768, 70000])


SoA testing: 82it [12:31, 14.16s/it]

torch.Size([83, 70000])
torch.Size([83, 70000])
torch.Size([83, 768, 70000])


SoA testing: 83it [12:46, 14.33s/it]

torch.Size([84, 70000])
torch.Size([84, 70000])
torch.Size([84, 768, 70000])


SoA testing: 84it [13:00, 14.31s/it]

torch.Size([85, 70000])
torch.Size([85, 70000])
torch.Size([85, 768, 70000])


SoA testing: 85it [13:16, 14.85s/it]

torch.Size([86, 70000])
torch.Size([86, 70000])
torch.Size([86, 768, 70000])


SoA testing: 86it [13:31, 14.62s/it]

torch.Size([87, 70000])
torch.Size([87, 70000])
torch.Size([87, 768, 70000])


SoA testing: 87it [13:46, 14.87s/it]

torch.Size([88, 70000])
torch.Size([88, 70000])
torch.Size([88, 768, 70000])


SoA testing: 88it [14:00, 14.62s/it]

torch.Size([89, 70000])
torch.Size([89, 70000])
torch.Size([89, 768, 70000])


SoA testing: 89it [14:15, 14.74s/it]

torch.Size([90, 70000])
torch.Size([90, 70000])
torch.Size([90, 768, 70000])


SoA testing: 90it [14:29, 14.43s/it]

torch.Size([91, 70000])
torch.Size([91, 70000])
torch.Size([91, 768, 70000])


SoA testing: 91it [14:44, 14.60s/it]

torch.Size([92, 70000])
torch.Size([92, 70000])
torch.Size([92, 768, 70000])


SoA testing: 92it [14:59, 14.78s/it]

torch.Size([93, 70000])
torch.Size([93, 70000])
torch.Size([93, 768, 70000])


SoA testing: 93it [15:16, 15.42s/it]

torch.Size([94, 70000])
torch.Size([94, 70000])
torch.Size([94, 768, 70000])


SoA testing: 94it [15:31, 15.29s/it]

torch.Size([95, 70000])
torch.Size([95, 70000])
torch.Size([95, 768, 70000])


SoA testing: 95it [15:48, 15.80s/it]

torch.Size([96, 70000])
torch.Size([96, 70000])
torch.Size([96, 768, 70000])


SoA testing: 96it [16:03, 15.51s/it]

torch.Size([97, 70000])
torch.Size([97, 70000])
torch.Size([97, 768, 70000])


SoA testing: 97it [16:19, 15.68s/it]

torch.Size([98, 70000])
torch.Size([98, 70000])
torch.Size([98, 768, 70000])


SoA testing: 98it [16:34, 15.68s/it]

torch.Size([99, 70000])
torch.Size([99, 70000])
torch.Size([99, 768, 70000])


SoA testing: 99it [16:50, 15.80s/it]

torch.Size([100, 70000])
torch.Size([100, 70000])
torch.Size([100, 768, 70000])


SoA testing: 100it [17:06, 15.85s/it]

torch.Size([101, 70000])
torch.Size([101, 70000])
torch.Size([101, 768, 70000])


SoA testing: 100it [17:25, 10.45s/it]

torch.Size([102, 70000])
torch.Size([102, 70000])
torch.Size([102, 768, 70000])
Pred FE tensor([14., 14., 14.,  ..., 14., 14., 14.]) 	Shape:  torch.Size([5416264])
Label tensor([14., 14., 14.,  ..., 14., 14., 14.]) 	Shape:  torch.Size([5416264])





accuracy: tensor([0.9212, 0.3168, 0.3117, 0.8122, 0.3878, 0.5392, 0.6542, 0.0000, 0.8718,
        0.2367, 0.6637, 0.0035, 0.8606, 0.3714, 0.8657, 0.5617, 0.6887, 0.5033,
        0.4813])
avg acc: 0.5290325284004211


# Check if the two models have the same weights?

In [71]:
# SemanticKitti
fe.model

Segmenter(
  (embed): Embedding(
    (norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): Conv1d(5, 768, kernel_size=(1,), stride=(1,))
    (conv2): Sequential(
      (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv2d(5, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (final): Conv1d(1536, 768, kernel_size=(1,), stride=(1,))
  )
  (waffleiron): WaffleIron(
    (channel_mix): ModuleList(
      (0-47): 48 x ChannelMix(
        (norm): myLayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Conv1d(768, 768, kernel_size=(1,), stride=(1,))
          (1): ReLU(inplace=True)
          (2): Conv1d(768, 768, kernel_size=(1,), stride=(1,))
        )
 

In [53]:
fe_2 = Feature_Extractor(nb_class=16)
fe_2.load_pretrained('/root/main/ScaLR/saved_models/ckpt_last_scalr.pth')

  checkpoint = torch.load(path_to_ckpt,


Checkpoint loaded on cuda:0: /root/main/ScaLR/saved_models/ckpt_last_scalr.pth


In [72]:
import torch
from torch.nn import Module

# Assuming `model1` and `model2` are your two models
def compare_model_weights(model1: Module, model2: Module):
    for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()):
        if name1 != name2:
            print(f"Parameter names differ: {name1} vs {name2}")
            return False
        if not torch.equal(param1.data, param2.data):
            print(f"Weights differ in {name1}")
            return False
    print("All weights are didentical!")
    return True

# Call the function with your two models
result = compare_model_weights(fe.model, fe_2.model)

Weights differ in classif.0.weight


In [96]:
x = torch.ones((5,))
F.pad(input=x, pad=(0, 5), mode='constant', value=0)

tensor([1., 1., 1., 1., 1., 0., 0., 0., 0., 0.])