# How to Run This

## Creating the Dataset

In [1]:
import load_data_tf as load_data

"""
    Ensure that 'patches_with_cloud_and_shadow.csv' and 'patches_with_seasonal_snow.csv' are in the locations specified
    (download from GCP)
"""
filter_files = ['../patches_with_cloud_and_shadow.csv', '../patches_with_seasonal_snow.csv'] # replace with your path

data_dir = "../SmallEarthNet" # put the path to the dataset here

"""
    This will create the dataset and create a split file (if it doesn't exist) at ./smallearthnet.pkl and label count
    data at ./label_counts.pkl (takes ~10-15s).
    
    RECOMMENDED: download ./smallearthnet.pkl and ./label_counts.pkl from the GCP bucket.
"""
support_size = 8
label_subset_size = 3
meta_dataset = load_data.MetaBigEarthNetTaskDataset(data_dir=data_dir, filter_files=filter_files, 
                                                    support_size=support_size, label_subset_size=label_subset_size,
                                                    split_save_path="smallearthnet.pkl", 
                                                    split_file="smallearthnet.pkl")

  1%|          | 113/9607 [00:00<00:08, 1125.30it/s]

Reloading train-val-test split cache from smallearthnet.pkl
File smallearthnet.pkl not found. Creating new split instead with file name smallearthnet.pkl
Building new train-val-test split and saving to smallearthnet.pkl


100%|██████████| 9607/9607 [00:10<00:00, 894.38it/s] 


## Creating the Dataset + Getting Data

In [2]:
"""
    This loads the data.
"""

X, y, y_debug = meta_dataset.sample_batch(batch_size=16, split='train') # split needs to be 'train', 'val', or 'test'


"""
    You should get something of the form
    
    Data shape: (batch_size, support_set_size, channels, width, height)
    Label shape: (batch_size, support_set_size, label_subset_size)
"""
print("Data shape:", X.shape)
print("Label shape:", y.shape)

Data shape: (16, 8, 3, 120, 120)
Label shape: (16, 8, 3)


In [19]:
"""
    Let's investigate how the labels are encoded.
    
    You should see a (support_set_size, label_subset_size)-shaped array of 0s and 1s.
    Each row represents the labels corresponding to an image (X[0, row]).
    Each column denotes which classes of the subset are present (class 1, class 2, ... class N).
    
    So [1. 1. 0.] means that the corresponding image has two land cover types present within the label subset, say
    class 1 and class 2.
    
    Then [0. 1. 0.] means that that image is labeled with land cover type class 2 (same class 2 as above).
    
    Each primary index (batch dimension) represents a new task, i.e. a new subset of labels.
"""
import numpy as np
print(y) # print first support set
single_labels = (np.packbits(y.astype(int), 2, 'little') - 1).reshape((len(y), -1))
print(single_labels)
num_classes = 7
one_hot = np.eye(num_classes)[single_labels]
print(one_hot)

