In [2]:
from manim import *
import numpy as np
from itertools import combinations
import random 

config.media_width = "75%"
config.verbosity = "WARNING"

In [21]:
%%manim -qm CheckDots

def cartesian_to_spherical(vec):
    x, y, z = vec
    r = np.linalg.norm(vec)
    phi = np.arccos(z / r)
    theta = np.arctan2(y, x)
    return r, phi, theta

def spherical_to_cartesian(spherical_coords):
    r, phi, theta = spherical_coords
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    return np.array([x, y, z])

class CheckDots(ThreeDScene):    
    
    def construct(self):
        axes = ThreeDAxes()
        
        self.renderer.camera.light_source.move_to(3*IN) # changes the source of the light
        self.set_camera_orientation(phi=75 * DEGREES, theta=60 * DEGREES)
        
        self.add(axes)
        
        # Generate a cluster of random points in 3D space
        np.random.seed(0)
        cluster1 = np.random.rand(10, 3) + np.array([3, 2, 2])
        
        # Create Mobjects for the cluster of dots
        cluster1_dots = VGroup(*[Dot3D(np.array(coords), color=BLUE) for coords in cluster1])
        
        # Display the cluster
        self.play(
            Create(cluster1_dots),
        )
        self.wait()

        # Get the current camera orientation and focal distance
        current_phi = self.renderer.camera.phi
        current_theta = self.renderer.camera.theta
        focal_distance = self.renderer.camera.focal_distance

        # Calculate the camera position in Cartesian coordinates
        camera_position = spherical_to_cartesian((focal_distance, current_phi, current_theta))

        # Calculate the vector pointing from the camera to the centroid of the first cluster (cluster1)
        centroid1 = np.mean(cluster1, axis=0)
        camera_to_centroid1 = centroid1 - camera_position

        # Calculate the spherical coordinates (phi and theta) of this vector
        _, phi, theta = cartesian_to_spherical(camera_to_centroid1)

        # Set the camera orientation to face the first cluster
        self.set_camera_orientation(phi=phi * DEGREES, theta=theta * DEGREES)
        self.wait()


                                                                                                                                            

In [24]:
%%manim -qm ThreeDots

