In [111]:
import numpy as np
from scipy.stats import poisson
from manim import *
config.media_width = "60%"

In [83]:
import typing as t

In [84]:
import cv2

In [85]:
nll_image = cv2.imread("nll1.png")[:, :, ::-1]

In [137]:
%%manim -v WARNING -qm Slides04Inference

class Slides04Inference(Scene):
    
    def create_y_labels(self, plane: NumberPlane):
        self.y_labels = ["A", "C", "B", "A"]
        self.y_labels_color = {"A": PURPLE, "B": GOLD, "C": GREEN}
        
        self.y_label_group = []
        for i, yl in enumerate(self.y_labels):
            text = Tex(yl).scale(0.5)
            text.next_to(plane.coords_to_point(0, i + 0.5), LEFT * 0.5)
            self.y_label_group.append(text)
        self.y_label_group.reverse()
        
    def create_x_labels(self, plane: NumberPlane):
        x_labels = list(range(1, 28))
        
        self.x_label_group = []
        for i, xl in enumerate(x_labels):
            text = Tex(xl).scale(0.5)
            text.next_to(plane.coords_to_point(i + 0.5, 0), DOWN * 0.5)
            self.x_label_group.append(text)
            
    def create_probabilities_and_plane(self):
        self.nll = ImageMobject(nll_image, scale_to_resolution=700)
        self.nll.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
        
        self.plane = NumberPlane(
            x_range=[0, 27.01, 1],
            y_range=[0, 4.01, 1],
            x_length=self.nll.width,
            y_length=self.nll.height,
            axis_config={
                "stroke_color": BLUE,
                "stroke_width": 2,
            },
            background_line_style={
                "stroke_color": BLUE,
                "stroke_width": 2,
            }
        )
        
        self.probabilities = Group(self.nll, self.plane).shift(DOWN * 0.75)
        
    def animate_optimization_problem(self):
        self.the_objective = MathTex(
            r"\ell_{1:N}",
            # 0
            r" = ",
            # 1
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmax}}",
            # 2
            r" \Big\{ ",
            # 3
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            # 4                            5       6     7         8
            r" \Big\}",
            # 9
        )
        
        objective_step_1 = MathTex(
            r"\ell_{1:N}",
            r" &= ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}",
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\big(",
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            r"\big)",
            r" \Big\}",
        )
        
        objective_step_2 = MathTex(
            r"\ell_{1:N}",
            r" &= ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}",
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\bigg(",
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            r"\bigg)",
            r" \Big\}",
            
            r"\\",
            r" &= ",
            
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}", # 16
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\bigg(",
            r"\prod_{t=1}^{T} p\big(", # 21
            r"\alpha(t;c_{1:N}, \ell_{1:N})", # 22
            r" | ", # 23
            r"x_t", # 24
            r"\big)", # 25
            r" \cdot ",
            r"\prod_{n=1}^{N}", # 27
            r" p\big( ", # 28
            r"\ell_n", # 29
            r" | ", # 30
            r"c_n", # 31
            r"\big)", # 32
            r"\bigg)",
            r" \Big\}",
        )
        
        objective_step_3 = MathTex(
            r"\ell_{1:N}",
            r" &= ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}",
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\bigg(",
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            r"\bigg)",
            r" \Big\}",
            
            r"\\",
            r" &= ",
            
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}", # 16
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\bigg(",
            r"\prod_{t=1}^{T} p\big(", # 21
            r"\alpha(t)", # 22
            r" | ", # 23
            r"x_t", # 24
            r"\big)", # 25
            r" \cdot ",
            r"\prod_{n=1}^{N}", # 27
            r" p\big( ", # 28
            r"\ell_n", # 29
            r" | ", # 30
            r"c_n", # 31
            r"\big)", # 32
            r"\bigg)",
            r" \Big\}",
        )
        
        objective_step_4 = MathTex(
            r"\ell_{1:N}",
            r" &= ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}",
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\bigg(",
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            r"\bigg)",
            r" \Big\}",
            
            r"\\",
            r" &= ",
            
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}", # 16
            r" \Big\{ ",

            r"\bigg(",
            r"\sum_{t=1}^{T}", # 21
            
            r"-",
            r"\log",
            r"p\big(",
            r"\alpha(t)", # 22
            r" | ", # 23
            r"x_t", # 24
            r"\big)", # 25
            
            r" + ",
            
            r"\sum_{n=1}^{N}", # 27
            r"-",
            r"\log",
            
            r" p\big( ", # 28
            r"\ell_n", # 29
            r" | ", # 30
            r"c_n", # 31
            r"\big)", # 32
            r"\bigg)",
            r" \Big\}",
        ) 
        
        self.full_objective = MathTex(
            r"c_{1:N}",
            r", ",
            r"\ell_{1:N}",
            r" = ",
            r"\underset{\hat{c}_{1:N}}{\mathrm{argmax}}",
            r" ~ ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmax}}",
            r" \Big\{ ",
            r"p(\hat{\ell}_{1:N} | ", 
            r"x_{1:T}", ", ", r"\hat{c}_{1:N}", r") ",
            r"\Big\}",
        )
        
        self.play(Write(self.the_objective))
        
        self.animate_part_of_math(self.the_objective[5])
        self.wait()
        self.animate_part_of_math(self.the_objective[7])
        self.wait()
        self.animate_part_of_math(self.the_objective[2])
        self.wait()
        
        self.play(TransformMatchingTex(self.the_objective, self.full_objective))
        self.full_objective.set_color_by_tex(r"c_{1:N}", YELLOW)
        self.full_objective.set_color_by_tex(r"\underset{\hat{c}_{1:N}}{\mathrm{argmax}}", YELLOW)
        self.wait()
        
        self.play(FadeOut(self.full_objective))
        self.play(FadeIn(self.the_objective))
        self.wait()
        
        self.play(TransformMatchingTex(self.the_objective, objective_step_1))
        self.wait()
        self.play(TransformMatchingTex(objective_step_1, objective_step_2))
        self.wait()
        
        self.animate_part_of_math(objective_step_2[22])
        self.wait()
        
        # shorten the alpha
        self.play(TransformMatchingTex(objective_step_2, objective_step_3))
        self.wait()
        
        self.animate_part_of_math(objective_step_3[21:26])
        self.wait()
        self.animate_part_of_math(objective_step_3[27:33])
        self.wait()
        
        self.play(TransformMatchingTex(objective_step_3, objective_step_4))
        self.wait()
        
#         self.play(self.the_objective.animate.move_to(self.title.get_bottom() + DOWN).scale(0.8))

    def animate_optimization_problem_short(self):
        self.the_objective = MathTex(
            r"\ell_{1:N}",
            # 0
            r" = ",
            # 1
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmax}}",
            # 2
            r" \Big\{ ",
            # 3
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            # 4                            5       6     7         8
            r" \Big\}",
            # 9
        )
        
        objective_step_1 = MathTex(
            r"\ell_{1:N}",
            r" &= ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}",
            r" \Big\{ ",
            r"-",
            r"\log",
            r"\big(",
            r"p(\hat{\ell}_{1:N} | ", "x_{1:T}", ", ", "c_{1:N}", ")",
            r"\big)",
            r" \Big\}",
        )
        
        objective_step_4 = MathTex(
            r"\ell_{1:N}",
            r" &= ",

            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}", # 2
            r" \Big\{ ", # 3

            r"\sum_{t=1}^{T}", # 4
            
            r"-",
            r"\log",
            r"p\big(", # 7
            r"\alpha(t;c_{1:N}, \ell_{1:N})", # 8
            r" | ", # 9
            r"x_t", # 10
            r"\big)", # 11
            
            r" + ", # 12
            
            r"\sum_{n=1}^{N}", # 13
            r"-", # 14
            r"\log", # 15
            
            r" p\big( ", # 16
            r"\ell_n", # 17
            r" | ", # 18
            r"c_n", # 19
            r"\big)", # 20
            r" \Big\}",
        )
        
        self.objective_step_5 = MathTex(
            r"\ell_{1:N}",
            r" &= ",

            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmin}}", # 2
            r" \Big\{ ", # 3

            r"\sum_{t=1}^{T}", # 4
            
            r"-",
            r"\log",
            r"p\big(", # 7
            r"\alpha(t)", # 8
            r" | ", # 9
            r"x_t", # 10
            r"\big)", # 11
            
            r" + ", # 12
            
            r"\sum_{n=1}^{N}", # 13
            r"-", # 14
            r"\log", # 15
            
            r" p\big( ", # 16
            r"\ell_n", # 17
            r" | ", # 18
            r"c_n", # 19
            r"\big)", # 20
            r" \Big\}",
        ) 
        
        self.full_objective = MathTex(
            r"c_{1:N}",
            r", ",
            r"\ell_{1:N}",
            r" = ",
            r"\underset{\hat{c}_{1:N}}{\mathrm{argmax}}",
            r" ~ ",
            r"\underset{\hat{\ell}_{1:N}}{\mathrm{argmax}}",
            r" \Big\{ ",
            r"p(\hat{\ell}_{1:N} | ", 
            r"x_{1:T}", ", ", r"\hat{c}_{1:N}", r") ",
            r"\Big\}",
        )
        
        self.play(FadeIn(self.the_objective))
        self.wait()
        
        self.animate_part_of_math(self.the_objective[5])
        self.wait()
        self.animate_part_of_math(self.the_objective[7])
        self.wait()
        self.animate_part_of_math(self.the_objective[2])
        self.wait()
        
        self.play(TransformMatchingTex(self.the_objective, self.full_objective))
        self.full_objective.set_color_by_tex(r"c_{1:N}", YELLOW)
        self.full_objective.set_color_by_tex(r"\underset{\hat{c}_{1:N}}{\mathrm{argmax}}", YELLOW)
        self.wait()
        
        self.play(FadeOut(self.full_objective))
        self.play(FadeIn(self.the_objective))
        self.wait()
        
        self.play(TransformMatchingTex(self.the_objective, objective_step_4.scale(0.9)))
        self.wait()
        
        self.animate_part_of_math(objective_step_4[4:12], r"E_o")
        self.wait()
        
        self.animate_part_of_math(objective_step_4[13:21], r"E_\ell")
        self.wait()
        
        self.play(TransformMatchingTex(objective_step_4, self.objective_step_5))
        
        self.play(self.objective_step_5.animate.move_to(self.title.get_bottom() + DOWN).scale(0.8))
        
    
    def animate_part_of_math(self, math_part, underline_text: t.Optional[str] = None):
        prev_color = WHITE
        underline = Underline(math_part, buff=0.4)
        if underline_text:
            underline_text = MathTex(underline_text).move_to(underline.get_bottom() + DOWN * 0.5)
            self.play(
                AnimationGroup(
                    Create(underline),
                    Write(underline_text),
                    math_part.animate.set_color(YELLOW)
                )
            )
        else:
            self.play(
                AnimationGroup(
                    Create(underline),
                    math_part.animate.set_color(YELLOW)
                )
            )
        self.wait()
        self.remove(underline)
        if underline_text:
            self.remove(underline_text)
        math_part.set_color(prev_color)
    
    
    def animate_probabilities_and_labels(self):
        self.create_y_labels(self.plane)
        self.create_x_labels(self.plane)
        self.play(FadeIn(self.probabilities))
        self.play(
            AnimationGroup(
                FadeIn(VGroup(*self.y_label_group), shift=RIGHT),
                FadeIn(VGroup(*self.x_label_group), shift=UP)
            )
        )
        self.wait()
        
    def show_the_path_start_and_end(self):
        self.start_point = self.plane.coords_to_point(0.5, 3.5)
        self.end_point = self.plane.coords_to_point(26.5, 0.5)
        
        self.start_point_dot = Dot(self.start_point, color=RED)
        self.end_point_dot = Dot(self.end_point, color=RED)
        
        self.add(self.start_point_dot, self.end_point_dot)
        self.wait(0.5)
        self.play(Flash(self.start_point_dot))
        self.wait(0.5)
        self.play(Flash(self.end_point_dot))
    
    def get_cumsum_value(self, index: int) -> int:
        result = 0
        for i in range(0, index + 1):
            result += self.length_values[i].get_value()
        return result
    
    def get_start_point(self, index: int) -> t.Tuple[float, float]:
        y_value = len(self.length_values) - index - 0.5
        x_value = self.get_cumsum_value(index) - self.length_values[index].get_value() + 0.5
        
        return self.plane.coords_to_point(x_value, y_value)
        
    
    def get_end_point(self, index: int) -> t.Tuple[float, float]:
        y_value = len(self.length_values) - index - 0.5
        x_value = self.get_cumsum_value(index) - 0.5
        
        return self.plane.coords_to_point(x_value, y_value)
    
    def create_line(self, index: int) -> Line:
        return Line(
            self.get_start_point(index), self.get_end_point(index), color=RED, stroke_width=5
        ).add_updater(
            lambda m: m.become(
                Line(self.get_start_point(index), self.get_end_point(index), color=RED, stroke_width=5)
            )
        )
    
    def create_transition_line(self, index: int) -> Line:
        return Line(
            self.get_end_point(index), self.get_start_point(index + 1), color=RED,
        ).add_updater(
            lambda m: m.become(
                Line(self.get_end_point(index), self.get_start_point(index + 1), color=RED)
            )
        )
    
    def draw_initial_lines(self):
        self.lines = [self.create_line(x) for x in range(len(self.length_values))]
        self.transition_lines = [self.create_transition_line(x) for x in range(len(self.length_values) - 1)]
        
        for ind, li in enumerate(self.lines):
            self.add(li)
            self.play(Create(li))
            if ind < len(self.transition_lines):
                tli = self.transition_lines[ind]
                self.add(tli)
                self.play(Create(tli), run_time=0.25)
        self.wait()
    
    def get_rectangle_width(self, index: int) -> float:
        start_coords = self.get_cumsum_value(index) - self.length_values[index].get_value()
        end_coords = start_coords + self.length_values[index].get_value()
        
        start_point = self.plane.coords_to_point(start_coords, 0)
        end_point = self.plane.coords_to_point(end_coords, 0)
        
        return end_point[0] - start_point[0]
        
    def get_rectangle_middle_point(self, index: int) -> t.Tuple[float, float, float]:
        start_coords = self.get_cumsum_value(index) - self.length_values[index].get_value()
        end_coords = start_coords + self.length_values[index].get_value()
        
        start_point = self.plane.coords_to_point(start_coords, 0)
        end_point = self.plane.coords_to_point(end_coords, 0)
        
        width = end_point[0] - start_point[0]
        height = 0.5
        return (start_point[0] + width / 2, start_point[1] + height/2, 0)
    
    def get_rectangle(self, index: int) -> Rectangle:
        down_shift_amount = DOWN * 1.5
        height = 0.5
        
        return Rectangle(
            width=self.get_rectangle_width(index),
            height=height,
        ).set_fill(self.y_labels_color[self.y_labels[index]], opacity=0.7).move_to(self.get_rectangle_middle_point(index)).shift(down_shift_amount).add_updater(
            lambda m: m.become(
                Rectangle(
                    width=self.get_rectangle_width(index),
                    height=height,
                ).set_fill(self.y_labels_color[self.y_labels[index]], opacity=0.7).move_to(self.get_rectangle_middle_point(index)).shift(down_shift_amount)
            )
        )
    
    def get_segmentation_action_label(self, index: int, rectangle: Rectangle) -> Tex:
        return Tex(self.y_labels[index]).scale(0.7).add_updater(lambda m: m.move_to(rectangle.get_center()))
    
    def draw_segmentation_output(self):
        self.segmentation_rectangles = []
        self.action_labels = []
        for i, label in enumerate(self.y_labels):
            rect = self.get_rectangle(i)
            self.segmentation_rectangles.append(rect)
            action_label = self.get_segmentation_action_label(i, rect)
            self.action_labels.append(action_label)
            
        self.add(Group(*self.segmentation_rectangles))
        self.add(Group(*self.action_labels))
        
    def animate_part_of_math_dont_remove(self, math_part, underline_text: str) -> t.Tuple[VMobject]:
        underline = Underline(math_part, buff=0.1)
        underline_text = MathTex(underline_text).move_to(underline.get_bottom() + DOWN * 0.4)
        self.play(
            AnimationGroup(
                Create(underline),
                Write(underline_text),
                math_part.animate.set_color(YELLOW)
            )
        )
        
        return underline, underline_text, math_part
        
    def remove_after_animate_part_of_math(self, things: t.Tuple[VMobject]):
        prev_color = WHITE
        
        underline, underline_text, math_part = things
        self.remove(underline)
        if underline_text:
            self.remove(underline_text)
        math_part.set_color(prev_color)
        
    def get_background_rectangles_given_the_path(self) -> t.List[BackgroundRectangle]:
        x_coord = 0.5
        coords = []
        for j in range(len(self.y_labels)):
            y_coord = len(self.y_labels) - j - 0.5
            for j in range(int(self.length_values[j].get_value())):
                coords.append([x_coord, y_coord])
                x_coord += 1
        
        dots = [Dot(self.plane.coords_to_point(x_c, y_c)) for (x_c, y_c) in coords]
        return list([BackgroundRectangle(d, color=YELLOW, fill_opacity=0.5, buff=0.12) for d in dots])
        
        
    def animate_optimization_part_one(self):
        # E_o
        eo_things = self.animate_part_of_math_dont_remove(self.objective_step_5[4:12], r"E_o")
        self.wait()
        rects = VGroup(*self.get_background_rectangles_given_the_path())
        self.play(
            LaggedStart(*[ShowCreationThenFadeOut(r) for r in rects])
        )
        self.remove_after_animate_part_of_math(eo_things)
        
    def create_poisson_distribution(self, mean: int, color, label: str) -> VGroup:
        axes = Axes(
            x_range=[-0.1, 20.9, 1],
            y_range=[-0.05, 0.34, 0.05],
            x_length=5,
            y_length=3,
            x_axis_config={"numbers_to_include": np.arange(0, 21, 2)},
            y_axis_config={"numbers_to_include": np.arange(0, 0.34, 0.05)},
            axis_config={"color": WHITE},
        )
        
        x_vals = np.array(list(range(21)))
        y_vals = poisson.pmf(x_vals, mu=mean)
        graph = axes.get_line_graph(x_values=x_vals, y_values=y_vals, line_color=color)
        
        x_label = axes.get_x_axis_label(MathTex(r"\ell_n"))
        y_label = axes.get_y_axis_label(MathTex(r"p(\ell_n|{})".format(label)), edge=LEFT, direction=LEFT).set_color(color)
        
        
        group = VGroup(axes, graph, x_label, y_label)
        
        return group
    
    def animate_optimization_part_two(self):
        # E_l
        el_things = self.animate_part_of_math_dont_remove(self.objective_step_5[13:21], r"E_\ell")
        self.wait()
        
        graphs = [
            self.create_poisson_distribution(x, c, n) for x, c, n in zip([4, 8, 11], self.y_labels_color.values(), self.y_labels_color.keys())
        ]
        
        # create the graphs
        all_graphs = VGroup(*graphs).arrange(buff=1.0).scale(0.5)
        all_graphs.to_edge(DOWN, buff=0.1)
        self.add(all_graphs)
        self.wait()
        
        index_to_graph = {
            0: graphs[0],
            1: graphs[1],
            2: graphs[2],
            3: graphs[0]
        }
        
        # for each line, highlight the value from the corresponding distribution
        for i, line in enumerate(self.lines):
            graph = index_to_graph[i]
            
            x_val = int(self.length_values[i].get_value())
            y_val = poisson.pmf(x_val, mu=self.good_lengths[i])
            
            vline = graph[0].get_vertical_line(graph[0].c2p(x_val, y_val), color=RED)
            vdot = Dot(graph[0].c2p(x_val, y_val), color=RED).scale(0.5)
            
            b1 = SurroundingRectangle(line, color=YELLOW, buff=0.2)
            b2 = SurroundingRectangle(graph, color=YELLOW, buff=0.1)
            
            self.play(AnimationGroup(Create(b1), Create(b2)))
            self.wait()
            
            self.play(AnimationGroup(Flash(vdot), Create(vline)))
            self.wait()
            self.remove(vline, vdot, b1, b2)
        
        self.remove_after_animate_part_of_math(el_things)
        self.remove(all_graphs)
    
    def construct(self):
        self.title = Title("Inference In Action Segmentation")
        self.add(self.title)
        
        self.animate_optimization_problem_short()
        
        self.create_probabilities_and_plane()
        
        self.animate_probabilities_and_labels()

        self.show_the_path_start_and_end()
        
        self.length_values = [ValueTracker(3), ValueTracker(4), ValueTracker(5), ValueTracker(15)]
        self.good_lengths = [4, 8, 11, 4]
        
        self.draw_initial_lines()
        
        
        self.animate_optimization_part_one()
        self.wait()
        self.animate_optimization_part_two()
        self.wait()
        
        self.draw_segmentation_output()
        self.wait()
        
        self.play(AnimationGroup(
            *[v.animate.set_value(self.good_lengths[i]) for i, v in enumerate(self.length_values)]
        ))
        
        self.wait()

                                                             

In [139]:
!ffmpeg -i media/videos/presentation/720p30/Slides04Inference.mp4 /home/souri/Slides04Inference.gif

ffmpeg version 4.2.4-1ubuntu0.1 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 9 (Ubuntu 9.3.0-10ubuntu2)
  configuration: --prefix=/usr --extra-version=1ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-l