A quick eval script I adapted from their code:

Notes:
I trained it for 13 epochs, turns out best model by evaluation is in 5th epoch, so thats the model we will use.

How to use:

1. Change `os.chdir("/home/richard/workspaces/VML/wayfaster")`
2. Change which datapoint to evaluate by `EVAL_BATCH_IDX = 0`. It's unshuffled (in order). I made batch_size = 1, so it's just a single datapoint (sequence of images, the sequence length is 6)

Visualization outputs:
- You could check generated imgs in wandb folder, or go to wandb online.

  - Input visualization:
    - eval_depth_target is the input depth image but downsampled.
    - eval_pcloud is the input point cloud, transformed from the input depth image.
  - Output visualizatrion: eval_mu, eval_nu is the linear,angular traction coefficient.

  - The others are for Depth predictions (part of the architecture|)

In [16]:
import torch
import pytorch_lightning as pl
import os

from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger

# Custom packages
from train.dataloader import Dataset
from train.train_configs import get_cfg
from train.trainer import TrainingModule
from models.traversability_net import TravNet
from train.utils import path_to_map

# Change working directory to project level
os.chdir("/home/richard/workspaces/VML/wayfaster")
print(os.getcwd())
CONFIG_FILE_PATH = "configs/temporal_model.yaml"


def parse_config():
    # Load default configs and merge with args
    config = get_cfg(CONFIG_FILE_PATH)
    return config


configs = parse_config()
print("configs:\n", configs)
pl.seed_everything(configs.SEED, workers=True)

train_dataset = Dataset(configs, configs.DATASET.TRAIN_DATA)
valid_dataset = Dataset(configs, configs.DATASET.VALID_DATA, weights=train_dataset.weights)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    drop_last=True,
    pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    drop_last=True,
    pin_memory=True,
)

# Initialize model and logger (it's pl.lightining module, model.model is the actual traversability network)
model = TrainingModule(configs)
wandb_logger = WandbLogger(
    project="WayFASTER",
    log_model="all",
)

# Load a previously trained network
if configs.MODEL.LOAD_NETWORK is not None:
    print("Loading saved network from {}".format(configs.MODEL.LOAD_NETWORK))
    pretrained_dict = torch.load(configs.MODEL.LOAD_NETWORK, map_location="cpu")["state_dict"]
    model.load_state_dict(pretrained_dict)

Seed set to 42


/home/richard/workspaces/VML/wayfaster
configs:
 {'TAG': 'temporal', 'TRAINING': {'EPOCHS': 20, 'BATCHSIZE': 4, 'WORKERS': 4, 'PRECISION': 16, 'DT': 0.1, 'HORIZON': 300, 'GAMMA': 1.0, 'DEPTH_WEIGHT': 0.1, 'VIS_INTERVAL': 500, 'VERBOSE': False}, 'MODEL': {'LOAD_NETWORK': 'checkpoints/checkpoint_epoch=5-valid_loss=0.2139.ckpt', 'DOWNSAMPLE': 8, 'LATENT_DIM': 64, 'TIME_LENGTH': 6, 'PREDICT_DEPTH': True, 'TRAIN_DEPTH': True, 'FUSE_PCLOUD': True, 'INPUT_SIZE': [320, 180], 'GRID_BOUNDS': {'xbound': [-2.0, 8.0, 0.1], 'ybound': [-5.0, 5.0, 0.1], 'zbound': [-1.0, 2.0, 0.2], 'dbound': [0.3, 8.0, 0.2]}}, 'OPTIMIZER': {'LR': 0.0001, 'WEIGHT_DECAY': 0.0001}, 'DATASET': {'TRAIN_DATA': ['dataset/zed2/data_train', 'dataset/realsense/data_train'], 'VALID_DATA': ['dataset/zed2/data_valid', 'dataset/realsense/data_valid'], 'CSV_FILE': 'rosbags.csv'}, 'AUGMENTATIONS': {'HORIZ_FLIP': 0.5, 'PCLOUD_DROPOUT': 0.3, 'MAX_TRANSLATION': 0.0, 'MAX_ROTATION': 0.0}, 'SEED': 42}
Initializing dataset...
Dataset initia



Loading saved network from checkpoints/checkpoint_epoch=5-valid_loss=0.2139.ckpt


Visualization code


