In [None]:
import datetime

import matplotlib.dates
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import pandas as pd
import shapely
import torch
import torch.utils.data
import tqdm

import ssl4rs.data.parsers.disa
import ssl4rs.utils.config
import ssl4rs.utils.logging

logger = ssl4rs.utils.logging.setup_logging_for_analysis_script()

In [None]:
# EDIT THIS CELL IF NECESSARY (i.e. if using another local path)
dataset_root = ssl4rs.utils.config.get_data_root_dir() / "ai4h-disa" / "india"
assert dataset_root.is_dir(), f"bad dataset root dir: {dataset_root}"
deeplake_dataset_path = dataset_root / ".deeplake"
assert deeplake_dataset_path.exists(), f"bad deeplake dataset path: {deeplake_dataset_path}"
logger.info(f"Ready to parse deeplake dataset at: {deeplake_dataset_path}")

In [None]:
def convert_data_to_float(batch: dict) -> dict:
    batch["image_data"] = batch["image_data"].astype(np.float32)
    return batch


dataset_parser = ssl4rs.data.parsers.disa.DeepLakeParser(
    dataset_path_or_object=deeplake_dataset_path,  # already-opened object or path to the .deeplake dir
    check_integrity=True,  # will run internal checks to make sure the data is clean/good
    batch_transforms=[convert_data_to_float],  # converts uint16 raster data to a pytorch-friendly dtype
)
dataset_parser.summary()

In [None]:
sample_idx = np.random.randint(len(dataset_parser))
sample_data = dataset_parser[sample_idx]
logger.info(f"Displaying preview for {sample_data['location_id']}...")
# note: preview image is always based on the oldest in the image data stack
timestamp = sample_data["image_metadata"][0].item()["timestamp"]
logger.info(f"image timestamp: {timestamp}")

plt.figure(figsize=(16, 4))

plt.subplot(1, 3, 1)
plt.imshow(sample_data["location_preview_image"])

plt.title("Image Preview")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(sample_data["location_preview_roi"], cmap=plt.cm.gray)
plt.title("Region of Interest")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(sample_data["field_mask"], cmap=plt.cm.gray)
plt.title("Field Mask")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# print stats and other interesting things related to the full dataset
metadata_lists = dict(
    images_per_location=[],
    height_per_location=[],
    width_per_location=[],
    scatter_per_location=[],
    fields_per_location=[],
    field_area_per_location=[],
    field_percent_coverage_per_location=[],
    timestamps=[],
    time_deltas=[],
    area_per_field=[],
    gsd_per_image=[],
    view_angle_per_image=[],
    valid_percent_coverage_per_image=[],
    bandwise_mean_per_image=[],
    bandwise_std_per_image=[],
)

band_count = dataset_parser.dataset_info["band_count"]
total_location_count = len(dataset_parser)
location_idxs = list(range(total_location_count))

