[![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/tcapelle/aws_smsl_demo/blob/main/01_data_processing.ipynb)

In [1]:
import wandb
from fastai.vision.all import *

In [2]:
path = untar_data(URLs.CAMVID)
codes = np.loadtxt(path/'codes.txt', dtype=str)
fnames = get_image_files(path/"images")
class_labels = {k: v for k, v in enumerate(codes)}

In [3]:
PROJECT="CamVid"
ENTITY="av-demo"
ARTIFACT_ID = "av-demo/CamVid/camvid-dataset:latest"

In [4]:
with wandb.init(
    project=PROJECT,
    name="upload_camvid_final",
    entity=ENTITY,
    job_type="upload",
):
    artifact = wandb.Artifact(
        'camvid-dataset',
        type='dataset',
        metadata={
            "url": URLs.CAMVID,
            "class_labels": class_labels
        },
        description="The Cambridge-driving Labeled Video Database (CamVid) is the first collection of videos with object class semantic labels, complete with metadata. The database provides ground truth labels that associate each pixel with one of 32 semantic classes."
    )
    artifact.add_dir(path)
    wandb.log_artifact(artifact)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Adding directory to artifact (/home/paperspace/.fastai/data/camvid)... Done. 3.5s


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [5]:
def label_func(fn):
    return fn.parent.parent/"labels"/f"{fn.stem}_P{fn.suffix}"

In [6]:
def get_frequency_distribution(mask_data):
    (unique, counts) = np.unique(mask_data, return_counts=True)
    unique = list(unique)
    counts = list(counts)
    frequency_dict = {}
    for _class in class_labels.keys():
        if _class in unique:
            frequency_dict[class_labels[_class]] = counts[unique.index(_class)]
        else:
            frequency_dict[class_labels[_class]] = 0
    return frequency_dict

In [7]:
def log_dataset():
    with wandb.init(
        project=PROJECT,
        name="visualize_camvid",
        entity=ENTITY,
        job_type="data_viz"
    ):
        artifact = wandb.use_artifact(ARTIFACT_ID, type='dataset')
        artifact_dir = artifact.download()
        
        table_data = []
        image_files = get_image_files(Path(artifact_dir)/"images")
        labels = [str(class_labels[_lab]) for _lab in list(class_labels)]
        
        print("Creating Table...")
        for image_file in progress_bar(image_files):
            image = np.array(Image.open(image_file))
            mask_data = np.array(Image.open(label_func(image_file)))
            frequency_distribution = get_frequency_distribution(mask_data)
            table_data.append(
                [
                    str(image_file.name),
                    wandb.Image(image),
                    wandb.Image(image, masks={
                        "predictions": {
                            "mask_data": mask_data,
                            "class_labels": class_labels
                        }
                    })
                ] + [
                    frequency_distribution[_lab] for _lab in labels
                ]
            )
        wandb.log({
            "CamVid_Dataset": wandb.Table(
                data=table_data,
                columns=["File_Name", "Images", "Segmentation_Masks"] + labels
            )
        })

In [8]:
log_dataset()

[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Downloading large artifact camvid-dataset:latest, 1210.65MB. 1409 files... Done. 0:0:0


Creating Table...


VBox(children=(Label(value=' 570.27MB of 570.27MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…