In [None]:
import os
import argparse
import json
import torch

 
# the custom dataset file also includes scripts for geobench. if you dont want that, simply comment out those lines. 
from mmearth_dataset import get_mmearth_dataloaders


from MODALITIES import * # this contains all the input and output bands u need for pretraining.

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args()

# these 4 arguments need to be set manually
args.data_path = '/data/mmearth/data_1M_v001/' # path to h5 file 
args.random_crop = True # ensure that if the dataset image size is 128 x 128, the resulting image after cropping is 112 x 112.
args.random_crop_size = 112 # the size of the crop
args.batch_size = 1

# define the input and output bands for the dataset
args.inp_modalities = INP_MODALITIES
args.out_modalities = OUT_MODALITIES

args.modalities = args.inp_modalities.copy()
args.modalities.update(args.out_modalities) # args modalities is a dictionary of all the input and output bands.
args.modalities_full = MODALITIES_FULL # this is a dictionary of all the bands in the dataset.

args.no_ffcv = False # this flag allows you to load the ffcv dataloader or the h5 dataset.
args.processed_dir = None # default is automatically created in the data path. this is the dir where the beton file for ffcv is stored
args.num_workers = 4 # number of workers for the dataloader
args.distributed = False # if you are using distributed training, set this to True


In [None]:
def collate_fn(batch): # only for non ffcv dataloader
    # for each batch append the samples of the same modality together and return the ids. We keep track of the ids to differentiate between sentinel2_l1c and sentinel2_l2a
    return_batch = {}
    ids = [b['id'] for b in batch]
    return_batch = {modality: torch.stack([b[modality] for b in batch], dim=0) for modality in args.modalities.keys()}
    return ids, return_batch

# the following line, creates a pytorch dataset object. 
if args.no_ffcv:
    dataset = get_mmearth_dataloaders(
        args.data_dir,
        args.processed_dir, 
        args.modalities,
        num_workers=args.num_workers,
        batch_size_per_device=args.batch_size,
        distributed=args.distributed
    )[0] # non ffcv mode returns only the dataset object
    
    # define a sampler based on the number of tasks and the global rank. This is useful for distributed training
    num_tasks = # number of tasks in distributed training
    global_rank = # global rank of the current task
    sampler_train = torch.utils.data.DistributedSampler(
        dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
    )

    train_dataloader = torch.utils.data.DataLoader(
        dataset, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        collate_fn=collate_fn,
    )
else:
    # if ffcv, we return the dataloader object
    train_dataloader = get_mmearth_dataloaders(
        args.data_dir,
        args.processed_dir,
        args.modalities,
        num_workers=args.num_workers,
        batch_size_per_device=args.batch_size,
        distributed=args.distributed,
    )[0]

In [None]:
# The dataloader item is a dictionary of all the modalities.
# this returns a dictionary of all the modalities as key, and the corresponding data as value. The keys 
# are similar to the ones in the args.modalities dictionary, or the MODALITIES.py file.