In [1]:
from pathlib import Path
from types import SimpleNamespace
import tensorflow as tf

import wandb

from data_loader import DataLoader
from utils.callbacks import TensorBoard
from utils.util import *
from utils.args_loader import load_model_config

2022-09-09 13:28:48.072932: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-09-09 13:28:48.072970: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [23]:
data_path = Path("/mnt/disks/KITTI/small/")
data_path.exists()

True

In [24]:
arg = SimpleNamespace(model="squeezesegv2",
                      config="squeezesegv2kitti",
                      data_path=data_path,
                      train_dir="../output",
                      epochs=10)   

config, model = load_model_config(arg.model, arg.config)

In [25]:
train_dl = DataLoader("train", arg.data_path, config).write_tfrecord_dataset().read_tfrecord_dataset()
val_dl = DataLoader("val", arg.data_path, config).write_tfrecord_dataset().read_tfrecord_dataset()

TFRecord exists at /mnt/disks/KITTI/small/train.tfrecord. Skipping TFRecord writing.
TFRecord exists at /mnt/disks/KITTI/small/val.tfrecord. Skipping TFRecord writing.


In [26]:
at = wandb.Artifact("KITTI_nano_tfrecord", 
                    type="dataset", 
                    description="A nano version of KITTI with only 40/10 samples")

In [28]:
# at.add_dir(data_path)

In [29]:
# with wandb.init(project="small_kitti", entity="av-team", job_type="log_dataset"):
#     nano_kitti = wandb.use_artifact("KITTI_nano:v0")
#     wandb.log_artifact(at)

## Log validation dataset

In [30]:
from utils.wandb import _create_row

In [31]:
class_color_map = model.CLS_COLOR_MAP
classes = model.CLASSES
class_map = {i:c  for i,c in enumerate(classes)}

In [32]:
class_map

{0: 'None',
 1: 'car',
 2: 'bicycle',
 3: 'motorcycle',
 4: 'truck',
 5: 'other-vehicle',
 6: 'person',
 7: 'bicyclist',
 8: 'motorcyclist',
 9: 'road',
 10: 'parking',
 11: 'sidewalk',
 12: 'other-ground',
 13: 'building',
 14: 'fence',
 15: 'vegetation',
 16: 'trunk',
 17: 'terrain',
 18: 'pole',
 19: 'traffic-sign'}

In [33]:
total_examples_seen = 0
for i, ((lidar_inputs, lidar_masks), labels, _) in enumerate(val_dl):
    lidar_inputs, lidar_masks, labels = lidar_inputs.numpy(), lidar_masks.numpy(), labels.numpy()
    if total_examples_seen > 1:
        print(f"Batch: {i}, examples: {total_examples_seen}")
        break
    total_examples_seen += lidar_inputs.shape[0]
    for lidar_input, lidar_mask, label in zip(lidar_inputs, lidar_masks, labels):
        label_image, depth_image, intensity_image, points_rgb = _create_row(lidar_input, 
                                                                            label,
                                                                            class_color_map)

Batch: 1, examples: 8


In [34]:
def get_pixel_count(mask_data, class_labels):
    (unique, counts) = np.unique(mask_data, return_counts=True)
    unique = list(unique)
    counts = list(counts)
    frequency_dict = {}
    for _class in class_labels.keys():
        if _class in unique:
            frequency_dict[class_labels[_class]] = counts[unique.index(_class)]
        else:
            frequency_dict[class_labels[_class]] = 0
    return frequency_dict

In [35]:
pixel_count = get_pixel_count(label, class_map)

In [36]:
pixel_count

{'None': 15281,
 'car': 760,
 'bicycle': 0,
 'motorcycle': 402,
 'truck': 0,
 'other-vehicle': 0,
 'person': 16,
 'bicyclist': 0,
 'motorcyclist': 0,
 'road': 7040,
 'parking': 0,
 'sidewalk': 4434,
 'other-ground': 0,
 'building': 823,
 'fence': 106,
 'vegetation': 30754,
 'trunk': 425,
 'terrain': 5397,
 'pole': 85,
 'traffic-sign': 13}

In [37]:
import pandas as pd
def compute_pixel_count(dataset):
    "Compute the histogram per class on the dataset"
    pixel_df = pd.DataFrame(columns=list(class_map.values()))
    freq = {k:0 for k in class_map.values()}
    for (lidar_inputs, lidar_masks), labels, _ in dataset:
        lidar_inputs, lidar_masks, labels = lidar_inputs.numpy(), lidar_masks.numpy(), labels.numpy()
        for i, (lidar_input, lidar_mask, label) in enumerate(zip(lidar_inputs, lidar_masks, labels)):
            pixel_count = get_pixel_count(label, class_map)
            pixel_df = pixel_df.append(freq, ignore_index=True)
            for k,v in pixel_count.items():
                freq[k] += (1 if v else 0)
    return freq, pixel_df

In [38]:
val_freq, val_pixel_df = compute_pixel_count(val_dl)
train_freq, train_pixel_df = compute_pixel_count(train_dl)

In [39]:
val_freq

{'None': 200,
 'car': 200,
 'bicycle': 91,
 'motorcycle': 40,
 'truck': 14,
 'other-vehicle': 100,
 'person': 92,
 'bicyclist': 47,
 'motorcyclist': 8,
 'road': 200,
 'parking': 59,
 'sidewalk': 200,
 'other-ground': 31,
 'building': 200,
 'fence': 184,
 'vegetation': 200,
 'trunk': 190,
 'terrain': 200,
 'pole': 200,
 'traffic-sign': 136}

In [40]:
def freq_plot(freq, title="Frequencies"):
    "Create Freq plot"
    df = pd.DataFrame(columns=class_map.values())
    df = df.append(freq, ignore_index=True)
    df = df.T.reset_index()
    df.columns = ["Detection-Classes", "Frequencies"]
    table = wandb.Table(dataframe=df)
    return wandb.plot.bar(table, "Detection-Classes", "Frequencies", title=title)    

In [41]:
with wandb.init(project="small_kitti", entity="av-team", job_type="data_viz"):
    wandb.use_artifact("KITTI_nano_tfrecord:v0")
    wandb.log(
        {"Class Frequencies" : freq_plot(val_freq, title="Val class frequencies")}
    )

VBox(children=(Label(value='0.001 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.780271…

In [42]:
with wandb.init(project="small_kitti", entity="av-team", job_type="data_viz"):
    wandb.use_artifact("KITTI_nano_tfrecord:v0")
    wandb.log(
        {"Class Frequencies" : freq_plot(train_freq, title="Train class frequencies")}
    )

VBox(children=(Label(value='0.001 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.774820…