for location_idx in tqdm.tqdm(location_idxs, desc="parsing location data"):
    location_data = dataset_parser[location_idx]
    image_count = int(location_data["image_count"].item())
    image_metadata = [location_data["image_metadata"][image_idx].item() for image_idx in range(image_count)]
    image_data = location_data["image_data"]
    roi_data = location_data["image_roi"]
    field_mask = location_data["field_mask"]
    field_geoms = location_data["field_geoms"]
    field_count = len(field_geoms)
    field_polygons = [shapely.geometry.Polygon(field_geoms[pidx]) for pidx in range(field_count)]
    field_areas = [p.area for p in field_polygons]
    metadata_lists["images_per_location"].append(image_count)
    metadata_lists["height_per_location"].append(image_data.shape[-2])
    metadata_lists["width_per_location"].append(image_data.shape[-1])
    metadata_lists["scatter_per_location"].append(location_data["field_scatter"].item())
    metadata_lists["fields_per_location"].append(field_count)
    metadata_lists["field_area_per_location"].append(sum(field_areas))
    metadata_lists["field_percent_coverage_per_location"].append(
        (np.count_nonzero(field_mask) / np.prod(field_mask.shape)) * 100
    )
    timestamps = [
        datetime.datetime.strptime(
            image_metadata[image_idx]["properties"]["acquired"],
            "%Y-%m-%dT%H:%M:%S.%fZ",
        )
        for image_idx in range(image_count)
    ]
    metadata_lists["timestamps"].extend(timestamps)
    metadata_lists["time_deltas"].extend(
        [timestamps[image_idx] - timestamps[image_idx - 1] for image_idx in range(1, image_count)]
    )
    metadata_lists["area_per_field"].extend(field_areas)
    metadata_lists["gsd_per_image"].extend(
        [image_metadata[image_idx]["properties"]["gsd"] for image_idx in range(image_count)]
    )
    metadata_lists["view_angle_per_image"].extend(
        [image_metadata[image_idx]["properties"]["view_angle"] for image_idx in range(image_count)]
    )
    metadata_lists["valid_percent_coverage_per_image"].extend(
        [
            (np.count_nonzero(roi_data[image_idx]) / np.prod(roi_data.shape[1:])) * 100
            for image_idx in range(image_count)
        ]
    )
    metadata_lists["bandwise_mean_per_image"].extend(
        [
            [image_data[image_idx][band_idx][roi_data[image_idx]].mean() for band_idx in range(band_count)]
            for image_idx in range(image_count)
        ]
    )
    metadata_lists["bandwise_std_per_image"].extend(
        [
            [image_data[image_idx][band_idx][roi_data[image_idx]].std() for band_idx in range(band_count)]
            for image_idx in range(image_count)
        ]
    )

total_image_count = sum(metadata_lists["images_per_location"])
total_field_count = sum(metadata_lists["fields_per_location"])

In [None]:
arrays_to_plot_as_hist = [
    "images_per_location",
    "height_per_location",
    "width_per_location",
    "scatter_per_location",
    "fields_per_location",
    "field_area_per_location",
    "field_percent_coverage_per_location",
    "area_per_field",
    "gsd_per_image",
    "view_angle_per_image",
    "valid_percent_coverage_per_image",
]

for array_name in arrays_to_plot_as_hist:
    array_data = metadata_lists[array_name]
    fig, ax = plt.subplots(figsize=(12, 4), dpi=300)
    counts, bin_edges = np.histogram(array_data, bins=30)
    bin_widths = np.diff(bin_edges)
    ax.bar(bin_edges[:-1], counts, width=bin_widths, edgecolor="black", align="edge")
    ax.set_xlabel("Values")
    ax.set_ylabel("Frequency")
    ax.set_title(f"{array_name} distribution")
    min_val = np.min(array_data)
    max_val = np.max(array_data)
    mean_val = np.mean(array_data)
    median_val = np.median(array_data)
    num_decimals = max(0, 4 - int(np.floor(np.log10(max_val - min_val))))
    subtitle = (
        f"min: {min_val:.{num_decimals}f}, "
        f"avg: {mean_val:.{num_decimals}f}, "
        f"median: {median_val:.{num_decimals}f}, "
        f"max: {max_val:.{num_decimals}f}"
    )
    ax.text(0.5, 0.9, subtitle, transform=ax.transAxes, fontsize="small", ha="center", va="bottom")
    ax.grid(True)
    ax.xaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(1))
    fig.tight_layout()
    plt.show()

