In [1]:
import typing as t
from scipy.stats import poisson, laplace
import numpy as np
from manim import *
config.media_width = "60%"

In [2]:
import typing as t

import torch
from scipy.signal import gaussian
from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample, softmax, cross_entropy
from scipy.ndimage import zoom

# noinspection PyPep8Naming
def project_lengths_softmax(T: int, L: Tensor) -> Tensor:
    """

    :param T: 1:int
    :param L: [M]:float
    :return: [M]:float
    """
    return T * softmax(L, dim=0)


# noinspection PyPep8Naming
def create_masks(
    T: int,
    L: Tensor,
    overlap: float = 0.0,
    opts: t.Dict[str, t.Any] = None,
) -> Tensor:
    """
    Given a set of projected S_{i}s, creates the attentions for weak training.
    :param T: The target size for the masks.
    :param L: [M] the projected lengths.
    :param overlap: how much overlap should the attentions have
    :param opts: potentially other options needed.
        "template_width": for non-plateau sampling based masks
        "sharpness": for plateau masks
    :return: [M x T] the attention maps.
    """
    default_opts = {
        "sharpness": 0.1,
        "template_width": 100,
    }
    if not opts:  # opts is None
        opts = default_opts
    else:
        # commenting out the line before because it requires Python 3.9.
        # opts = default_opts | opts
        # update the default opts with the input opts
        opts = {**default_opts, **opts}

    B = L.size(0)

    sharpness = opts["sharpness"]
    the_range = torch.arange(
        start=0, end=T, step=1.0, dtype=torch.float32, device=L.device
    )
    pis = torch.cumsum(L, 0)  # [M]
    pis -= L  # [B]

    if overlap > 0:
        L = L * (1.0 + 2 * overlap)
        pis = pis - (L * (overlap / 2))

    half_width = L / 2.0  # [M]
    centers = pis + half_width  # center of the boxes  [M]
    sharpness = torch.full(
        (B, 1), dtype=torch.float32, device=L.device, fill_value=sharpness
    )
    c, w, s = (
        centers.unsqueeze(1),
        half_width.unsqueeze(1),
        # sharpness / half_width.unsqueeze(1),
        sharpness,
        # torch.minimum(sharpness, 500. / half_width.unsqueeze(1))
    )
    inside_1 = s * (the_range - c - w)
    inside_2 = s * (-the_range + c - w)
    zeros_ = torch.zeros_like(inside_1)
    denom_1_prime = torch.logsumexp(torch.stack((inside_1, zeros_), dim=2), dim=2)
    denom_2_prime = torch.logsumexp(torch.stack((inside_2, zeros_), dim=2), dim=2)
    g_prime = -1 * (denom_1_prime + denom_2_prime)
    g = torch.exp(g_prime)  # [M x T], the masks

    # non-log-space implementation is like this
    # denom_1 = torch.exp(s * (the_range - c - w)) + 1
    # denom_2 = torch.exp(s * (-the_range + c - w)) + 1
    # g = 1.0 / (denom_1 * denom_2)

    return g

In [3]:
laplace.pdf(10, loc=10)

0.5

In [4]:
import cv2

In [5]:
nll_image = cv2.imread("nll1.png")[:, :, ::-1]

In [6]:
nll_image.shape

(143, 967, 3)

In [7]:
143/4

35.75

In [8]:
36+36+36+35

143

In [9]:
(np.array([4, 8-0.1, 11-0.01, 4+0.1]) / 27 * 976).round()

array([145., 286., 397., 148.])

In [10]:
r1 = (0, 36)
r2 = (36, 36+35)
r3 = (36+35, 36+35+36)
r4 = (36+35+36, 36+35+36+36)

absolute_lengths = [145, 286, 397-3, 148+3]

exact_m_image = np.zeros_like(nll_image)
start_length = 0
for r, al in zip([r1, r2, r3, r4], absolute_lengths):
    exact_m_image[r[0]:r[1], start_length:start_length+al, :] = 255
    start_length += al
    
exact_m_image = exact_m_image.copy()  # there is some very strange caching effect. This copy is to fix it.

