filters `coteach_visuals_with_categories.csv` dataset to create a test set containing the specific geometric diagram types:
- 2D shapes: triangle, rectangle, circle
- 3D shapes: rectangular prism, cube

In [1]:
import pandas as pd

In [2]:
INPUT_CSV = "../data/coteach_visuals_with_categories.csv"


OUTPUT_CSV = "../data/geometric_shapes_test_set.csv"

df = pd.read_csv(INPUT_CSV)

print(f"Original dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
df.head()

Original dataset shape: (6000, 5)
Columns: ['prompt', 'tikz', 'image', 'main_category', 'subcategory']


Unnamed: 0,prompt,tikz,image,main_category,subcategory
0,"coordinate plane, -5 to 5 on both axes, plot p...",\documentclass{IM}\n\usepackage{tikz}\n\begin{...,https://2xavun1dsa0sayar.public.blob.vercel-st...,coordinate plane,
1,"coordinate plane, -5 to 5 on both axes, plot p...",\documentclass{IM}\n\usepackage{tikz}\n\begin{...,https://2xavun1dsa0sayar.public.blob.vercel-st...,coordinate plane,
2,"histogram, study hours with bars: 0-2 hours (5...",\documentclass{IM}\n\usepackage{tikz}\n\begin{...,https://2xavun1dsa0sayar.public.blob.vercel-st...,statistics,histogram
3,"histogram, study hours with bars: 0-2 hours (5...",\documentclass{IM}\n\usepackage{tikz}\n\begin{...,https://2xavun1dsa0sayar.public.blob.vercel-st...,statistics,histogram
4,"box plot, test scores with min=65, Q1=75, medi...",\documentclass{IM}\n\usepackage{tikz}\n\begin{...,https://2xavun1dsa0sayar.public.blob.vercel-st...,statistics,box plot


In [3]:
print(df['main_category'].value_counts())

print("\nSubcategory distribution (top 20):")
print(df['subcategory'].value_counts().head(20))

main_category
2d shapes                                                                                                                                                                            1346
3d shapes                                                                                                                                                                            1120
coordinate plane                                                                                                                                                                      574
statistics                                                                                                                                                                            490
number line                                                                                                                                                                           296
                                                        

In [4]:
TARGET_SHAPES = [
    'triangle',
    'rectangle', 
    'circle',
    'rectangular prism',
    'cube'
]

In [5]:
filtered_df = df[df['subcategory'].isin(TARGET_SHAPES)].copy()

print(f"Filtered dataset shape: {filtered_df.shape}")

Filtered dataset shape: (1739, 5)


In [6]:
def get_dataset_makeup(df):
    print(f"Dataset shape: {df.shape}")
    
    shape_counts = df['subcategory'].value_counts()
    for shape, count in shape_counts.items():
        print(f"  {shape:18}: {count:3d} rows")
    print(f"Total: {shape_counts.sum()} rows")

    main_cat_counts = df['main_category'].value_counts()
    for cat, count in main_cat_counts.items():
        print(f"  {cat}: {count} rows")

In [7]:
get_dataset_makeup(filtered_df)

Dataset shape: (1739, 5)
  rectangular prism : 680 rows
  triangle          : 471 rows
  rectangle         : 322 rows
  cube              : 152 rows
  circle            : 114 rows
Total: 1739 rows
  2d shapes: 907 rows
  3d shapes: 832 rows


In [8]:
# Deduplicate
deduplicated_df = filtered_df.drop_duplicates(subset=['prompt', 'tikz'])

get_dataset_makeup(deduplicated_df)
filtered_df = deduplicated_df

Dataset shape: (870, 5)
  rectangular prism : 340 rows
  triangle          : 236 rows
  rectangle         : 161 rows
  cube              :  76 rows
  circle            :  57 rows
Total: 870 rows
  2d shapes: 454 rows
  3d shapes: 416 rows


In [9]:
def sample(df, target_size=400):
      sample_fraction = target_size / len(df)
      sampled_dfs = []

      for shape in TARGET_SHAPES:
          shape_df = df[df['subcategory'] == shape]
          n_sample = int(len(shape_df) * sample_fraction)
          if n_sample > 0:
              sampled_dfs.append(shape_df.sample(n=n_sample,
  random_state=42))

      return pd.concat(sampled_dfs, ignore_index=True)

In [10]:
# randomly sample 400 via stratified sampling
rater_study_df = sample(filtered_df, target_size=400)

# Show final distribution
print("\nFinal rater study distribution:")
get_dataset_makeup(rater_study_df)


Final rater study distribution:
Dataset shape: (398, 5)
  rectangular prism : 156 rows
  triangle          : 108 rows
  rectangle         :  74 rows
  cube              :  34 rows
  circle            :  26 rows
Total: 398 rows
  2d shapes: 208 rows
  3d shapes: 190 rows


In [12]:
rater_study_df.to_csv(OUTPUT_CSV, index=False)
print(f"Filtered dataset exported to: {OUTPUT_CSV}")

Filtered dataset exported to: ../data/geometric_shapes_test_set.csv
