# Get metadata

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

def visualize_sent_len(df: pd.DataFrame):
    sequence_lengths = df["sent_len"].tolist()
    plt.figure(figsize=(10, 6))
    sns.histplot(sequence_lengths, bins=list(set(sequence_lengths)), kde=False)
    plt.xlabel('Number of Words in an Example')
    plt.ylabel('Frequency')
    plt.show()

data_dir = Path.cwd().parent / "data"
metadata_df = pd.read_parquet(data_dir / "processed/vasr/metadata.parquet")
mapping_df = pd.read_json(data_dir / "processed/vasr/mapping.json", dtype={"id": str, "shard": str})
df = metadata_df.merge(mapping_df, on=["id", "shard", "split"])
train_df = df[df["split"] == "train"]
train_df["sent_len"] = train_df["transcript"].apply(lambda x: len(x.split()))
len(train_df)

In [None]:
visualize_sent_len(train_df)

# Stratification

In [None]:
frac = 0.1
train_df = train_df.groupby("channel", group_keys=False).apply(
    lambda x: x.sample(frac=frac)
)
len(train_df)

In [None]:
visualize_sent_len(train_df)