[[[1. 1. 0.]
  [0. 0. 1.]
  [0. 0. 1.]
  [1. 0. 0.]
  [1. 0. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [1. 1. 0.]]

 [[1. 0. 0.]
  [1. 0. 0.]
  [0. 1. 1.]
  [0. 1. 1.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]]

 [[1. 0. 0.]
  [0. 1. 0.]
  [1. 0. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [0. 1. 0.]
  [1. 0. 0.]
  [1. 0. 0.]]

 [[1. 0. 0.]
  [0. 1. 0.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 1. 1.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 1. 0.]]

 [[1. 0. 1.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 1. 0.]
  [1. 0. 0.]
  [0. 0. 1.]]

 [[1. 0. 0.]
  [0. 1. 1.]
  [0. 1. 0.]
  [1. 0. 0.]
  [1. 1. 0.]
  [0. 1. 0.]
  [0. 1. 0.]
  [1. 0. 0.]]

 [[1. 0. 0.]
  [1. 0. 0.]
  [1. 1. 0.]
  [1. 0. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [1. 1. 0.]
  [0. 1. 0.]]

 [[0. 0. 1.]
  [0. 0. 1.]
  [1. 0. 0.]
  [0. 1. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [0. 0. 1.]]

 [[1. 1. 1.]
  [1. 0. 1.]
  [0. 1. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]
  [1. 0. 0.]]

 [[1. 0. 1.]
  [1. 1. 0.]
  [1. 0. 0.]
  [1. 0

## Other useful tips

In [73]:
# see what classes are in the training dataset; val_keys and test_keys work as well
print(meta_dataset.train_keys) 

# see the indices of examples in the training set, which can be accessed via meta_dataset.dataset[idx]
print(meta_dataset.train_indices) 

# See the number of times a label appears in the data
print(meta_dataset.counts)

{'Sea and ocean', 'Sparsely vegetated areas', 'Discontinuous urban fabric', 'Burnt areas', 'Green urban areas', 'Beaches, dunes, sands', 'Estuaries', 'Transitional woodland/shrub', 'Continuous urban fabric', 'Construction sites', 'Peatbogs', 'Vineyards', 'Airports', 'Salt marshes', 'Non-irrigated arable land', 'Natural grassland', 'Moors and heathland', 'Mineral extraction sites', 'Dump sites', 'Fruit trees and berry plantations', 'Port areas', 'Sport and leisure facilities'}
[1, 2, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 21, 23, 25, 26, 27, 29, 33, 36, 42, 43, 44, 47, 50, 51, 52, 53, 55, 56, 57, 58, 65, 66, 68, 69, 73, 76, 79, 81, 82, 83, 84, 85, 86, 87, 90, 91, 94, 95, 98, 99, 111, 113, 114, 115, 119, 120, 121, 122, 125, 126, 127, 129, 130, 131, 133, 136, 137, 141, 143, 148, 152, 156, 159, 160, 161, 162, 164, 167, 169, 172, 175, 176, 179, 180, 181, 182, 184, 186, 187, 188, 194, 196, 198, 200, 201, 202, 203, 205, 206, 207, 208, 209, 213, 214, 219, 220, 221, 223, 227, 239, 240, 242, 

In [None]:
m

## Prototyping (STOP READING HERE)

In [12]:
import os
from torch.utils.data import DataLoader

data_dir = "../SmallEarthNet/"
cloud_shadow_file = "../patches_with_cloud_and_shadow.csv"
snow_file = "../patches_with_seasonal_snow.csv"
patches = os.listdir(data_dir)
patches.sort()
print(len(patches))

%load_ext autoreload
%autoreload 2

import load_data_tf as load_data

10000
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
data = load_data.BigEarthNetDataset(split_file="splits.pkl")

In [45]:
print(meta_dataset.counts)
print(meta_dataset.dataset.idx_to_label)
sum(meta_dataset.counts.values())

Counter({'Coniferous forest': 6480, 'Mixed forest': 5802, 'Transitional woodland/shrub': 3665, 'Non-irrigated arable land': 2456, 'Water bodies': 2229, 'Land principally occupied by agriculture, with significant areas of natural vegetation': 1908, 'Peatbogs': 1106, 'Broad-leaved forest': 871, 'Sea and ocean': 634, 'Discontinuous urban fabric': 545, 'Pastures': 452, 'Complex cultivation patterns': 342, 'Moors and heathland': 208, 'Vineyards': 202, 'Inland marshes': 193, 'Industrial or commercial units': 98, 'Bare rock': 85, 'Natural grassland': 81, 'Mineral extraction sites': 71, 'Fruit trees and berry plantations': 53, 'Sport and leisure facilities': 49, 'Dump sites': 32, 'Intertidal flats': 25, 'Beaches, dunes, sands': 15, 'Water courses': 12, 'Estuaries': 9, 'Road and rail networks and associated land': 9, 'Continuous urban fabric': 8, 'Green urban areas': 7, 'Sparsely vegetated areas': 7, 'Airports': 6, 'Burnt areas': 4, 'Salt marshes': 4, 'Construction sites': 2, 'Port areas': 1})


27671

In [44]:
meta_dataset = load_data.MetaBigEarthNetTaskDataset(data_dir="../SmallEarthNet", split_save_path="smallearthnet.pkl", split_file="smallearthnet_splits.pkl")

 17%|█▋        | 1634/9607 [00:00<00:00, 16336.74it/s]

Label count cache at ./label_counts.pkl not found; rebuilding cache.


100%|██████████| 9607/9607 [00:00<00:00, 17043.96it/s]
 17%|█▋        | 1651/9607 [00:00<00:00, 16505.37it/s]

Reloading train-val-test split cache from smallearthnet_splits.pkl
File smallearthnet_splits.pkl not found. Creating new split instead with file name smallearthnet.pkl
Building new train-val-test split and saving to smallearthnet.pkl


100%|██████████| 9607/9607 [00:00<00:00, 17272.65it/s]


In [54]:
X, y, yt = meta_dataset.sample_batch(batch_size=16, split='train')
y.shape

[2, 4, 5, 10, 11, 12, 13, 15, 16, 18, 19, 20, 21, 24, 26, 27, 28, 29, 30, 32, 33, 34] 5340


(16, 8, 3)

In [44]:
train = DataLoader(data, batch_size=16)
img_batch, label_batch = next(iter(train))
img_batch.size(), label_batch.size()

(torch.Size([16, 3, 120, 120]), torch.Size([16, 43]))

In [75]:
train, _, _ = load_data.get_dataloaders(split_file="splits.pkl")

Reloading train-val-test split cache from splits.pkl
Reloading train-val-test split cache from splits.pkl
Reloading train-val-test split cache from splits.pkl


In [76]:
img_batch, label_batch, raw_labels = next(iter(train))
img_batch.size(), label_batch.size()

(torch.Size([8, 8, 3, 120, 120]), torch.Size([8, 8, 3]))

In [77]:
meta_dataset.key_indices, raw_labels

([4,
  5,
  6,
  7,
  12,
  13,
  14,
  15,
  16,
  17,
  19,
  20,
  22,
  24,
  25,
  28,
  30,
  32,
  33,
  34,
  35,
  36,
  37,
  39,
  40,
  41,
  42],
 tensor([[[22, -1, -1],
          [ 4, 17, -1],
          [ 4, -1, -1],
          [ 4, -1, -1],
          [ 4, 22, -1],
          [17, 22, -1],
          [ 4, 22, -1],
          [ 4, 22, -1]],
 
         [[ 4, -1, -1],
          [ 4, 22, -1],
          [ 4, 37, -1],
          [ 4, -1, -1],
          [ 4, 22, -1],
          [ 4, -1, -1],
          [22, -1, -1],
          [ 4, 22, -1]],
 
         [[ 4,  5, 22],
          [ 4,  5, -1],
          [ 4,  5, 22],
          [ 4, -1, -1],
          [ 4, -1, -1],
          [ 4, 22, -1],
          [ 4,  5, 22],
          [22, -1, -1]],
 
         [[ 4, -1, -1],
          [ 7, -1, -1],
          [14, -1, -1],
          [ 4, -1, -1],
          [ 4, 14, -1],
          [ 4, -1, -1],
          [ 4,  7, -1],
          [ 4, 14, -1]],
 
         [[ 4, 22, -1],
          [ 4, 22, -1],
          [ 4

In [None]:
import csv
import gdal
import rasterio

In [None]:

elimination_patch_list = []  
for file_path in [cloud_shadow_file, snow_file]:
    if not os.path.exists(file_path):
        print('ERROR: file located at', file_path, 'does not exist')
        exit()
    with open(file_path, 'r') as f:
        csv_reader = csv.reader(f, delimiter=',')
        for row in csv_reader:
            elimination_patch_list.append(row[0])
#print('INFO:', len(elimination_patch_list), 'number of patches will be eliminated')
elimination_patch_list = set(elimination_patch_list)


In [None]:
filtered_patches = [patch for patch in patches if patch not in elimination_patch_list]
len(filtered_patches)

In [None]:
os.listdir(os.path.join(data_dir, patches[2]))

In [None]:
import random
idx = random.randint(0, len(patches))
all_bands = ['B01', 'B02', 'B03', 'B04', 'B05',
              'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
rgb_bands = ['B04', 'B03', 'B02']

band_stack = []
for bands in rgb_bands:
    band_path = os.path.join(data_dir, patches[idx], "{}_{}.tif".format(patches[idx], bands))
    print("Loading", band_path)
    assert os.path.isfile(band_path)

    band_ds = gdal.Open(band_path,  gdal.GA_ReadOnly)
    raster_band = band_ds.GetRasterBand(1)
    band_data = raster_band.ReadAsArray()
    band_stack.append(band_data)



In [None]:
import matplotlib.pyplot as plt

_OPTICAL_MAX_VALUE = 2000. # magic number from some google guys
img = np.stack(band_stack) / _OPTICAL_MAX_VALUE # (C, W, H)
img = np.clip(img, 0, 1)
plt.imshow(np.transpose(img))


In [None]:
import json

In [None]:
with open(os.path.join(data_dir, patches[idx], "{}_labels_metadata.json".format(patches[idx])), 'r') as f:
    metadata = json.load(f)

metadata

In [None]:
from collections import Counter

c = Counter()
c.update(metadata['labels'])
c

In [None]:
from collections import Counter
from tqdm.notebook import tqdm
import pickle

label_counts = Counter()
for patch in tqdm(patches):
    with open(os.path.join(data_dir, patch, "{}_labels_metadata.json".format(patch)), 'r') as f:
        metadata = json.load(f)
        label_counts.update(metadata['labels'])
print(label_counts)

with open("label_counts_cache.pkl", "wb") as f:
    pickle.dump(label_counts, f)

In [None]:
random.choice(label_counts.keys(), 2)