In [125]:
import os
from pathlib import Path
from zipfile import ZipFile
import urllib.request

import torch
from tfrecord.torch.dataset import MultiTFRecordDataset

# parameters
BATCH_SIZE=256
FEATURES = [
    "elevation",
    "th",
    "vs",
    "tmmn",
    "tmmx",
    "sph",
    "pr",
    "pdsi",
    "NDVI",
    "population",
    "erc",
    "PrevFireMask",
]
LABELS = ["FireMask"]

ARR_SIZE = 4096
LENGTH, WIDTH = 64, 64

In [103]:
# set up data directory
data_dir = os.path.join(os.path.abspath("."), "data")
Path(data_dir).mkdir(parents=True, exist_ok=True)

In [104]:
# download data zip
data_zip = os.path.join(data_dir, "archive.zip")
if not os.path.exists(data_zip):
    url = "https://www.kaggle.com/api/v1/datasets/download/fantineh/next-day-wildfire-spread"
    urllib.request.urlretrieve(url, data_zip)

In [105]:
# extract files from zip
files = []
with ZipFile(data_zip, "r") as z:
    for file in z.namelist():
        files.append(Path(file).stem)
        if not os.path.exists(os.path.join(data_dir, file)):
            z.extract(file, data_dir)

In [106]:
# get all records into a data loader
tfrecord_path = os.path.join(data_dir, "{}.tfrecord")
dataset = MultiTFRecordDataset(
    tfrecord_path, 
    None, 
    splits={file: 1.0 for file in files}
)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

In [107]:
data = next(iter(loader))

In [133]:
# gather batch of features
features = torch.cat([data[key][:, :, None] for key in FEATURES], dim=2)
features = features.reshape(BATCH_SIZE, LENGTH, WIDTH, len(FEATURES))
features.shape

torch.Size([256, 64, 64, 12])

In [134]:
# gather batch of label(s)
labels = torch.cat([data[key][:, :, None] for key in LABELS], dim=2)
labels = labels.reshape(BATCH_SIZE, LENGTH, WIDTH, len(LABELS))
labels.shape

torch.Size([256, 64, 64, 1])