In [None]:
timestamps = metadata_lists["timestamps"]
time_deltas = metadata_lists["time_deltas"]
fig, axes = plt.subplots(nrows=2, figsize=(12, 9), dpi=300)
axes[0].hist(timestamps, bins=100, alpha=0.75, color="blue")
axes[0].xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d, %H:%M:%S"))
axes[0].xaxis.set_major_locator(matplotlib.dates.MonthLocator())
axes[0].set_xticks(axes[0].get_xticks())
axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha="right")
axes[0].grid(True)
axes[0].set_title("Acquisition timestamps")
timedeltas_in_days = [td.total_seconds() / (60 * 60 * 24) for td in time_deltas]
bins = np.logspace(np.log10(min(timedeltas_in_days)), np.log10(max(timedeltas_in_days)), num=50)
axes[1].hist(timedeltas_in_days, bins=bins, alpha=0.75, color="blue")
axes[1].set_xscale("log")
axes[1].set_xlabel("Time delta (days, log-scale)")
axes[1].set_ylabel("Frequency")
axes[1].set_title("Time deltas between acquisitions")
axes[1].grid(True)
axes[1].minorticks_on()
axes[1].xaxis.set_major_locator(matplotlib.ticker.LogLocator(subs="all"))
axes[1].xaxis.set_minor_locator(matplotlib.ticker.LogLocator(subs="all"))
min_val = np.min(timedeltas_in_days)
max_val = np.max(timedeltas_in_days)
mean_val = np.mean(timedeltas_in_days)
median_val = np.median(timedeltas_in_days)
subtitle = f"min: {min_val:.2f}, avg: {mean_val:.2f}, median: {median_val:.2f}, max: {max_val:.2f}"
axes[1].text(0.5, 0.9, subtitle, transform=axes[1].transAxes, fontsize="small", ha="center", va="bottom")
fig.tight_layout()
plt.show()

In [None]:
if band_count == 3:
    colors = ["red", "green", "blue"]
    labels = dataset_parser.metadata.three_band_descriptions
elif band_count == 4:
    colors = ["blue", "green", "red", "darkred"]
    labels = dataset_parser.metadata.four_band_descriptions
else:
    raise NotImplementedError

bandwise_mean_data = np.asarray(metadata_lists["bandwise_mean_per_image"]).T
bandwise_std_data = np.asarray(metadata_lists["bandwise_std_per_image"]).T

fig, axes = plt.subplots(nrows=2, figsize=(12, 9), dpi=300)
print("band_mean_values = [")
for band_idx, mean_data in enumerate(bandwise_mean_data):
    axes[0].hist(mean_data, bins=50, histtype="step", color=colors[band_idx], label=labels[band_idx])
    print(f"\t{np.mean(mean_data)},")
print("]")
print("band_std_values = [")
for band_idx, std_data in enumerate(bandwise_std_data):
    axes[1].hist(std_data, bins=50, histtype="step", color=colors[band_idx], label=labels[band_idx])
    print(f"\t{np.mean(std_data)},")
print("]")
axes[0].set_xlabel("Mean value")
axes[0].set_ylabel("Frequency")
axes[0].set_title("Image-wise band mean distribution")
axes[0].grid(True)
axes[1].set_xlabel("Standard deviation")
axes[1].set_ylabel("Frequency")
axes[1].set_title("Image-wise band std distribution")
axes[1].grid(True)
fig.tight_layout()
plt.show()

In [None]:
tensor_names = [  # these are the names of tensors we want to work with, and that will be batched
    "location_id",
    "field_geoms",
    "field_mask",
    "image_data",
    "image_roi",
    "image_udm2",
]
tensor_names_to_collate_manually = [  # anything that is not a numpy array will be manually handled
    n for n in tensor_names if not n.startswith("image_") or n != "field_mask"
]


def custom_collate(batches: list[dict]) -> dict:
    output = torch.utils.data.default_collate(
        [{k: v for k, v in batch.items() if k not in tensor_names_to_collate_manually} for batch in batches]
    )
    for k in tensor_names_to_collate_manually:
        output[k] = [b[k] for b in batches]
    return output


dataloader = dataset_parser.get_dataloader(  # we will create a dataloader using deeplake directly
    batch_size=1,
    collate_fn=custom_collate,
    tensors=tensor_names,
)

In [None]:
batch = next(iter(dataloader))
batch