# Webdataset reader example in rocAL
This example demonstrates how to set up a simple webdataset reader pipeline that loads and decodes image data stored in webdataset format using rocAL.

<font size="12"> Common Code </font>

In [None]:
import numpy as np
from amd.rocal.plugin.generic import ROCALClassificationIterator
from amd.rocal.pipeline import pipeline_def
import amd.rocal.fn as fn
import amd.rocal.types as types
import matplotlib.pyplot as plt
import os
%matplotlib inline

## Using fn.readers.webdataset operator

Data stored in WebDataset format is read using `fn.readers.webdataset` function. It takes the following arguments:

* `path`: the path to the webdataset tar files.
* `index_paths`: the path to the index files containing data about the tar files. If no path is provided, information is inferred from the tar files themselves. See `tar2idx` in the `tools` folder.
* `ext`: extensions of the types of data in the tar files. If there's more than one extension, the extensions must be separated by a semicolon (',').
* `missing_components_behavior`: specifies how the reader behaves when a sample in your dataset lacks a specific file that the reader expects to find. For example, if you are reading image-caption pairs and a sample only has an image but no caption file. This can be one of three options:
    * `MISSING_COMPONENT_EMPTY`: an empty tensor is returned for the missing component.
    * `MISSING_COMPONENT_SKIP`: the sample is skipped (e.g., both the image and the missing caption).
    * `MISSING_COMPONENT_ERROR`: an error is thrown. This is the default behaviour.
`fn.readers.webdataset` also accepts arguments common to all readers: random number generator seed, shuffling, sharding, and last batch policy.


<font size= "12" >Configuring rocAL pipeline </font>

<div class="alert alert-block alert-warning">
<b>Note:</b> Set the ROCAL_DATA_PATH environment variable before running the notebook.
</div>

### Prepare dataset

The data needs to be organized in separate folders under rocal_data directory for each reader. For webdataset reader, a directory named web_dataset should be created - the data (.tar files) should be added inside a subfolder named tar_file and the metadata (.idx files) should be added inside a subfolder named idx_file.


In [None]:
# Check if ROCAL_DATA_PATH is set
rocal_data_path = os.environ.get('ROCAL_DATA_PATH')
if rocal_data_path is None:
    raise EnvironmentError("ROCAL_DATA_PATH environment variable is not set. Please set it to the correct path.")
else:
    print(f"ROCAL_DATA_PATH IS SET TO: {rocal_data_path}")
wds_data_path = f"{rocal_data_path}/rocal_data/web_dataset/tar_file/"

<font size="12">Webdataset pipeline </font>

Here the webdataset reader is used followed by the webdataset decoder. In this pipeline, cascaded augmentations are added on the decoded images.<br>crop mirror normalize augmentation outputs are returned using set_outputs

In [None]:
@pipeline_def(batch_size=1, num_threads=4, device_id=0, rocal_cpu=True)
def wds_pipeline(wds_data=wds_data_path):
    img_raw = fn.readers.webdataset(path=wds_data, ext=[{'JPEG', 'cls'}], missing_components_behavior=types.MISSING_COMPONENT_SKIP)
    img = fn.decoders.image(img_raw, file_root=wds_data, max_decoded_width=416, max_decoded_height=416)
    resize_outputs = fn.resize(img, resize_width=300, resize_height=300)
    output = fn.crop_mirror_normalize(resize_outputs,
                                        output_layout=types.NHWC,
                                        output_dtype=types.UINT8,
                                        crop=(224, 224),
                                        mean=[0.0, 0.0, 0.0],
                                        std=[1.0, 1.0, 1.0])
    return output

In [None]:
webdataset_pipeline = wds_pipeline()
webdataset_pipeline.build()
data_loader = ROCALClassificationIterator(webdataset_pipeline)

<font size ="12">Visualizing  outputs</font>

The output of augmented images is displayed using imshow()

In [None]:
cnt = 0
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(15,15))
row = 0
col = 0
for i, it in enumerate(data_loader, 0):
    for img in it[0]:
        img[0] = img[0].astype(np.uint8)
        axes[row, col].imshow(img[0])
        cnt += 1
        row += 1
        if(row == 2):
            row = 0
            col += 1
        if(col == 4):
            col = 0
data_loader.reset()