In [1]:
import typing as t

import numpy as np
from manim import *
config.media_width = "60%"

In [2]:
import cv2

In [3]:
%%manim -v WARNING Slides06ApproximateInference
# %%manim -v WARNING -qm Slides06ApproximateInference

class Slides06ApproximateInference(Scene):
    
    def create_citation(self, title: str) -> Text:
        return Text(title, color=GRAY).scale(0.3)
    
    def create_text_group(self, texts: t.List[Tex]):
        return VGroup(*texts).arrange(
            DOWN,
            center=False,
            aligned_edge=LEFT
        ).next_to(self.title, DOWN * 1.5).to_edge(LEFT, buff=1)
    
    def write_the_objective(self):
        self.the_objective = MathTex(
            r"\ell^*_{1:N}",
            r" = ",
            r"\underset{\ell_{1:N}}{\mathrm{argmin}}",
            r"~E(\ell_{1:N} | x_{1:T}, c_{1:N})"
        )
        
        self.add(self.the_objective)
        self.wait()
        self.play(
            self.the_objective.animate.to_edge(UP)
        )
        
    def create_parameters(self, count: int) -> MathTex:
        return MathTex(r"\ell^{(" + str(count) + r")}_{1:N}")
    
    def create_values(self, count: int) -> MathTex:
        return MathTex(r"E(\ell^{(" + str(count) + r")}_{1:N})")
    
    def create_grads(self, count: int) -> MathTex:
        return MathTex(r"\frac{\partial E}{\partial \ell} (\ell^{(" + str(count) + r")}_{1:N})")
    
    def animate_gradient_base_optimization(self):
        l0 = self.create_parameters(0)
        l1 = self.create_parameters(1)
        l2 = self.create_parameters(2)
        
        g1 = self.create_grads(0).scale(0.75)
        g2 = self.create_grads(1).scale(0.75)
        
        v0 = self.create_values(0)
        v1 = self.create_values(1)
        v_dot = MathTex(r"\cdots")
        
        group = VGroup(
            *[
                l0, v0,
                l1, v1,
                l2, v_dot,
            ]
        )
        
        group.arrange_in_grid(
            cols=2, col_widths=[5, 5],
            row_heights=[2, 2, 2, 2],
            col_alignments=["c", "c"],
            buff=1
        ).next_to(self.title, DOWN, buff=1)
        
        a1 = Arrow(l0.get_center() + 0.5 * RIGHT, v0.get_center() - 0.75 * RIGHT)
        b1 = Arrow(v0, l1.get_center() + 0.25 * UP + 0.25 * RIGHT)
        d1 = Arrow(l0.get_center() + 0.25 * DOWN, l1.get_center() - 0.25 * DOWN)
        
        a2 = Arrow(l1.get_center() + 0.5 * RIGHT, v1.get_center() - 0.75 * RIGHT)
        b2 = Arrow(v1, l2.get_center() + 0.25 * UP + 0.25 * RIGHT)
        d2 = Arrow(l1.get_center() + 0.25 * DOWN, l2.get_center() - 0.25 * DOWN)
        
        a3 = Arrow(l2.get_center() + 0.5 * RIGHT, v_dot.get_center() - 0.75 * RIGHT)
        
        g1.move_to(b1).shift(0.33 * UP + 0.5 * LEFT).rotate(PI/6)
        g2.move_to(b2).shift(0.33 * UP + 0.5 * LEFT).rotate(PI/6)
        
        arrows = VGroup(a1, b1, d1, a2, b2, d2, a3, g1, g2)
        
        # step 0
        self.play(FadeIn(l0))
        self.wait()
        self.play(Create(a1))
        self.play(FadeIn(v0))
        self.wait()
        self.play(
            AnimationGroup(FadeIn(g1), Create(b1))
        )
        self.wait()
        self.play(
            AnimationGroup(FadeIn(l1), Create(d1))
        )
        self.wait()
        
        # clear step 0
        self.remove(l0, v0, g1, a1, b1, d1)
        shift_amount = a1.get_center() - a2.get_center()
        self.play(
            l1.animate.shift(shift_amount)
        )
        self.wait()
        
        # step 1
        self.play(Create(a2.shift(shift_amount)))
        self.play(FadeIn(v1.shift(shift_amount)))
        self.wait()
        self.play(
            AnimationGroup(FadeIn(g2.shift(shift_amount)), Create(b2.shift(shift_amount)))
        )
        self.wait()
        self.play(
            AnimationGroup(FadeIn(l2.shift(shift_amount)), Create(d2.shift(shift_amount)))
        )
        self.wait()
        
        # step dots
        self.play(
            AnimationGroup(Create(a3.shift(shift_amount)), FadeIn(v_dot.shift(shift_amount)))
        )
        self.wait()
        
    def construct(self):
        self.the_objective = r"\ell^*_{1:N} = \underset{\ell_{1:N}}{\mathrm{argmin}}~E(\ell_{1:N} | x_{1:T}, c_{1:N})"
        self.title = Title(f"${self.the_objective}$")
        self.add(self.title)
        
#         self.write_the_objective()
        
        self.animate_gradient_base_optimization()
        
        self.wait()

                                                                              