In [3]:
# This file will run through an entire dataset to report Accuracy and IoU.

# Import Libraries
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from matplotlib.patches import Patch

from Lseg.lseg_trainer import LSegModule
from Lseg.lseg_net import LSegNet

from Lseg.data.util import get_labels, get_dataset
import torchvision.transforms as transforms
from PIL import Image
from torchmetrics import Accuracy, JaccardIndex
import tqdm

# METRICS
NUM_CLASSES = 195
accuracy_fn = Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device="cuda")
iou_fn = JaccardIndex(task="multiclass", num_classes=NUM_CLASSES).to(device="cuda")

# Labels
labels = get_labels()

config = {
    "batch_size": 2,  # 6
    "base_lr": 0.04,
    "max_epochs": 50,
    "num_features": 512,
}

net = LSegNet(
    labels=labels,
    features=config["num_features"],
)

# Load Model - replace with actual
load_checkpoint_path = r"checkpoints/checkpoint_epoch=0-val_loss=4.7304.ckpt"
model = LSegModule.load_from_checkpoint(
    load_checkpoint_path,
    max_epochs=config["max_epochs"],
    model=net,
    num_classes=len(labels),
    batch_size=config["batch_size"],
    base_lr=config["base_lr"],
)
model = model.to(device="cuda").float()

# Load ADE20K validation dataset
test_dataset = get_dataset(dataset_name="ade20k", get_train=False)
test_dataloaders = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=8)

# Evaluate accuracy and IoU
mean_accuracy = 0.0
mean_iou = 0.0

model.eval()
with torch.no_grad():
    for X, y in tqdm.tqdm(test_dataloaders):
        X = X.to(device="cuda").float()
        y = y.to(device="cuda").float()
        output = model(X)
        prediction = torch.argmax(output, dim=1)
        mean_accuracy += accuracy_fn(prediction, y).item()
        mean_iou += iou_fn(prediction, y).item()

mean_accuracy = mean_accuracy / len(test_dataloaders)
mean_iou = mean_iou / len(test_dataloaders)
print(f"Number of examples: {len(test_dataset)}")
print(f"Accuracy: {mean_accuracy}")
print(f"IoU: {mean_iou}")




	Mapping ade20k-150 -> universal
	Mapping bdd -> universal
	Mapping cityscapes-19 -> universal
	Mapping coco-panoptic-133 -> universal
	Mapping idd-39 -> universal
	Mapping mapillary-public65 -> universal
	Mapping sunrgbd-37 -> universal
	Mapping ade20k-150-relabeled -> universal
	Mapping bdd-relabeled -> universal
	Mapping cityscapes-19-relabeled -> universal
	Mapping cityscapes-34-relabeled -> universal
	Mapping coco-panoptic-133-relabeled -> universal
	Mapping idd-39-relabeled -> universal
	Mapping mapillary-public65-relabeled -> universal
	Mapping sunrgbd-37-relabeled -> universal

	Creating 1x1 conv for test datasets...
Totally 20210 samples in val set.
Checking image&label pair val list done!
image folder path: data/mseg_dataset/ADE20K/
text path: mseg-api/mseg/dataset_lists/ade20k-150-relabeled/list/train.txt


100%|██████████| 10105/10105 [43:44<00:00,  3.85it/s]

Number of examples: 20210
Accuracy: 0.5836145441590743
IoU: 0.18291134191472844



