In [None]:
import os
import sys
# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

import torch
import numpy as np
import random
import torch
import json

import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from src.prepare_data import load_dataset
from src.token import SpecialToken


%matplotlib inline

In [None]:
dataset_file = "../intermediate_data/prepared_dataset_using_arc.pth"

dataset, data_sources, _ = load_dataset(dataset_file)
print(f'dataset length: {len(dataset)}')

In [None]:
from tqdm.notebook import tqdm

def draw_histogram(seq_lengths, num_bins=20):
    plt.figure(figsize=(10, 6))
    
    # Calculate percentiles
    percentiles = np.percentile(seq_lengths, [10, 20, 30, 40, 50, 60, 70, 80, 90])
    
    # Plot histogram
    n, bins, patches = plt.hist(seq_lengths, bins=num_bins, weights=np.ones(len(seq_lengths)) / len(seq_lengths), edgecolor='black')
    
    # Add percentile lines
    for i, p in enumerate(percentiles):
        plt.axvline(p, color='r', linestyle='dashed', linewidth=1)
        plt.text(p, plt.ylim()[1], f'{(i+1)*10}th', rotation=90, va='top', ha='right')
    
    plt.xlabel('Sequence Length')
    plt.ylabel('Frequency')
    plt.title('Histogram of Sequence Lengths')
    plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
    
    # Print percentiles
    print("Percentiles:")
    for i, p in enumerate(percentiles):
        print(f"{(i+1)*10}th percentile: {p:.2f}")
    
    plt.show()

seq_lengths = []

print('tqdm starts here')

for index, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Processing batches"):
    seq_lengths.extend([len(batch['task'])])

max_length = max(seq_lengths)
print("max_length:", max_length)

# Draw histogram with proper binning
draw_histogram(seq_lengths)

In [None]:
def extract_grid_sizes(sequence, sample_index):
    grid_edge_sizes = []
    current_section = None
    current_data = []

    token_index = 0
    while token_index < len(sequence):
        token = sequence[token_index]
        if token == SpecialToken.START_INPUT.value or token == SpecialToken.START_OUTPUT.value:
            h, w = sequence[token_index + 1:token_index + 3]
            token_index += 1 + 2 + int(h) * int(w)
            grid_edge_sizes.extend([h, w])
        else:
            assert False
            
    return grid_edge_sizes

def draw_histogram(grid_sizes, num_bins=20):
    # print('grid_sizes', grid_sizes)
    plt.figure(figsize=(10, 6))
    
    # Calculate percentiles
    percentiles = np.percentile(grid_sizes, [10, 20, 30, 40, 50, 60, 70, 80, 90])
    
    # Plot histogram
    n, bins, patches = plt.hist(grid_sizes, bins=num_bins, weights=np.ones(len(grid_sizes)) / len(grid_sizes), edgecolor='black')
    
    # Add percentile lines
    for i, p in enumerate(percentiles):
        plt.axvline(p, color='r', linestyle='dashed', linewidth=1)
        plt.text(p, plt.ylim()[1], f'{(i+1)*10}th', rotation=90, va='top', ha='right')
    
    plt.xlabel('Grid Size')
    plt.ylabel('Frequency')
    plt.title('Histogram of Grid Sizes')
    plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
    
    # Print percentiles
    print("Percentiles:")
    for i, p in enumerate(percentiles):
        print(f"{(i+1)*10}th percentile: {p:.2f}")
    
    plt.show()
    
dataset, data_sources, _ = load_dataset(dataset_file)

grid_sizes = []
for sample_index in range(len(dataset.data)): # range(17, 18):
    sizes = extract_grid_sizes(dataset.data[sample_index], sample_index)
    grid_sizes.extend(sizes)
    
    # Draw histogram with proper binning
draw_histogram(grid_sizes)

print('min, max:', min(grid_sizes), max(grid_sizes))