In [17]:
def visualize_results(
    logger, model, image, pcloud, trav_map, pred_depth, depth_target, depth_mask, debug, executed_path, prefix="eval"
):
    """
    Visualize the traversability network results.

    Args:
        image (torch.Tensor): Tensor containing the images.
        pcloud (torch.Tensor): Tensor containing the point clouds.
        trav_map (torch.Tensor): Tensor containing the traversability map.
        pred_depth (torch.Tensor): Tensor containing the predicted depth.
        depth_target (torch.Tensor): Tensor containing the target depth.
        depth_mask (torch.Tensor): Tensor containing the depth mask.
        debug (torch.Tensor): Tensor containing the debug information.
        executed_path (torch.Tensor): Tensor containing the executed path.
        prefix (str): Prefix for the log keys.
    """
    # Visualize the camera inputs
    logger.log_image(key=prefix + "_images", images=[image.view(-1, *image.shape[2:])])

    # Visualize the input point cloud
    pcloud = torch.mean(pcloud, dim=2, keepdim=True)
    pcloud = pcloud.view(-1, *pcloud.shape[2:])
    logger.log_image(key=prefix + "_pcloud", images=[pcloud])

    # Visualize the traversability map
    logger.log_image(key=prefix + "_mu", images=[trav_map[:, :1]])

    logger.log_image(key=prefix + "_nu", images=[trav_map[:, 1:]])

    # Visualize the depth prediction
    n_d = (model.grid_bounds["dbound"][1] - model.grid_bounds["dbound"][0]) / model.grid_bounds["dbound"][2]
    depth_pred = torch.argmax(pred_depth, dim=1, keepdim=True) / (n_d - 1)
    logger.log_image(key=prefix + "_depth_pred", images=[depth_pred])

    # Visualize the depth target
    if model.predict_depth:
        depth_target = depth_target.unsqueeze(1) / (n_d - 1)
    else:
        depth_target = depth_target.argmax(1).unsqueeze(1) / (n_d - 1)

    logger.log_image(key=prefix + "_depth_target", images=[depth_target])

    # Visualize the depth mask
    logger.log_image(key=prefix + "_depth_mask", images=[depth_mask.unsqueeze(1)])

    # Visualize the debug output
    temp = torch.sum(debug, dim=1, keepdim=True)
    temp = (temp - torch.min(temp)) / (torch.max(temp) - torch.min(temp))
    logger.log_image(key=prefix + "_debug", images=[temp])

    # Visualize the executed path
    executed_path = executed_path / (torch.amax(executed_path, (1, 2, 3), keepdim=True) + model.eps)
    logger.log_image(key=prefix + "_path", images=[executed_path])

Evaluation code:



In [None]:
# Evaluate traversability network model
trav_model: TravNet = model.model
EVAL_BATCH_IDX = 0
eval_batch = None

for batch_idx, batch in enumerate(train_loader):
    if batch_idx == EVAL_BATCH_IDX:
        print(f"Found batch {EVAL_BATCH_IDX}")
        eval_batch = batch
        break

with torch.no_grad():
    trav_model.eval()

    # Get data
    color_img, pcloud, inv_intrinsics, extrinsics, path, target_trav, trav_weights, depth_target, depth_mask = eval_batch  # fmt: skip

    # Forward pass
    print("Forward Pass.")
    trav_map, pred_depth, debug = trav_model(color_img, pcloud, inv_intrinsics, extrinsics, depth_target)

    # Project path to map
    print("Projecting path to map.")
    executed_path = path_to_map(
        path.unsqueeze(1),
        torch.ones_like(path[..., 0, 0]).unsqueeze(1),
        model.map_size,
        model.map_resolution,
        model.map_origin,
    )

    # Calculate traversability loss and error
    trav_loss, trav_error = model.trav_criterion(path, trav_map, target_trav, trav_weights)

    # Calculate depth classification loss
    depth_target = depth_target.view(-1, *depth_target.shape[2:])
    depth_mask = depth_mask.view(-1, *depth_mask.shape[2:])
    depth_loss = model.depth_criterion(pred_depth, depth_target, depth_mask)

    if model.train_depth:
        loss = trav_loss + model.depth_weight * depth_loss
    else:
        loss = trav_loss

    # Visualize results
    print("Logging output to Wandb")
    visualize_results(
        wandb_logger,
        model,
        color_img,
        pcloud,
        trav_map,
        pred_depth,
        depth_target,
        depth_mask,
        debug,
        executed_path,
        prefix="eval",
    )

    # Print
    print("eval_loss ", loss.item())
    print("eval_trav_loss ", trav_loss.item())
    print("eval_trav_error ", trav_error.item())
    print("eval_depth_loss ", depth_loss.item())

Found batch 0
Forward Pass.
Projecting path to map.
Logging output to Wandb
eval_loss  tensor(0.0481)
eval_trav_loss  tensor(0.0187)
eval_trav_error  tensor(0.0258)
eval_depth_loss  tensor(0.2941)
