In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data

import ssl4rs.data.parsers.disa
import ssl4rs.utils.config
import ssl4rs.utils.logging

logger = ssl4rs.utils.logging.setup_logging_for_analysis_script()

In [None]:
# EDIT THIS CELL IF NECESSARY (i.e. if using another local path)
dataset_root = ssl4rs.utils.config.get_data_root_dir() / "ai4h-disa" / "india"
assert dataset_root.is_dir(), f"bad dataset root dir: {dataset_root}"
deeplake_dataset_path = dataset_root / ".deeplake"
assert deeplake_dataset_path.exists(), f"bad deeplake dataset path: {deeplake_dataset_path}"
logger.info(f"Ready to parse deeplake dataset at: {deeplake_dataset_path}")

In [None]:
def convert_data_to_float(batch: dict) -> dict:
    batch["image_data"] = batch["image_data"].astype(np.float32)
    return batch


dataset_parser = ssl4rs.data.parsers.disa.DeepLakeParser(
    dataset_path_or_object=deeplake_dataset_path,  # already-opened object or path to the .deeplake dir
    check_integrity=True,  # will run internal checks to make sure the data is clean/good
    batch_transforms=[convert_data_to_float],  # converts uint16 raster data to a pytorch-friendly dtype
)
dataset_parser.summary()

In [None]:
sample_idx = np.random.randint(len(dataset_parser))
sample_data = dataset_parser[sample_idx]
logger.info(f"Displaying preview for {sample_data['location_id']}...")

plt.figure(figsize=(16, 4))

plt.subplot(1, 3, 1)
plt.imshow(sample_data["location_preview_image"])
plt.title("Image Preview")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(sample_data["location_preview_roi"], cmap=plt.cm.gray)
plt.title("Region of Interest")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(sample_data["field_mask"], cmap=plt.cm.gray)
plt.title("Field Mask")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
tensor_names = [  # these are the names of tensors we want to work with, and that will be batched
    "location_id",
    "field_geoms",
    "field_mask",
    "image_data",
    "image_roi",
    "image_udm2",
]
tensor_names_to_collate_manually = [  # anything that is not a numpy array will be manually handled
    n for n in tensor_names if not n.startswith("image_") or n != "field_mask"
]


def custom_collate(batches: list[dict]) -> dict:
    output = torch.utils.data.default_collate(
        [{k: v for k, v in batch.items() if k not in tensor_names_to_collate_manually} for batch in batches]
    )
    for k in tensor_names_to_collate_manually:
        output[k] = [b[k] for b in batches]
    return output


dataloader = dataset_parser.get_dataloader(  # we will create a dataloader using deeplake directly
    batch_size=1,
    collate_fn=custom_collate,
    tensors=tensor_names,
)

In [None]:
batch = next(iter(dataloader))
batch