In [143]:
from manim import *
from scipy.stats import norm
import numpy as np
import random
import math
import pandas as pd

def gen_dataset(n: int = 100) -> pd.DataFrame:
    groups = ["A", "B"]
    
    grps = random.choices(groups, weights=[0.3, 0.7], k=n)
    z_xs = np.vectorize(lambda x: norm.ppf(max(1e-9, x)))(np.random.rand(n))
    z_ys = np.vectorize(lambda x: norm.ppf(max(1e-9, x)))(np.random.rand(n))
    
    x_noise = np.vectorize(lambda x: 0.2*norm.ppf(max(1e-9, x)))(np.random.rand(n))
    y_noise = np.vectorize(lambda x: 0.2*norm.ppf(max(1e-9, x)))(np.random.rand(n))
    
    df = pd.DataFrame({"x": z_xs, "y": z_ys, "grp": grps, "x_noise": x_noise, "y_noise": y_noise})
    
    df['label'] = np.where(df['x'] - df['y'] < 1, 1, -1)
    
    df['sep_x'] = df['x'] + df['x_noise']
    df['sep_y'] = df['y'] + df['y_noise']
    
    # df.loc[df['grp'] == 'A', 'sep_x'] = df['x'] + 1
    # df.loc[df['grp'] == 'A', 'sep_y'] = df['y'] - .5
    
    df['y_hat'] = np.where(df['sep_x'] - df['sep_y'] < 1, 1, -1)
    
    return df


data = gen_dataset()

print(data.groupby(by=["grp", "label"]).size())

grp  label
A    -1        7
      1       18
B    -1       15
      1       60
dtype: int64


In [146]:
%%manim -qm -v WARNING Introduction

dots_grp = {
   "A": [],
   "B": []
}

label_grp = {
   -1: [],
   1: []
}

grp_map = {
   "A": RED,
   "B": BLUE
}

labels_map = {
   -1: TEAL_A,
   1: ORANGE
}

