In [8]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------------------------------------------------------
# Assumptions about columns
# ---------------------------------------------------------
LABEL_COL = "subject_understood"   # zmień jeśli inaczej
VIDEO_COL = "video_id"

# df = pd.read_csv("data/raw/EEG_data.csv")  # jeśli jeszcze nie jest wczytane

os.makedirs("figures", exist_ok=True)

# ---------------------------------------------------------
# Precompute basic stats
# ---------------------------------------------------------
label_counts = df[LABEL_COL].value_counts().sort_index()   # 0,1
label_percent = label_counts / label_counts.sum() * 100

video_stats = (
    df.groupby(VIDEO_COL)[LABEL_COL]
      .mean()
      .reset_index()
      .sort_values(VIDEO_COL)
)

# ---------------------------------------------------------
# One figure, two panels: (a) global, (b) per video
# ---------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(6.0, 2.5))  # szeroka figura na dwa panele

# ---------- (a) Overall label distribution ----------
ax = axes[0]

bars = ax.bar(label_counts.index.astype(str), label_counts.values)

ax.set_xlabel("Label (understanding)")
ax.set_ylabel("Number of samples")

# podpisy w środku słupków
for i, (count, pct) in enumerate(zip(label_counts.values, label_percent.values)):
    ax.text(
        i,
        count * 0.5,
        f"n={count}\n({pct:.1f}%)",
        ha="center",
        va="center",
        fontsize=8,
        color="white",
    )

ax.set_title("(a)", loc="left", fontsize=9)

# ---------- (b) Fraction of label = 1 by video ----------
ax = axes[1]

x_labels = video_stats[VIDEO_COL].astype(str).values
fractions = video_stats[LABEL_COL].values

ax.bar(x_labels, fractions)

ax.set_xlabel("Video ID")
ax.set_ylabel("Fraction with label = 1")
ax.set_ylim(-0.02, 1.05)
ax.tick_params(axis="x", labelsize=8)

ax.set_title("(b)", loc="left", fontsize=9)

fig.tight_layout()
fig.savefig("figures/eda_labels_combined.png", dpi=300, bbox_inches="tight")
plt.close(fig)

print("Saved: figures/eda_labels_combined.png")


Saved: figures/eda_labels_combined.png
