In [12]:
import os
import torch
import pytorch_lightning as L
import matplotlib.pyplot as plt
from omegaconf import DictConfig

%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'

%load_ext autoreload
%autoreload 2

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_properties(0).total_memory)

# Get number of cores
print(os.cpu_count())

# Get number of threads
print(torch.get_num_threads())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
True
Tesla V100-SXM2-32GB
34072559616
72
36


In [15]:
import sklearn  # scikit-learn hack to fix the error on jetson

import torch
import hydra
import numpy as np
from PIL import Image
import albumentations as A
import pytorch_lightning as L
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from albumentations.pytorch import ToTensorV2
from pytorch_lightning.loggers import WandbLogger

from src import RoadDataModule, RoadModel, LogPredictionsCallback, val_checkpoint, regular_checkpoint, rgb_to_label


def main(cfg: DictConfig) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RoadModel(cfg, device)
    datamodule = RoadDataModule(cfg)

    wandb_logger = WandbLogger(project="road-segmentation", name=cfg.run_name)

    trainer = L.Trainer(max_epochs=cfg.train.max_epochs,
                        accelerator="gpu",
                        devices=1,
                        logger=wandb_logger,
                        callbacks=[
                            LogPredictionsCallback(),
                            val_checkpoint,
                            regular_checkpoint
                        ])

    if cfg.action == "train":
        trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
    elif cfg.action == "test":
        trainer.test(model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
    elif cfg.action == "predict":
        # Load the trained model
        model = RoadModel.load_from_checkpoint(cfg.ckpt_path, cfg=cfg, device=device).to(device)
        model.eval()

        # Load an image and its label
        # image_path = 'data/RUGD/Images/creek_00001.png'
        # label_path = 'data/RUGD/Annotations/creek_00001.png'        
        
        image_path = 'data/RoboTour/Images/tradr_20_5_2422.png'
        label_path = 'data/RoboTour/Annotations/tradr_20_5_2422.png'

        image = np.array(Image.open(image_path).convert('RGB')) / 255.0
        label = np.array(Image.open(label_path).convert('RGB'))

        # Process the label image
        label = rgb_to_label(label, cfg.ds.color_map)
        train_map = OmegaConf.to_container(cfg.ds.train_map)
        label = np.vectorize(train_map.get)(label)

        # Apply the same transformations as during training
        transform = A.Compose([
            A.Normalize(mean=cfg.ds.mean, std=cfg.ds.std, max_pixel_value=1.0),
            A.Resize(550, 688),
            ToTensorV2()
        ])
        sample = transform(image=image, mask=label)
        image = sample['image'].float().unsqueeze(0).to(device)
        label = sample['mask'].long().unsqueeze(0).to(device)

        # Predict the label image
        with torch.no_grad():
            logits = model(image)
        prediction = logits.argmax(1).squeeze(0).cpu().numpy()

        # Plot the image, label, and prediction
        fig = plt.figure(figsize=(12, 4))

        # Plot the prediction next to the label
        plt.subplot(1, 3, 1)
        plt.imshow(image[0].permute(1, 2, 0).cpu().numpy())
        plt.axis('off')  # Remove axes

        plt.subplot(1, 3, 2)
        plt.imshow(label[0].cpu().numpy())
        plt.axis('off')  # Remove axes

        plt.subplot(1, 3, 3)
        plt.imshow(prediction)
        plt.axis('off')  # Remove axes

        plt.subplots_adjust(wspace=0, hspace=0)
        plt.tight_layout()

        # Save the plot
        plt.savefig('prediction.png')
    else:
        raise ValueError(f"Unknown action: {cfg.action}")

In [19]:
from hydra import compose, initialize

with initialize(version_base=None, config_path="conf"):
    # Make sure that the cfg.ds is 'robotour' 
    config = compose(config_name="config", overrides=["action=train", "ds=robotour", "ckpt_path=checkpoints/e51-iou0.60.ckpt"])
    print(config)
    main(config)

{'action': 'train', 'ckpt_path': 'checkpoints/e51-iou0.60.ckpt', 'run_name': 'baseline', 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.0001, 'weight_decay': 0.0001}, 'train': {'batch_size': 4, 'max_epochs': 100, 'num_workers': 4}, 'ds': {'name': 'rugd', 'path': 'data/RoboTour', 'mean': [0.0, 0.0, 0.0], 'std': [1.0, 1.0, 1.0], 'color_map': {'0,0,0': 0, '0,255,0': 1, '255,0,0': 2, '0,0,255': 3}, 'train_map': {0: 0, 1: 1, 2: 2, 3: 3}}, 'model': {'_target_': 'torchvision.models.segmentation.deeplabv3_resnet50', 'num_classes': 4}}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at checkpoints/e51-iou0.60.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | DeepLabV3              | 39.6 M
1 | criterion | CrossEntropyLoss       | 0     
2 | accuracy  | MulticlassAccuracy     | 0     
3 | jaccard   | MulticlassJaccardIndex | 0     
-----------------------------------------------------
39.6 M    Trainable params
0         Non-trainable params
39.6 M    Total params
158.538   Total estimated model params size (MB)
Restored all states from the checkpoint at checkpoints/e51-iou0.60.ckpt
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
