In [1]:
import os
from tqdm import tqdm
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
z

# ash-color/normal-label

In [None]:
def normalize_range(data, bounds):
    """Maps data to the range [0, 1]."""
    return (data - bounds[0]) / (bounds[1] - bounds[0])


def get_false_color(record_data):
    _TDIFF_BOUNDS = (-4, 2)
    _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
    _T11_BOUNDS = (243, 303)

    N_TIMES_LABELED = 4

    r = normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
    g = normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(record_data["band_14"], _T11_BOUNDS)
    false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
    img = false_color[..., N_TIMES_LABELED]

    return img


def read_record(record_id, directory, mode):
    record_data = {}
    if mode in ["train", "validation"]:
        bands_mask = ["band_11", "band_14", "band_15", "human_pixel_masks"]
    if mode in ["test"]:
        bands_mask = ["band_11", "band_14", "band_15"]

    for x in bands_mask:
        record_data[x] = np.load(os.path.join(directory, record_id, x + ".npy"))
    return record_data


def create_dataset(data_dir, save_dir, mode):
    os.makedirs(save_dir, exist_ok=True)

    input_dir = f"{data_dir}/{mode}"
    ids = os.listdir(input_dir)

    df = pd.DataFrame(ids, columns=['record_id'])
    df['path'] = save_dir + df['record_id'].astype(str) + '.npy'
    df.to_csv(f"{save_dir}/{mode}_df.csv", index=False)

    for record_id in tqdm(ids):
        data = read_record(str(record_id), input_dir, mode)
        images = get_false_color(data)
        if mode in ["train", "validation"]:
            array = np.dstack([images, data['human_pixel_masks']])
        if mode in ["test"]:
            array = np.dstack([images])
        array = array.astype(np.float16)

        npy_path = f"{save_dir}/{record_id}.npy"
        np.save(str(npy_path), array)

data_dir = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/'
dataset_train = "/kaggle/working/dataset_train/ash_color/"
dataset_test = "/kaggle/working/dataset_test/ash_color/"

# create_dataset(data_dir, dataset_train, "train")
# create_dataset(data_dir, dataset_train, "validation")
# # create_dataset(data_dir, dataset_test, "test")