In [1]:
import os
import pandas as pd

# ---- 설정 ----
DATA_PATH  = "../data/deep_chal_multitask_dataset.parquet"  # 전체 데이터
VAL_FRAC   = 0.05   # 검증 비율(5%)
RANDOM_SEED = 42

# 저장 경로(원본과 같은 폴더에 저장)
out_dir   = os.path.dirname(os.path.abspath(DATA_PATH))
TRAIN_OUT = os.path.join(out_dir, "train.parquet")
VAL_OUT   = os.path.join(out_dir, "val.parquet")

# ---- 로드 ----
df = pd.read_parquet(DATA_PATH, engine="fastparquet")
assert "task" in df.columns, "'task' 컬럼이 없습니다."

# ---- stratified split: 각 task에서 일정 비율 샘플링 ----
val_indices = []
for task, g in df.groupby("task"):
    n_total = len(g)
    # 기본은 비율 기반
    n_val = max(1, int(round(n_total * VAL_FRAC)))
    # 전체가 다 val로 가는 상황 방지 (train에도 최소 1개 남기기)
    if n_val >= n_total and n_total > 1:
        n_val = n_total - 1
    picked = g.sample(n=n_val, random_state=RANDOM_SEED).index
    val_indices.append(picked)

import numpy as np
val_indices = np.concatenate(val_indices) if len(val_indices) else np.array([], dtype=int)

val_df   = df.loc[val_indices].copy()
train_df = df.drop(index=val_indices).copy()

# 셔플(선택)
val_df   = val_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)
train_df = train_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

# ---- 저장 ----
val_df.to_parquet(VAL_OUT, engine="fastparquet", index=False)
train_df.to_parquet(TRAIN_OUT, engine="fastparquet", index=False)

# ---- 요약 출력 ----
print(f"Total: {len(df):,}  ->  train: {len(train_df):,},  val: {len(val_df):,}")
print("\n[Train per task]")
print(train_df["task"].value_counts().sort_index())
print("\n[Val per task]")
print(val_df["task"].value_counts().sort_index())
print(f"\nSaved:\n  {TRAIN_OUT}\n  {VAL_OUT}")


Total: 44,672  ->  train: 42,438,  val: 2,234

[Train per task]
task
captioning        9500
math_reasoning    7099
summarization     9500
text_qa           6839
vqa               9500
Name: count, dtype: int64

[Val per task]
task
captioning        500
math_reasoning    374
summarization     500
text_qa           360
vqa               500
Name: count, dtype: int64

Saved:
  /storage/lyh/2025_deep_challenge/data/train.parquet
  /storage/lyh/2025_deep_challenge/data/val.parquet


In [2]:
train_df.head()

Unnamed: 0,input_type,task,input,output,question
0,text,summarization,SECTION 1. SHORT TITLE.\n\n This Act may be...,Prohibits Federal agencies from offsetting fun...,
1,text,text_qa,"Guyana (pronounced or ), officially the Co-ope...",{'input_text': ['Anglo Caribbean countries and...,"['What does CARICOM stand for?', 'What is a so..."
2,image,captioning,https://pulpcovers.com/wp-content/uploads/2014...,"\nThe image is the cover of a book titled ""Mar...",
3,image,captioning,https://pulpcovers.com/wp-content/uploads/2011...,\nThe image is a colorful illustration of a ma...,
4,text,math_reasoning,Megan is making food for a party. She has to ...,"First, we need to determine the total amount o...",