In [11]:
approx_m = create_masks(967, torch.tensor(absolute_lengths)).numpy() * 255
approx_m_image = np.zeros_like(exact_m_image).copy()
for i, r in enumerate([r1, r2, r3, r4]):
    approx_m_image[r[0]:r[1], :, 0] = approx_m[i, :].copy()
    approx_m_image[r[0]:r[1], :, 1] = approx_m[i, :].copy()
    approx_m_image[r[0]:r[1], :, 2] = approx_m[i, :].copy()

In [12]:
approx_m_image.shape

(143, 967, 3)

In [13]:
exact_m_image.shape

(143, 967, 3)

In [14]:
%%manim -v WARNING Slides08ApproximateDifferentiableEnergy
# %%manim -v WARNING -qm --flush_cache Slides08ApproximateDifferentiableEnergy

class Slides08ApproximateDifferentiableEnergy(Scene):
    def create_citation(self, title: str) -> Text:
        return Text(title, color=GRAY).scale(0.3)
    
    def create_text_group(self, texts: t.List[Tex]):
        return VGroup(*texts).arrange(
            DOWN,
            center=False,
            aligned_edge=LEFT
        ).next_to(self.title, DOWN * 1.5).to_edge(LEFT, buff=1)
        
    def animate_part_of_math(self, math_part, underline_text: t.Optional[str] = None):
        prev_color = WHITE
        underline = Underline(math_part, buff=0.4)
        if underline_text:
            underline_text = MathTex(underline_text).move_to(underline.get_bottom() + DOWN * 0.5)
            self.play(
                AnimationGroup(
                    Create(underline),
                    Write(underline_text),
                    math_part.animate.set_color(YELLOW)
                )
            )
        else:
            self.play(
                AnimationGroup(
                    Create(underline),
                    math_part.animate.set_color(YELLOW)
                )
            )
        self.wait()
        self.remove(underline)
        if underline_text:
            self.remove(underline_text)
        math_part.set_color(prev_color)
        
    def show_animation_length_term(self):
        self.objective_length_expanded = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"E_o(\ell'_{1:N})",
            r" + ",
            r"\sum_{n=1}^{N}",
            r"-",
            r"\log ",
            r"p\big(\ell_n | c_n\big)",
        ).next_to(self.title, DOWN * 1.5)
        
        self.objective_length_expanded_replacement = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"E_o(\ell'_{1:N})",
            r" + ",
            r"\sum_{n=1}^{N}",
            r"-",
            r"\log ",
            r"p^\star\big(\ell_n | c_n\big)",
        ).next_to(self.title, DOWN * 1.5)
        
        self.objective_length_expanded_final = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"E_o(\ell'_{1:N})",
            r" + ",
            r"\beta",
            r"\sum_{n=1}^{N}",
            r"|\ell_n - \lambda^{\ell}_{c_n}|"
        ).next_to(self.title, DOWN * 1.5)
        
        self.objective_after_length_update = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"E_o(\ell'_{1:N})",
            r" + ",
            r"\beta",
            r"E^\star_\ell(\ell'_{1:N})"
        ).next_to(self.title, DOWN * 1.5)
        
        self.play(TransformMatchingTex(self.objective, self.objective_length_expanded))
        self.wait()
        
        poisson_distribution = self.create_poisson_distribution()
        laplace_distribution = self.create_laplace_distribution()
        
        graphs = VGroup(poisson_distribution, laplace_distribution).arrange(RIGHT, buff=1.0).shift(DOWN)
        
        self.add(poisson_distribution)
        self.objective_length_expanded.set_color_by_tex(r"p\big(\ell_n | c_n\big)", YELLOW)
        
        self.wait()
        
        self.objective_length_expanded_replacement.set_color_by_tex(r"p^\star\big(\ell_n | c_n\big)", BLUE)
        self.play(
            AnimationGroup(
                FadeIn(laplace_distribution, shift=LEFT),
                TransformMatchingTex(self.objective_length_expanded, self.objective_length_expanded_replacement)
            )
        )
        
        self.wait()
        self.objective_length_expanded_final.set_color_by_tex(r"\beta", BLUE)
        self.objective_length_expanded_final.set_color_by_tex(r"\sum_{n=1}^{N}", BLUE)
        self.objective_length_expanded_final.set_color_by_tex(r"|\ell_n - \lambda^{\ell}_{c_n}|", BLUE)
        self.play(
            AnimationGroup(
                FadeOut(laplace_distribution, shift=LEFT),
                FadeOut(poisson_distribution, shift=LEFT),
                TransformMatchingTex(self.objective_length_expanded_replacement, self.objective_length_expanded_final)
            )
        )
        
        self.wait()
        
        self.play(TransformMatchingTex(self.objective_length_expanded_final, self.objective_after_length_update))
        self.wait()
        
    def create_poisson_distribution(self, mean: int = 10, color = YELLOW, label: str = "c_n") -> VGroup:
        axes = Axes(
            x_range=[-0.1, 20.9, 1],
            y_range=[-0.04, 0.19, 0.05],
            x_length=5,
            y_length=3,
            x_axis_config={"numbers_to_include": np.arange(0, 21, 2)},
            axis_config={"color": WHITE},
        )
        
        x_vals = np.array(list(range(21)))
        y_vals = poisson.pmf(x_vals, mu=mean)
        graph = axes.get_line_graph(x_values=x_vals, y_values=y_vals, line_color=color)
        
        x_label = axes.get_x_axis_label(MathTex(r"\ell_n"), edge=DOWN, direction=DOWN)
        y_label = axes.get_y_axis_label(MathTex(r"p(\ell_n|{})".format(label)), edge=UP, direction=UP).set_color(color)
        
        
        group = VGroup(axes, graph, x_label, y_label).scale(0.75)
        
        return group
    
    def create_laplace_distribution(self, mean: int = 10, color = BLUE, label: str = "c_n") -> VGroup:
        axes = Axes(
            x_range=[-0.1, 20.9, 1],
            y_range=[-0.04, 0.19, 0.05],
            x_length=5,
            y_length=3,
            x_axis_config={"numbers_to_include": np.arange(0, 21, 2)},
            axis_config={"color": WHITE},
        )
        
        graph = axes.get_graph(lambda x: 0.3 * laplace.pdf(x, loc=mean), color=color)
        
        x_label = axes.get_x_axis_label(MathTex(r"\ell_n"), edge=DOWN, direction=DOWN)
        y_label = axes.get_y_axis_label(MathTex(r"p^\star(\ell_n|{})".format(label)), edge=UP, direction=UP).set_color(color)
        
        
        group = VGroup(axes, graph, x_label, y_label).scale(0.75)
        
        return group
            
    def create_P_and_M(self):
        p_matrix = ImageMobject(nll_image, scale_to_resolution=700)
        p_matrix.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
        
        p_plane = NumberPlane(
            x_range=[0, 27.01, 1],
            y_range=[0, 4.01, 1],
            x_length=p_matrix.width,
            y_length=p_matrix.height,
            axis_config={
                "stroke_color": BLUE,
                "stroke_width": 2,
            },
            background_line_style={
                "stroke_color": BLUE,
                "stroke_width": 2,
            }
        )
        p_label = p_plane.get_y_axis_label("P", direction=LEFT, edge=LEFT)
        self.p_matrix_group = Group(p_matrix, p_plane, p_label).shift(DOWN * 0.25)
        
        m_matrix = ImageMobject(exact_m_image, scale_to_resolution=700)
        m_matrix.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
        
        m_plane = NumberPlane(
            x_range=[0, 27.01, 1],
            y_range=[0, 4.01, 1],
            x_length=m_matrix.width,
            y_length=m_matrix.height,
            axis_config={
                "stroke_color": RED,
                "stroke_width": 2,
            },
            background_line_style={
                "stroke_color": RED,
                "stroke_width": 2,
            }
        )
        m_label = m_plane.get_y_axis_label("M", direction=LEFT, edge=LEFT)
        self.m_matrix_group = Group(m_matrix, m_plane, m_label).shift(DOWN * 2.5)
        
    def create_M_star(self):
        am_matrix = ImageMobject(approx_m_image, scale_to_resolution=700)
        am_matrix.set_resampling_algorithm(RESAMPLING_ALGORITHMS["linear"])
        
        am_plane = NumberPlane(
            x_range=[0, 27.01, 1],
            y_range=[0, 4.01, 1],
            x_length=am_matrix.width,
            y_length=am_matrix.height,
            axis_config={
                "stroke_color": RED,
                "stroke_width": 2,
            },
            background_line_style={
                "stroke_color": RED,
                "stroke_width": 2,
            }
        )
        am_label = am_plane.get_y_axis_label(r"M^\star", direction=LEFT, edge=LEFT)
        self.am_matrix_group = Group(am_matrix, am_plane, am_label).shift(DOWN * 2.5)
    
    def show_animation_observation_term(self):
        self.objective_observation_expanded = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"\sum_{t=1}^{T}",
            r"-",
            r"\log",
            r"p\big(",
            r"\alpha(t)",
            r" | ",
            r"x_t",
            r"\big)",
            r" + ",
            r"\beta",
            r"E^\star_\ell(\ell'_{1:N})"
        ).next_to(self.title, DOWN * 1.5)
        
        self.play(TransformMatchingTex(self.objective_after_length_update, self.objective_observation_expanded))
        self.wait()
        
        self.create_P_and_M()
        
        self.objective_observation_mul = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"\sum_{t=1}^{T}",
            r"\sum_{n=1}^{N}",
            r"\big(P \odot M\big)[n, t]",
            r" + ",
            r"\beta",
            r"E^\star_\ell(\ell'_{1:N})"
        ).next_to(self.title, DOWN * 1.5)
        self.objective_observation_mul.set_color_by_tex(r"\big(P \odot M\big)[n, t]", BLUE)
        
        self.objective_observation_mul_star = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"\sum_{t=1}^{T}",
            r"\sum_{n=1}^{N}",
            r"\big(P \odot M^\star\big)[n, t]",
            r" + ",
            r"\beta",
            r"E^\star_\ell(\ell'_{1:N})"
        ).next_to(self.title, DOWN * 1.5)
        self.objective_observation_mul_star.set_color_by_tex(r"\big(P \odot M^\star\big)[n, t]", BLUE)
        
        self.play(TransformMatchingTex(self.objective_observation_expanded, self.objective_observation_mul))
        self.wait()
        
        self.play(
            AnimationGroup(
                FadeIn(self.p_matrix_group, shift=LEFT),
                FadeIn(self.m_matrix_group, shift=LEFT),
            )
        )
        self.wait()
        
        
        
        self.create_M_star()
        
        self.play(
            AnimationGroup(
                FadeOut(self.m_matrix_group),
                FadeIn(self.am_matrix_group),
                TransformMatchingTex(self.objective_observation_mul, self.objective_observation_mul_star)
            )
        )
        self.wait()
        
    
    def construct(self):
        self.length_values = [ValueTracker(3), ValueTracker(4), ValueTracker(5), ValueTracker(15)]
        self.good_lengths = [4, 8, 11, 4]
        
        self.title = Title("Approximate Differentiable Energy")
        self.add(self.title)
        
        self.objective = MathTex(
            r"E(\ell'_{1:N})",
            r" = ",
            r"E_o(\ell'_{1:N})",
            r" + ",
            r"E_{\ell}(\ell'_{1:N})",
        ).next_to(self.title, DOWN * 1.5)
        
        
        self.the_final_objective = MathTex(
            r"E^\star(\ell'_{1:N})",
            r" = ",
            r"E^\star_o(\ell'_{1:N})",
            r" + ",
            r"\beta",
            r"E^\star_{\ell}(\ell'_{1:N})",
        ).scale(1.25)
        
        self.add(self.objective)
        self.wait()
        
        self.show_animation_length_term()
        
        self.show_animation_observation_term()
        
        self.wait(2)
        
        self.play(
            AnimationGroup(
                FadeOut(self.am_matrix_group),
                FadeOut(self.p_matrix_group),
                TransformMatchingTex(self.objective_observation_mul_star, self.the_final_objective)
            )
        )
        self.wait()

                                                                              