class Introduction(Scene):
   def construct(self):
      ax = Axes(
         x_range=[-5, 5, 1],
         y_range=[-5, 5, 1],
         x_length=6.5,
         y_length=6.5,
         tips=False,
         axis_config={"include_numbers": False}
      )
      
      for row in data.iterrows():
         x, y, g, lbl = row[1]['x'], row[1]['y'], row[1]['grp'], row[1]['label']
         d = Dot(ax.c2p(x, y, 0), radius=0.04, color=grp_map[g])
         dots_grp[g].append(d)
         label_grp[lbl].append(d)
      
      group_a_tex = Tex(r"$\mathbb{A}$", font_size=80, color=RED).to_edge(UR)
      group_b_tex = Tex(r"$\mathbb{B}$", font_size=80, color=BLUE).next_to(group_a_tex, DOWN)
      
      group_frameboxes = [
         SurroundingRectangle(group_a_tex, buff = .1),
         SurroundingRectangle(group_b_tex, buff = .1)
      ]
      
      positive_tex = Tex(r"Label:  1", color=labels_map[1], font_size=60).to_edge(UL)
      negative_tex = Tex(r"Label: -1", color=labels_map[-1], font_size=60).next_to(positive_tex, DOWN)
      
      label_frameboxes = [
         SurroundingRectangle(positive_tex, buff=.1),
         SurroundingRectangle(negative_tex, buff=.1)
      ]
      
      graph = ax.plot(lambda x: x - 1, x_range=[-5, 5], use_smoothing=False, color=YELLOW)
      graph_label = ax.get_graph_label(graph, r"\text{Decision Boundary}", x_val=5, direction=UR)
      
      label_neg_count = DecimalNumber(number=0, num_decimal_places=1, color=labels_map[-1]).next_to(graph, UL)
      label_pos_count = DecimalNumber(number=0, num_decimal_places=1, color=labels_map[1]).next_to(label_neg_count, DOWN)
      
      label_neg_a_count = DecimalNumber(number=0, num_decimal_places=1, color=grp_map["A"]).next_to(label_neg_count, LEFT, buff=0.75)
      label_pos_a_count = DecimalNumber(number=0, num_decimal_places=1, color=grp_map["A"]).next_to(label_pos_count, LEFT, buff=0.75)
      
      label_neg_b_count = DecimalNumber(number=0, num_decimal_places=1, color=grp_map["B"]).next_to(label_neg_a_count, LEFT, buff=0.75)
      label_pos_b_count = DecimalNumber(number=0, num_decimal_places=1, color=grp_map["B"]).next_to(label_pos_a_count, LEFT, buff=0.75)
      
      pos_given_a_c = ((data["label"] == 1) & (data["grp"] == "A")).sum() / (data["grp"] == "A").sum()
      pos_given_a = Tex(
         r"P$(\widehat{Y} = 1 | G =$", "$\mathbb{A}$", f"$) = {pos_given_a_c: 0.2f}$",
         font_size=32,
      ).next_to(graph, UR)
      pos_given_a.set_color_by_tex("A", RED)
      
      pos_given_b_c = ((data["label"] == 1) & (data["grp"] == "B")).sum() / (data["grp"] == "B").sum()
      pos_given_b = Tex(
         r"P$(\widehat{Y} = 1 | G =$", "$\mathbb{B}$", f"$) = {pos_given_b_c: 0.2f}$",
         font_size=32,
      ).next_to(pos_given_a, DOWN)
      pos_given_b.set_color_by_tex("B", BLUE)
         
      
      # introduce A,B,C datasets
      self.play(Write(VGroup(ax, group_a_tex, group_b_tex)), run_time=0.5)
      self.wait(4)
      
      self.play(Write(group_frameboxes[0]))
      self.play(Create(VGroup(*dots_grp["A"])))
      self.wait()
      
      self.play(ReplacementTransform(group_frameboxes[0], group_frameboxes[1]), Create(VGroup(*dots_grp["B"])))
      self.wait()
      
      self.play(FadeOut(group_frameboxes[1], group_a_tex, group_b_tex))

      # introduce classifier labels
      self.play(Write(VGroup(positive_tex, negative_tex)))
      self.play(Write(label_frameboxes[0]))
      self.play(FadeToColor(VGroup(*label_grp[1]), color=ORANGE))
      
      self.play(ReplacementTransform(label_frameboxes[0], label_frameboxes[1]), FadeToColor(VGroup(*label_grp[-1]), color=TEAL_A))
      self.play(FadeOut(label_frameboxes[1], positive_tex, negative_tex))
      self.wait(2)
      
      self.play(Create(graph))
      self.play(Write(graph_label))
      
      self.wait()
      
      self.play(FadeOut(graph_label, graph))
      self.play(Write(label_neg_count), ChangeDecimalToValue(label_neg_count, data[data["label"] == -1].shape[0]), Write(label_pos_count), ChangeDecimalToValue(label_pos_count, data[data["label"] == 1].shape[0]), rate_func=linear, run_time=0.25)
      
      self.wait(3)
      
      self.play(
         Write(label_neg_a_count), 
         ChangeDecimalToValue(label_neg_a_count, data[(data["label"] == -1) & (data["grp"] == "A")].shape[0]),
         Write(label_pos_a_count),
         ChangeDecimalToValue(label_pos_a_count, data[(data["label"] == 1) & (data["grp"] == "A")].shape[0]),
         Write(label_neg_b_count),
         ChangeDecimalToValue(label_neg_b_count, data[(data["label"] == -1) & (data["grp"] == "B")].shape[0]),
         Write(label_pos_b_count),
         ChangeDecimalToValue(label_pos_b_count, data[(data["label"] == 1) & (data["grp"] == "B")].shape[0]),
         rate_func=linear,
         run_time=0.25
      )
      
      self.wait(1)
      
      self.play(Write(pos_given_a), Write(pos_given_b), rate_func=linear)
      
      self.wait(4)
      
      group = VGroup(pos_given_a, pos_given_b)
      self.play(
         group.animate.scale(2.5).center(),
         FadeOut(
            ax, label_neg_a_count, 
            label_pos_a_count, 
            label_neg_b_count, 
            label_pos_b_count,
            label_neg_count,
            label_pos_count,
            *label_grp[-1], 
            *label_grp[1],
            
         ),
         rate_func=smooth,
         run_time=1
      )
      
      self.wait(1)
      
      title = Text('Criterion #1: Independence').scale(0.85).next_to(pos_given_b, DOWN)
      desc = Text('The rate of positive classification is equal across groups.', t2w={'equal':BOLD}, t2s={"rate of positive classification": ITALIC}).next_to(title, DOWN).scale(0.55)
      self.play(
         group.animate.shift(UP),
         Write(VGroup(title, desc)),
         run_times=3
      )
      
      self.wait(1)
      




                                                                                                                                      

In [144]:
%%manim -qm -v WARNING IndTable

dots_grp = {
   "A": [],
   "B": []
}

label_grp = {
   -1: [],
   1: []
}

