In [1]:
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
from tqdm import tqdm, trange
from pymongo import MongoClient
import torch
import time

In [2]:
MONGODB_URI = f"mongodb+srv://nucleusadmin:eMPF9pgRy2UqJW3@imagedata.global.mongocluster.cosmos.azure.com/?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000"
DATA_HOME = Path('/mnt/pollux/data/world_model')


def save_dataset(dataset, path):
    # Assuming 'data' is your large list of dictionaries
    batch_size = 10000
    schema = None  # Will be inferred from the first batch

    for i in trange(0, len(dataset), batch_size):
        batch = dataset[i:i+batch_size]
        df_batch = pd.DataFrame(batch)
        df_batch['_id'] = df_batch['_id'].map(lambda x: str(x))
        
        # For the first batch, create the schema and ParquetWriter
        if i == 0:
            table = pa.Table.from_pandas(df_batch, preserve_index=False)
            schema = table.schema
            writer = pq.ParquetWriter(path, schema)
            writer.write_table(table)
        else:
            table = pa.Table.from_pandas(df_batch, schema=schema, preserve_index=False)
            writer.write_table(table)

    # Close the writer when done
    if 'writer' in locals():
        writer.close()
        print("Parquet file created successfully!")


def get_dataset(collection, use_cache=True):
    if use_cache and (DATA_HOME / f'{collection}.parquet').exists():
        print(f"Loading {collection} from cache")
        return pd.read_parquet(DATA_HOME / f'{collection}.parquet')
    else:
        print(f"Loading {collection} from MongoDB, this may take a while...")
        query = {}
        projection = {
            "_id": 1,
            "source_id": 1,
            "media_path": 1,
            "width": 1,
            "height": 1,
            "caption": 1,
            "source": 1,
        }

        client = MongoClient(MONGODB_URI)
        db = client['world_model']
        collection = db[collection]

        cursor = collection.find(
            query,
            projection,
            batch_size=8192,
            no_cursor_timeout=True,
            max_time_ms=2400000,
        )

        dataset = []
        for doc in tqdm(cursor):
            dataset.append(doc)

        save_dataset(dataset, DATA_HOME / f'{collection}.parquet')
        return pd.read_parquet(DATA_HOME / f'{collection}.parquet')


In [3]:
def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps, resume_step, seed):
    sample_indices = torch.empty([max_steps * global_batch_size // world_size], dtype=torch.long)
    epoch_id, fill_ptr, offs = 0, 0, 0
    while fill_ptr < sample_indices.size(0):
        g = torch.Generator()
        g.manual_seed(seed + epoch_id)
        epoch_sample_indices = torch.randperm(len(dataset), generator=g)
        epoch_id += 1
        epoch_sample_indices = epoch_sample_indices[(rank + offs) % world_size :: world_size]
        offs = (offs + world_size - len(dataset) % world_size) % world_size
        epoch_sample_indices = epoch_sample_indices[: sample_indices.size(0) - fill_ptr]
        sample_indices[fill_ptr : fill_ptr + epoch_sample_indices.size(0)] = epoch_sample_indices
        fill_ptr += epoch_sample_indices.size(0)
    return sample_indices[resume_step * global_batch_size // world_size :].tolist()

In [4]:
def shuffle_dataset(dataset: pd.DataFrame, global_batch_size: int, world_size: int, rank:int, seed:int, resume_step:int = 0):
    samples_per_rank = len(dataset) // world_size
    max_steps = samples_per_rank // (global_batch_size // world_size)

    indices = get_train_sampler(
        dataset, 
        rank=rank, 
        world_size=world_size, 
        global_batch_size=global_batch_size, 
        max_steps=max_steps, 
        resume_step=resume_step, 
        seed=seed
    )
    
    return dataset.iloc[indices]

In [5]:
dataset = get_dataset('bucket-256-1', use_cache=True)

Loading bucket-256-1 from cache


In [6]:
print(f"Shuffling all data")
time_start = time.time()
dataset_rank = shuffle_dataset(
    dataset=dataset, 
    global_batch_size=512, 
    world_size=1, 
    rank=0, 
    seed=47,
    resume_step=0
)
time_end = time.time()
print(f"Shuffling rank {0} done in {time_end - time_start} seconds")
save_dataset(dataset_rank, DATA_HOME / "shuffled" / f'bucket-256-1.parquet')

Shuffling all data
Shuffling rank 0 done in 50.24880599975586 seconds


100%|██████████| 10549/10549 [10:05<00:00, 17.41it/s]


Parquet file created successfully!


In [7]:
DATA_HOME / "shuffled" / f'bucket-256-1.parquet'

PosixPath('/mnt/pollux/data/world_model/shuffled/bucket-256-1.parquet')

In [11]:
dataset.iloc[0].media_path

'https://ayon.blob.core.windows.net/flickr-images/13549916194.jpg'

In [12]:
len(dataset)

105480192

In [None]:
for rank in range(8):
    print(f"Shuffling all data")
    time_start = time.time()
    dataset_rank = shuffle_dataset(
        dataset=dataset, 
        global_batch_size=512, 
        world_size=8, 
        rank=rank, 
        seed=47,
        resume_step=0
    )
    time_end = time.time()
    print(f"Shuffling rank {rank} done in {time_end - time_start} seconds")
    save_dataset(dataset_rank, DATA_HOME / "shuffled" / f'bucket-256-1-rank-{rank}.parquet')