class ThreeDots(ThreeDScene):
    
    def add_objects(self, objects):
        for obj in objects:
            self.add(obj)
    
    def remove_objects(self):
        removed_objects = self.mobjects.copy()
        for arg in self.mobjects:
            self.remove(arg)
            
        return removed_objects
            
    def play_text(self, text='', wait_time=3):
        removed_objects = self.remove_objects()
        
        # Set the size and position of the frame
        self.set_camera_orientation(phi=0 * DEGREES, theta=-90 * DEGREES)
               
        # Create and display the text
        text = Text(text, font_size=50)
        self.play(FadeIn(text))

        self.wait(wait_time) # wait for 5 seconds before ending the scene
        
        self.play(FadeOut(text))
        
        self.set_camera_orientation(phi=75 * DEGREES, theta=60 * DEGREES)
        
        self.add_objects(removed_objects)
        
    
    def construct(self):
        axes = ThreeDAxes()
        
        self.renderer.camera.light_source.move_to(3*IN) # changes the source of the light
        self.set_camera_orientation(phi=75 * DEGREES, theta=60 * DEGREES)
        
        self.add(axes)
        
        # Generate two clusters of random points in 3D space
        np.random.seed(0)
        cluster1 = np.random.rand(10, 3) + np.array([3, 2, 2])
        origin = np.array([0, 0, 0])
        
        # Create Mobjects for the two clusters of dots
        cluster1_dots = VGroup(*[Dot3D(np.array(coords), color=BLUE) for coords in cluster1])
        
        # Calculate the centroid of each cluster
        centroid1 = np.mean(cluster1, axis=0)
        
        # Create Mobjects for the centroids
        centroid1_dot = Dot3D(np.array(centroid1), color=YELLOW, radius=0.1)
        
        # Calculate vector 'd' connecting centroids
        d = centroid1
        d = d/np.linalg.norm(d)

        # calculate matrix that projects into 2D orthogonal complement space
        def calculate_tranformation_matrix(d):
            new_basis_mat = np.eye(d.shape[0])
            new_basis_mat[:, 0] = d
            U, _ = np.linalg.qr(new_basis_mat)
            S = np.eye(U.shape[0])
            S[0, 0] = 0
            
            proj_mat = U @ S @ U.T
            
            return proj_mat
        
        # Collapse 3D points to 2D orthogonal complement space
        def collapse_to_2d(points, proj_mat):
            new_points = [proj_mat @ point for point in points]
            
            return new_points
        
        # Normalize collapsed 2D embedding to unit circle
        def project_to_2D_unit_circle(embeds):    
            # Find the normal vector of the plane
            u, s, vh = np.linalg.svd(embeds)
            normal = vh[-1]    # the last eigenvector is the one that doesn't vary, i.e. 
                               # normal to the plane where the points lie 

            # Project the points onto the plane
            points_proj = embeds - (embeds @ np.outer(normal, normal))  # embeds minus projection of embeds onto normal
                                                                        # would give us embeds that lie on the 2D plane,
                                                                        # which is perpendicular to the normal 

            # normalize to unit circle 
            points_norm = np.linalg.norm(points_proj, axis=1)
            points_proj /= points_norm[:, np.newaxis]         # using broadcasting to normalize 

            return points_proj
        
        def cartesian_to_spherical(vec):
            x, y, z = vec
            r = np.linalg.norm(vec)
            phi = np.arccos(z / r)
            theta = np.arctan2(y, x)
            
            return phi, theta

        
        proj_mat = calculate_tranformation_matrix(d)
        
        projected_cluster1 = collapse_to_2d(cluster1, proj_mat)
        centroid1_proj = collapse_to_2d([centroid1], proj_mat)
        
        normalized_cluster1  = project_to_2D_unit_circle(projected_cluster1)
        normalized_centroid1 = project_to_2D_unit_circle(centroid1_proj)
        
        projected_cluster1_dots = VGroup(*[Dot3D(np.array(coords), color=BLUE) for coords in projected_cluster1])
        centroid1_dot_proj = Dot3D(np.array(centroid1_proj), color=YELLOW, radius=0.1)
        
        normalized_cluster1_dots = VGroup(*[Dot3D(np.array(coords), color=BLUE) for coords in normalized_cluster1])
        normalized_centroid1_dot = Dot3D(np.array(normalized_centroid1), color=YELLOW, radius=0.1)

        d_vector = Arrow3D(centroid1, origin, color=WHITE)
        
        ##################################
        # ANIMATION SECTION
        ##################################
        
        self.play_text('Visualizing network embeddings in 3D')
        # Display original clusters
        self.play(
            Create(cluster1_dots),
        )
        self.wait()
        
        self.play_text('Calculating centroid of the blue cluster')
        # Show centroids and vector 'd'
        self.play(
            Create(centroid1_dot),
        )
        self.wait()
        
        self.play_text("Obtain vector connecting origin and centroid")
        self.play(
            Create(d_vector),
        )        
        self.wait() # wait for 2 seconds
        
        self.play_text("Collapse the embeddings along this vector")
        self.play(
#             FadeOut(cluster1_dots),
            Transform(cluster1_dots, projected_cluster1_dots),
            FadeOut(d_vector),
#             FadeOut(centroid1_dot),
            Transform(centroid1_dot, centroid1_dot_proj),
        )
        self.wait()
        
        self.play_text("Normalize embeddings along unit circle")
        self.play(
            FadeOut(cluster1_dots),
            Transform(projected_cluster1_dots, normalized_cluster1_dots),
            FadeOut(centroid1_dot),
            Transform(centroid1_dot_proj, normalized_centroid1_dot),
        )
        self.wait(5)
        
        

                                                                                                                                            