In [1]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from datasets import load_dataset, load_from_disk
import torch
from torch.utils.data import DataLoader
import numpy as np
import pickle
import yaml
import json
from collections import defaultdict
root_path = Path(".").resolve().parent
print("Root path:", root_path)
sys.path.append(str(root_path))

import importlib
from src.model_meta.dataset import CSVDataModule
print("point_num_dict" in dir(CSVDataModule))

  from .autonotebook import tqdm as notebook_tqdm


Root path: /Users/takeruito/work/PrfSR
False


# Load from CSV and evaluate

In [None]:
dataset_path = "/Users/takeruito/work/PrfSR/data/training/superfib_r1_dataset.csv"
dataset = load_dataset("csv", data_files=dataset_path, split="train")

def formatter(sample):
    sample["source"] = eval(sample["source"])
    sample["target"] = eval(sample["target"])
    return sample

eval_dataset = dataset.map(formatter, batched=False, num_proc=10, load_from_cache_file=False)

# Load from hf data directory

In [2]:
eval_dataset = load_from_disk("/Users/takeruito/work/PrfSR/data/training/dir_hf_superfib_eval_dataset")

  table = cls._concat_blocks(blocks, axis=0)


# Load config path

In [6]:
train_config_path = root_path / "src/model_meta/training_config.yaml"
assert Path(train_config_path).exists(), FileNotFoundError(f"Train config file not found: {train_config_path}")
with open(train_config_path, 'r') as f:
    config = yaml.safe_load(f)
print(f"Metadata path: {config["metadata_path"]:}")
print(f"max_epoch: {config['max_epoch']}")
print(f"max_value: {config['max_value']}")
print(f"min_n_tokens_in_batch: {config['min_n_tokens_in_batch']}")
print(f"test_ratio: {config['test_ratio']}")
print(f"val_ratio: {config['val_ratio']}")
print(f"num_workers: {config['num_workers']}")
print(f"token_embed_dim: {config['token_embed_dim']}")
print(f"emb_expansion_factor: {config['emb_expansion_factor']}")
print(f"learning_rate: {config['learning_rate']}")
print(f"nhead: {config['transformer']['nhead']}")
print(f"num_encoder_layers: {config['transformer']['num_encoder_layers']}")
print(f"num_decoder_layers: {config["transformer"]["num_decoder_layers"]}")
print(f"dim_feedforward: {config["transformer"]["dim_feedforward"]}")
print(f"dropout: {config["transformer"]["dropout"]}")


Metadata path: data/training/superfib_r1_metadata.pickle
max_epoch: 3
max_value: 2000
min_n_tokens_in_batch: 2000
test_ratio: 0.5
val_ratio: 0.25
num_workers: 16
token_embed_dim: 16
emb_expansion_factor: 1
learning_rate: 3*10**(-4)
nhead: 16
num_encoder_layers: 4
num_decoder_layers: 6
dim_feedforward: 512
dropout: 0.1


# Load Metadata

In [7]:
# Load metadata from file (supports YAML, JSON, or pickle)
print(f"\n📋 Loading metadata from: {config["metadata_path"]}")
metadata_path = Path("..")/ config['metadata_path']

if not Path(metadata_path).exists():
    raise FileNotFoundError(f"Metadata file not found: {metadata_path}")

# Determine file format from extension
file_ext = Path(metadata_path).suffix.lower()


if file_ext in ['.yaml', '.yml']:
    with open(metadata_path, 'r') as f:
        metadata = yaml.safe_load(f)
    
elif file_ext == '.json':
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
elif file_ext in ['.pkl', '.pickle']:
    with open(metadata_path, 'rb') as f:
        metadata = pickle.load(f)
    
else:
    raise ValueError(f"Unsupported metadata file format: {file_ext}. Supported formats: .yaml, .yml, .json, .pkl, .pickle")

print(f"📊 Available metadata keys: {list(metadata.keys())}")


print(f"📚 Vocabulary sizes from metadata:")
print(f"  - Source vocab: {len(metadata['src_vocab_list'])}")
print(f"  - Target vocab: {len(metadata['tgt_vocab_list'])}")
print(f"max_point_dim: {metadata['max_point_dim']}")
print(f"max_src_points: {metadata['max_src_points']}")
print(f"max_tgt_length: {metadata['max_tgt_length']}")


📋 Loading metadata from: data/training/superfib_r1_metadata.pickle
📊 Available metadata keys: ['max_tgt_length', 'max_src_points', 'max_point_dim', 'src_vocab_list', 'tgt_vocab_list', 'point_num_dist']
📚 Vocabulary sizes from metadata:
  - Source vocab: 1002
  - Target vocab: 15
max_point_dim: 4
max_src_points: 80
max_tgt_length: 859


In [8]:
def source_len(sample):
    return {"source_len": len(sample["source"])}

shuffled_dataset = eval_dataset.shuffle(seed=42)


point_num_dist = metadata.get("point_num_dist")

mini_data_num = 100

for i in range(mini_data_num):
    mini_dataset = shuffled_dataset.shard(num_shards=mini_data_num, index=i)
    #mini_dataset = mini_dataset.map(source_len, num_proc=1)
    #mini_dataset = mini_dataset.sort("source_len")
    point_num_dist = defaultdict(list)
    for idx, sample in enumerate(mini_dataset):
        point_num_dist[len(sample["source"])].append(idx)
    print(f"Processing shard {i+1}/{mini_data_num} with size {len(mini_dataset)}")
    dataloader = CSVDataModule(
        data_path = None,
        dataset=mini_dataset,
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        train_val_split=1 - config["test_ratio"],
        seed=42,
        batching_strategy="length_aware_token",
        min_tokens_per_batch=config['min_n_tokens_in_batch'],
        max_batch_size=config['batch_size'],
        point_num_dist=point_num_dist,
    )
    
    dataloader.setup()
    batch = next(iter(dataloader.train_dataloader()))
    print(f"MiniData {i+1} - Source shape: {batch['source'].shape}, Target shape: {batch['target'].shape}")

Processing shard 1/100 with size 71648
Train dataset: 35824 samples
Validation dataset: 35824 samples
Creating length-based groups...


Creating orig to subset idx mapping: 35824it [00:00, 10848862.56it/s]
Mapping point_num_dist to subset indices: 100%|██████████| 80/80 [00:00<00:00, 32131.03it/s]

Mapped to 80 groups for Subset:
  Length 4: 5554 samples (subset indices: [35362, 24771, 31447, 13090, 20282, 7229, 30471, 30077, 14707, 27052]...)
  Length 6: 5529 samples (subset indices: [16040, 22130, 21116, 11860, 24499, 15040, 33752, 9624, 32173, 3992]...)
  Length 5: 5468 samples (subset indices: [15603, 8047, 29130, 33021, 29597, 35815, 10229, 14297, 8499, 23734]...)
  Length 3: 2991 samples (subset indices: [20137, 28665, 3979, 12256, 30255, 7017, 32149, 15509, 31265, 25200]...)
  Length 11: 305 samples (subset indices: [7706, 27400, 27265, 12401, 34931, 1628, 19363, 575, 35772, 8595]...)



  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/takeruito/work/PrfSR/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/takeruito/work/PrfSR/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/takeruito/work/PrfSR/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/takeruito/work/PrfSR/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 172, in collate
    key: collate(
         ^^^^^^^^
  File "/Users/takeruito/work/PrfSR/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 207, in collate
    raise RuntimeError("each element in list of batch should be of equal size")
RuntimeError: each element in list of batch should be of equal size
