In [None]:
%load_ext autoreload
%autoreload 2
from oxeo.water.datamodules.samplers import RandomSampler
from oxeo.water.datamodules.datasets import TileDataset
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
from skimage.exposure import rescale_intensity
import numpy as np
from oxeo.core.models.tile import TilePath, tile_from_id

In [None]:
TilePath(tile_from_id("43_P_10000_65_137"), "landsat-5").tile.id

In [None]:
tiles = []
for t in ["43_P_10000_65_137","43_P_10000_65_138"]:
    for c in ["landsat-5", "landsat-7", "landsat-8", "sentinel-2"]:

        tiles.append(TilePath(tile_from_id(t), c))
ds_train = TileDataset(tiles,
                         masks=("pekel","cloud_mask"),
                         target_size=1000, 
                         bands=["nir", "red", "green", "blue", "swir1", "swir2"], 
                         cache_dir="cache",
                         cache_bytes=1e8)

In [None]:

sampler = RandomSampler(ds_train, 250, 1000, revisits_per_epoch=500, samples_per_revisit=200)
#sampler = GridSampler(ds, 1000, 1000)#


In [None]:
def worker_init_fn(worker_id):
    """Configures each dataset worker process.

    Just has one job!  To call SatelliteDataset.per_worker_init().
    """
    # get_worker_info() returns information specific to each worker process.
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
        print("worker_info is None!")
    else:
        dataset_obj = worker_info.dataset  # The Dataset copy in this worker process.
        dataset_obj.per_worker_init()


dataloader = DataLoader(ds_train, batch_size=1, 
                        sampler=sampler, 
                        num_workers=3, 
                        worker_init_fn=worker_init_fn)

len(dataloader)

In [None]:
def plot_imgs_in_row(imgs, labels=["img", "water", "clouds"], figsize=(8,5)):
    rows = 1
    cols = len(imgs)
    fig, ax = plt.subplots(rows, cols, figsize=figsize)
    # axes are in a two-dimensional array, indexed by [row, col]
   # fig.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)

    for i in range(cols):
        ax[i].set_title(labels[i])
        img = imgs[i]
        ax[i].imshow(img, vmin=0.0, vmax=1.0, interpolation=None)  
        ax[i].axis('off')
    fig

In [None]:

for sample in dataloader:
    img = sample["image"][0][[1,2,3]].numpy().transpose(1,2,0)
    pekel = sample["pekel"].numpy()[0].squeeze()
    
    mask = sample["cloud_mask"].numpy()[0].squeeze()
    #mask[0,0] = 1.0
    #pekel[0,0]=1.0
    print("pekel min", pekel.min(), "coud_min", mask.min())
    vmin, vmax = np.percentile(img, q=(2, 98))
    img = rescale_intensity(img,in_range=(vmin,vmax), out_range=(0,1))
    plot_imgs_in_row([img, pekel, mask], labels=[f"{sample['constellation'][0]}_img", 
                                                 f"{sample['constellation'][0]}_water", 
                                                 f"{sample['constellation'][0]}_cloud"])
    plt.show()