y_hats = {
      -1: [],
      1: []
}

mislabels = []

class IndTable(Scene):
      def construct(self):
            ax = Axes(
                  x_range=[-5, 5, 1],
                  y_range=[-5, 5, 1],
                  x_length=6.5,
                  y_length=6.5,
                  tips=False,
                  axis_config={"include_numbers": False}
            )

            shift_animations = []
            all_dots = []
            for row in data.iterrows():
                  x, y, g, lbl = row[1]['x'], row[1]['y'], row[1]['grp'], row[1]['label']
                  d = Dot(ax.c2p(x, y, 0), radius=0.04, color=grp_map[g])
                  all_dots.append(d)
                  shift_animations.append(d.animate.move_to(ax.c2p(row[1]['sep_x'], row[1]['sep_y'], 0)))
                  dots_grp[g].append(d)
                  label_grp[lbl].append(d)
                  
                  if lbl != row[1]['y_hat']:
                        mislabels.append(d)
                  
                  y_hats[row[1]['y_hat']].append(d)
            
            error_title = Text(f"Misclassifications: {len(mislabels)}", color=YELLOW).scale(0.75).to_edge(UL)
            
            a_error_cnt = ((data["y_hat"] != data["label"]) & (data["grp"] == "A")).sum()
            a_errors = Tex(
                  "$\mathbb{A}$:", f"  {a_error_cnt}"
            ).next_to(error_title, DOWN).set_color_by_tex("A", RED)
            
            b_error_cnt = ((data["y_hat"] != data["label"]) & (data["grp"] == "B")).sum()
            b_errors = Tex(
                  "$\mathbb{B}$:", f"  {b_error_cnt}"
            ).next_to(a_errors, DOWN).set_color_by_tex("B", BLUE)
            
            rate_title = Text(f"Error Rate: {len(mislabels) * 100 / data.shape[0]: .2f}%", color=YELLOW).scale(0.75).to_edge(UR)
            a_err_rt = Tex(
                  "$\mathbb{A}$:", f"  {a_error_cnt * 100 /(data['grp'] == 'A').sum():.2f}\%"
            ).next_to(rate_title, DOWN).set_color_by_tex("A", RED)
            
            b_err_rt = Tex(
                  "$\mathbb{B}$:", f"  {b_error_cnt * 100 /(data['grp'] == 'B').sum():.2f}\%"
            ).next_to(a_err_rt, DOWN).set_color_by_tex("B", BLUE)
            
            self.add(ax, *all_dots)
            
            self.wait(2)
            
            # shift
            self.play(*shift_animations)
            
            # then recolor to just ground truth label
            self.wait(2)
            self.play(FadeToColor(VGroup(*label_grp[-1]), color=labels_map[-1]), FadeToColor(VGroup(*label_grp[1]), color=labels_map[1]))
            
            # show boundary line, show some misclassifications
            self.wait(2)
            self.play(FadeToColor(VGroup(*mislabels), color=YELLOW), Write(error_title), Write(a_errors), Write(b_errors))
            self.wait(2)
            
            self.play(FadeToColor(VGroup(*label_grp[-1]), color=labels_map[-1]), FadeToColor(VGroup(*label_grp[1]), color=labels_map[1]))
            
            self.wait(2)
            self.play(Write(rate_title), Write(a_err_rt), Write(b_err_rt))
            
            self.wait(2)
            
            # end point expressions
            target_group = VGroup(a_err_rt.copy(), b_err_rt.copy()).arrange(RIGHT, buff=0.5)
            target_group.move_to(ORIGIN)
            
            self.play(
                  a_err_rt.animate.move_to(target_group[0]).scale(1),
                  b_err_rt.animate.move_to(target_group[1]).scale(1),
                  FadeOut(
                        ax, 
                        rate_title,
                        *dots_grp['A'],
                        *dots_grp['B'],
                        error_title,
                        a_errors,
                        b_errors
                  ),
                  run_time=2
            )
            
            title = Text('Criterion #2: Separation').scale(0.85).next_to(target_group, DOWN)
            desc = Text('Error rates are equal across groups', t2w={'equal':BOLD}, t2s={'Error rates': ITALIC}).next_to(title, DOWN).scale(0.55)
            self.play(
                  target_group.animate.shift(UP),
                  a_err_rt.animate.shift(UP),
                  b_err_rt.animate.shift(UP),
                  Write(VGroup(title, desc)),
                  run_times=3
            )
            
            # Show Errors
            
            



                                                                                                                     