# 1. Introduction & intuition

## title

In [50]:
from manim import *

In [51]:
%%manim -v WARNING -qk Title  
#-qk full hd 4k

from manim import *

class Title(Scene):
    def construct(self):
        texto = Text("Árboles de Decisión", font_size=75)
        
        self.play(DrawBorderThenFill(texto))
        self.wait(1)
        self.play(FadeOut(texto), run_time = 1)

                                                                                                                       

## Quote 

In [52]:
%%manim -v WARNING -qk Quote

class Quote(Scene):
    def construct(self):
        quote = Text("El árbol de decisión es una visualización \ndel proceso de pensamiento", font_size = 54)
        author = Text("-Garry Kasparov",font_size=54, color = YELLOW)

        quote.shift(UP * 2.5)
        author.shift(UP*0.4, RIGHT * 2.7)
        self.play(FadeIn(quote), run_time = 0.5)
        self.wait(2)
        self.play(Write(author))
        self.wait(1)
        self.play(FadeOut(quote,author), run_time = 0.5)

                                                                                                                       

In [53]:
%%manim -v WARNING -qk DecisionTreeAnimation  
#-qk full hd 4k
from manim import *

class DecisionTreeAnimation(Scene):
    def construct(self):
        # Crear nodos
        root = self.create_node("¿La calle está\n oscura?", color=YELLOW)
        root.shift(UP * 1)
        
        yes_node = self.create_node("No me meto", color=GREEN)
        yes_node.shift(DOWN * 1.2 + LEFT * 3)
        
        no_node = self.create_node("Me meto", color=RED)
        no_node.shift(DOWN * 1.2 + RIGHT * 3)
        
        # Etiquetas de las ramas
        yes_label = Text("True", color=WHITE, font_size=18)
        yes_label.next_to(root, DOWN).shift(LEFT * 2)
        
        no_label = Text("False", color=WHITE, font_size=18)
        no_label.next_to(root, DOWN).shift(RIGHT * 2)
        
        # Dibujar rama "Sí"
        yes_line = Line(
            root.get_bottom() + DOWN * 0.1,
            yes_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Dibujar rama "No"
        no_line = Line(
            root.get_bottom() + DOWN * 0.1,
            no_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Dibujar todo
        self.play(FadeIn(root, yes_line, yes_label, yes_node, no_line, no_label, no_node))
        self.wait(2)
        
        # 1. Root node se sombrea de amarillo y "Sí" se convierte en amarillo
        self.play(
            root[0].animate.set_fill(YELLOW, opacity=0.3),
            yes_label.animate.set_color(YELLOW)
        )
        self.wait(2)
        
        # 2. Nodo izquierdo se rellena de verde ligeramente
        self.play(yes_node[0].animate.set_fill(GREEN, opacity=0.3))
        self.wait(2)
        
        # 3. Resetear colores (excepto el relleno verde que permanece)
        self.play(
            yes_label.animate.set_color(WHITE),
            yes_node[0].animate.set_fill(GREEN, opacity = 0)
        )
        
        # 4. Root node se sombrea de amarillo y "No" se convierte en amarillo
        self.play(
            root[0].animate.set_fill(YELLOW, opacity=0.3),
            no_label.animate.set_color(YELLOW)
        )
        self.wait(2)
        
        # 5. Nodo derecho se rellena de rojo ligeramente
        self.play(no_node[0].animate.set_fill(RED, opacity=0.3))
        self.wait(2)
        
        # 6. Resetear colores final
        self.play(
            root[0].animate.set_fill(YELLOW, opacity=0),
            no_label.animate.set_color(WHITE),
            yes_node[0].animate.set_fill(GREEN, opacity = 0),
            no_node[0].animate.set_fill(RED, opacity = 0)
        )
        self.wait(1)
        
    def create_node(self, text, color=WHITE):
        """Crea un nodo rectangular con texto"""
        node = Rectangle(
            width=3,
            height=1,
            color=color,
            stroke_width=2
        )
        
        node_text = Text(text, font_size=24)
        node_text.set_max_width(node.width * 0.9)
        
        node_group = VGroup(node, node_text)
        return node_group

                                                                                                                       

In [54]:
%%manim -v WARNING -qk TwoDecisionTrees
from manim import *

class TwoDecisionTrees(Scene):
    def construct(self):
        # ===== ÁRBOL IZQUIERDO =====
        left_root = self.create_node("¿La calle está\n oscura?", color=YELLOW)
        left_root.shift(LEFT * 3.4 + UP * 0.5)
        
        left_yes_node = self.create_node("No me meto", color=WHITE)
        left_yes_node.shift(LEFT * 5.4 + DOWN * 1.5)
        
        left_no_node = self.create_node("Me meto", color=WHITE)
        left_no_node.shift(LEFT * 1.4 + DOWN * 1.5)
        
        # Etiquetas izquierdas
        left_yes_label = Text("True", color=WHITE, font_size=18)
        left_yes_label.next_to(left_root, DOWN).shift(LEFT * 1.5)
        
        left_no_label = Text("False", color=WHITE, font_size=18)
        left_no_label.next_to(left_root, DOWN).shift(RIGHT * 1.5)
        
        # Líneas izquierdas
        left_yes_line = Line(
            left_root.get_bottom() + DOWN * 0.1,
            left_yes_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        left_no_line = Line(
            left_root.get_bottom() + DOWN * 0.1,
            left_no_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # ===== ÁRBOL DERECHO =====
        right_root = self.create_node("¿El motor tiene más\n de 4 cilindros?", color=YELLOW)
        right_root.shift(RIGHT * 3.4 + UP * 0.5)
        
        right_yes_node = self.create_node("Consumo entre 6\n y 10 km/litro", color=WHITE)
        right_yes_node.shift(RIGHT * 5.4 + DOWN * 1.5)
        
        right_no_node = self.create_node("Consumo >= 10\n km/litro", color=WHITE)
        right_no_node.shift(RIGHT * 1.4 + DOWN * 1.5)
        
        # Etiquetas derechas
        right_yes_label = Text("True", color=WHITE, font_size=18)
        right_yes_label.next_to(right_root, DOWN).shift(LEFT * 1.5)
        
        right_no_label = Text("False", color=WHITE, font_size=18)
        right_no_label.next_to(right_root, DOWN).shift(RIGHT * 1.5)
        
        # Líneas derechas
        right_yes_line = Line(
            right_root.get_bottom() + DOWN * 0.1,
            right_yes_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        right_no_line = Line(
            right_root.get_bottom() + DOWN * 0.1,
            right_no_node.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Mostrar árbol izquierdo
        self.play(FadeIn(left_root, left_yes_line, left_yes_label, left_yes_node, 
                         left_no_line, left_no_label, left_no_node))
        self.wait(0.5)
        
        # Mostrar árbol derecho
        self.play(FadeIn(right_root, right_yes_line, right_yes_label, right_yes_node, 
                         right_no_line, right_no_label, right_no_node))
        self.wait(1)

        self.play(
            left_root[0].animate.set_fill(YELLOW, opacity=0.3),
            left_no_label.animate.set_color(YELLOW),
            left_yes_label.animate.set_color(YELLOW),
            left_yes_node[0].animate.set_fill(GREEN, opacity = 0.3),
            left_no_node[0].animate.set_fill(RED, opacity = 0.3)
        )

        self.wait(2)
        
    def create_node(self, text, color=WHITE):
        """Crea un nodo rectangular con texto"""
        node = Rectangle(
            width=2.3,
            height=1,
            color=color,
            stroke_width=2
        )
        
        node_text = Text(text, font_size=18)
        node_text.set_max_width(node.width * 0.9)
        
        node_group = VGroup(node, node_text)
        return node_group

                                                                                                                       

In [55]:
%%manim -v WARNING -qk ComplexDecisionTree
from manim import *

class ComplexDecisionTree(Scene):
    def construct(self):
        # Crear todos los nodos con variables independientes
        root = self.create_node("¿Tengo examen\n mañana?", color=YELLOW)
        root.shift(UP * 2.5)
        
        # Nivel 1 - izquierda (True)
        node1_true = self.create_node("¿Curso filtro?", color=BLUE)
        node1_true.shift(LEFT * 4 + UP * 0.5)
        
        # Nivel 2 - izquierda-izquierda (True-True)
        node11_true = self.create_node("No llego", color=GREEN, _width = 1.6)
        node11_true.shift(LEFT * 6 + DOWN * 1.5)
        
        # Nivel 2 - izquierda-derecha (True-False)
        node12_false = self.create_node("¿la reu empieza\nantes de 10pm?", color=BLUE)
        node12_false.shift(LEFT * 2 + DOWN * 1.5)
        
        # Nivel 3 - izquierda-derecha-verdadero
        node121_true = self.create_node("Llego", color=GREEN, _width = 1.6)
        node121_true.shift(LEFT * 3 + DOWN * 3.5)
        
        # Nivel 3 - izquierda-derecha-falso
        node122_false = self.create_node("No llego", color=GREEN, _width = 1.6)
        node122_false.shift(LEFT * 1 + DOWN * 3.5)
        
        # Nivel 1 - derecha (False)
        node2_false = self.create_node("¿la reu está a\nmás de 5km?", color=BLUE)
        node2_false.shift(RIGHT * 4 + UP * 0.5)
        
        # Nivel 2 - derecha-verdadero
        node21_true = self.create_node("¿la reu empieza\nantes de 10pm?", color=BLUE)
        node21_true.shift(RIGHT * 2 + DOWN * 1.5)
        
        # Nivel 2 - derecha-falso
        node22_false = self.create_node("Llego", color=GREEN, _width = 1.6)
        node22_false.shift(RIGHT * 6 + DOWN * 1.5)
        
        # Nivel 3 - derecha-verdadero-verdadero
        node211_true = self.create_node("Llego", color=GREEN, _width = 1.6)
        node211_true.shift(RIGHT * 1 + DOWN * 3.5)
        
        # Nivel 3 - derecha-verdadero-falso
        node212_false = self.create_node("No llego", color=GREEN, _width = 1.6)
        node212_false.shift(RIGHT * 3 + DOWN * 3.5)
        
        # ===== CREAR ETIQUETAS DE RAMAS =====
        # Rama principal True/False
        true_label_root = Text("True", color=WHITE, font_size=24).shift(UP*1.3)
        true_label_root.next_to(root, DOWN).shift(LEFT * 2)
        
        false_label_root = Text("False", color=WHITE, font_size=24).shift(UP*1.3)
        false_label_root.next_to(root, DOWN).shift(RIGHT * 2)
        
        # Ramas del subárbol izquierdo
        true_label_node1 = Text("True", color=WHITE, font_size=20).shift(UP*0.5)
        true_label_node1.next_to(node1_true, DOWN).shift(LEFT * 1.5)
        
        false_label_node1 = Text("False", color=WHITE, font_size=20).shift(UP*0.5)
        false_label_node1.next_to(node1_true, DOWN).shift(RIGHT * 1.5)
        
        # Ramas del nodo 1.2 (True-False)
        true_label_node12 = Text("True", color=WHITE, font_size=18).shift(UP*0.5)
        true_label_node12.next_to(node12_false, DOWN).shift(LEFT * 1)
        
        false_label_node12 = Text("False", color=WHITE, font_size=18).shift(UP*0.5)
        false_label_node12.next_to(node12_false, DOWN).shift(RIGHT * 1)
        
        # Ramas del subárbol derecho
        true_label_node2 = Text("True", color=WHITE, font_size=20).shift(UP*0.5)
        true_label_node2.next_to(node2_false, DOWN).shift(RIGHT * 1.5)
        
        false_label_node2 = Text("False", color=WHITE, font_size=20).shift(UP*0.5)
        false_label_node2.next_to(node2_false, DOWN).shift(LEFT * 1.5)
        
        # Ramas del nodo 2.1 (False-True)
        true_label_node21 = Text("True", color=WHITE, font_size=18).shift(UP*0.5)
        true_label_node21.next_to(node21_true, DOWN).shift(RIGHT * 1)
        
        false_label_node21 = Text("False", color=WHITE, font_size=18).shift(UP*0.5)
        false_label_node21.next_to(node21_true, DOWN).shift(LEFT * 1)
        
        # ===== CREAR LÍNEAS =====
        # Ramas desde root
        line_root_true = Line(
            root.get_bottom() + DOWN * 0.1,
            node1_true.get_top() + UP * 0.1,
            color=WHITE
        )
        
        line_root_false = Line(
            root.get_bottom() + DOWN * 0.1,
            node2_false.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Ramas subárbol izquierdo
        line_node1_true = Line(
            node1_true.get_bottom() + DOWN * 0.1,
            node11_true.get_top() + UP * 0.1,
            color=WHITE
        )
        
        line_node1_false = Line(
            node1_true.get_bottom() + DOWN * 0.1,
            node12_false.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Ramas del nodo 1.2
        line_node12_true = Line(
            node12_false.get_bottom() + DOWN * 0.1,
            node121_true.get_top() + UP * 0.1,
            color=WHITE
        )
        
        line_node12_false = Line(
            node12_false.get_bottom() + DOWN * 0.1,
            node122_false.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Ramas subárbol derecho
        line_node2_true = Line(
            node2_false.get_bottom() + DOWN * 0.1,
            node21_true.get_top() + UP * 0.1,
            color=WHITE
        )
        
        line_node2_false = Line(
            node2_false.get_bottom() + DOWN * 0.1,
            node22_false.get_top() + UP * 0.1,
            color=WHITE
        )
        
        # Ramas del nodo 2.1
        line_node21_true = Line(
            node21_true.get_bottom() + DOWN * 0.1,
            node211_true.get_top() + UP * 0.1,
            color=WHITE
        )
        
        line_node21_false = Line(
            node21_true.get_bottom() + DOWN * 0.1,
            node212_false.get_top() + UP * 0.1,
            color=WHITE
        )
        self.play(FadeIn(
            root, 
            line_root_true, 
            line_root_false, 
            true_label_root, 
            false_label_root,
            node1_true,
            line_node1_true,
            line_node1_false,
            true_label_node1,
            false_label_node1,
            node11_true,
            node12_false,
            line_node12_true,
            line_node12_false,
            true_label_node12,
            false_label_node12,
            node121_true,
            node122_false,
            node2_false,
            line_node2_true,
            line_node2_false,
            true_label_node2,
            false_label_node2,
            node21_true,
            node22_false,
            line_node21_true,
            line_node21_false,
            true_label_node21,
            false_label_node21,
            node211_true,
            node212_false
        ))
        self.wait(2)

        self.play(
            root[0].animate.set_fill(BLUE, opacity = 0.3), node1_true[0].animate.set_fill(BLUE, opacity = 0.3)
        )

        self.wait(1)

        self.play(
            root[0].animate.set_fill(BLUE, opacity = 0), node1_true[0].animate.set_fill(BLUE, opacity = 0)
        )

        self.play(
            node12_false[0].animate.set_fill(BLUE, opacity = 0.3),
             node21_true[0].animate.set_fill(BLUE, opacity = 0.3),
             node2_false[0].animate.set_fill(BLUE, opacity = 0.3)

        )

        self.wait(1)
        self.play(
            node12_false[0].animate.set_fill(BLUE, opacity = 0),
             node21_true[0].animate.set_fill(BLUE, opacity = 0),
             node2_false[0].animate.set_fill(BLUE, opacity = 0)
        )

        self.wait(2)

        self.play(
            root[0].animate.set_fill(YELLOW, opacity = 0.3)
        )
        self.wait(1)
        self.play(
            root[0].animate.set_fill(YELLOW, opacity = 0)
        )
        self.wait(1)
        self.play(
            node12_false[0].animate.set_fill(BLUE, opacity = 0.3),
             node21_true[0].animate.set_fill(BLUE, opacity = 0.3),
             node2_false[0].animate.set_fill(BLUE, opacity = 0.3), 
            node1_true[0].animate.set_fill(BLUE, opacity = 0.3)
            
        )
        self.wait(4)

        self.play(
            node12_false[0].animate.set_fill(BLUE, opacity = 0), 
             node21_true[0].animate.set_fill(BLUE, opacity = 0),
             node2_false[0].animate.set_fill(BLUE, opacity = 0), 
            node1_true[0].animate.set_fill(BLUE, opacity = 0)   
        )

        self.wait(1)

        self.play(
            node212_false[0].animate.set_fill(GREEN, opacity = 0.3), 
            node211_true[0].animate.set_fill(GREEN, opacity = 0.3), 
            node22_false[0].animate.set_fill(GREEN, opacity = 0.3), 
            node122_false[0].animate.set_fill(GREEN, opacity = 0.3), 
            node121_true[0].animate.set_fill(GREEN, opacity = 0.3), 
            node11_true[0].animate.set_fill(GREEN, opacity = 0.3)
        )

        self.wait(3)

        self.play(
            node212_false[0].animate.set_fill(GREEN, opacity = 0), 
            node211_true[0].animate.set_fill(GREEN, opacity = 0), 
            node22_false[0].animate.set_fill(GREEN, opacity = 0), 
            node122_false[0].animate.set_fill(GREEN, opacity = 0), 
            node121_true[0].animate.set_fill(GREEN, opacity = 0), 
            node11_true[0].animate.set_fill(GREEN, opacity = 0)
        )

        self.wait(2)


        self.play(
            root[0].animate.set_fill(YELLOW, opacity = 0.3)
        )
        self.wait(1)
        self.play(
            node1_true[0].animate.set_fill(BLUE, opacity = 0.3)
        )
        self.wait(1)
        self.play(
            node12_false[0].animate.set_fill(BLUE, opacity = 0.3)
        )
        self.wait(1)
        self.play(
            node121_true[0].animate.set_fill(GREEN, opacity = 0.3)
        )
        self.wait(3)
        self.play(
            node121_true[0].animate.set_fill(GREEN, opacity = 0),
             node12_false[0].animate.set_fill(BLUE, opacity = 0),
             node1_true[0].animate.set_fill(BLUE, opacity = 0), 
            root[0].animate.set_fill(YELLOW, opacity = 0)
        )    
        self.wait(5)
        
            
        
    def create_node(self, text, color=WHITE, _width =2.2):
        """Crea un nodo rectangular con texto"""
        node = Rectangle(
            width=_width,
            height=0.9,
            color=color,
            stroke_width=2
        )
        
        node_text = Text(text, font_size=20)
        node_text.set_max_width(node.width * 0.85)
        
        node_group = VGroup(node, node_text)
        return node_group

                                                                                                                       

# Decision Tree en Dataset Real

In [None]:
%%manim -v WARNING -qk Division
from manim import *

class Division(Scene):
    def construct(self):
        title = Text(
            "Decision Tree en Dataset Real",
            font_size=50,
            color=YELLOW
        ).to_edge(UP)

        testo1 = Paragraph(
            "Vamos a ver la aplicación del Decision Tree",
            "en el Dataset de Wine:",
            font_size=38
        )

        # Ajustar ancho máximo para evitar overflow
        testo1.scale_to_fit_width(config.frame_width * 0.8)

        # Centrar debajo del título
        testo1.next_to(title, DOWN, buff=0.8)

        self.play(FadeIn(title, testo1))
        self.wait(4)
        self.play(FadeOut(title, testo1))


In [None]:
%%manim -v WARNING -qm WineTreeStructure

from manim import *
import numpy as np
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier


class WineTreeStructure(Scene):

    SCALE = 0.75
    
    def construct(self):
        self.load_and_train()
        self.build_tree_visual()

    def load_and_train(self):

        wine = load_wine()

        self.X = wine.data[:, [0, 1]]
        self.y = wine.target

        self.feature_names = ["Alcohol", "Malic Acid"]
        self.class_names = wine.target_names

        self.model = DecisionTreeClassifier(max_depth=3, random_state=42)
        self.model.fit(self.X, self.y)

    def build_tree_visual(self):

        tree = self.model.tree_

        nodes = {}
        edges = VGroup()

        allowed_nodes = self.get_nodes_by_depth(tree, max_depth=3)
        positions = self.compute_layout(tree, allowed_nodes)

        # Crear nodos
        for node_id in allowed_nodes:
            gini = tree.impurity[node_id]
            # real deteccion gini = 0
            pure_node = np.isclose(gini, 0.0)
            
            if tree.feature[node_id] != -2:

                feature = self.feature_names[tree.feature[node_id]]
                threshold = tree.threshold[node_id]
                #gini = tree.impurity[node_id]
                samples = tree.n_node_samples[node_id]

                text = (
                    f"{feature} ≤ {threshold:.2f}\n"
                    f"Gini = {gini:.2f}\n"
                    f"Samples = {samples}"
                )

                #color = BLUE
                color = GOLD if pure_node else BLUE

            else:
                values = tree.value[node_id][0]
                predicted_class = np.argmax(values)
                class_name = self.class_names[predicted_class]
                #gini = tree.impurity[node_id]
                samples = tree.n_node_samples[node_id]

                text = (
                    f"Clase: {class_name}\n"
                    f"Gini = {gini:.2f}\n"
                    f"Samples = {samples}"
                )

                #color = GREEN
                color = GOLD if pure_node else GREEN 

            node = self.create_node(text, color=color)
            
            if pure_node:
                node[0].set_fill(GOLD, opacity=0.25)
            
            node.move_to(positions[node_id])
            nodes[node_id] = node

        # Crear conexiones
        for node_id in allowed_nodes:

            left = tree.children_left[node_id]
            right = tree.children_right[node_id]

            if left in nodes:
                line = Line(
                    nodes[node_id].get_bottom(),
                    nodes[left].get_top(),
                    stroke_width=2
                )
                label_T = Text("T", font_size=18)
                label_T.next_to(line.point_from_proportion(0.5), LEFT, buff=0.05)

                edges.add(line, label_T)
                #edges.add(line)

            if right in nodes:
                line = Line(
                    nodes[node_id].get_bottom(),
                    nodes[right].get_top(),
                    stroke_width=2
                )
                label_F = Text("F", font_size=18)
                label_F.next_to(line.point_from_proportion(0.5), RIGHT, buff=0.05)

                edges.add(line, label_F)
                #edges.add(line)

        tree_group = VGroup(*nodes.values(), edges)

        tree_group.scale(self.SCALE)

        self.play(FadeIn(VGroup(*nodes.values())))
        self.play(Create(edges))
        self.wait(2)

    def get_nodes_by_depth(self, tree, max_depth):

        allowed = []

        def traverse(node, depth):
            if node == -1 or depth > max_depth:
                return
            allowed.append(node)
            traverse(tree.children_left[node], depth + 1)
            traverse(tree.children_right[node], depth + 1)

        traverse(0, 0)
        return allowed

    def compute_layout(self, tree, allowed_nodes):

        positions = {}
        levels = {}

        def assign_level(node, depth):
            if node == -1 or node not in allowed_nodes:
                return
            levels.setdefault(depth, []).append(node)
            assign_level(tree.children_left[node], depth + 1)
            assign_level(tree.children_right[node], depth + 1)

        assign_level(0, 0)

        for depth, nodes in levels.items():

            width = len(nodes)

            for i, node in enumerate(nodes):

                x = (i - (width - 1) / 2) * 2.2   # antes 3
                y = 3 - depth * 1.6               # antes 2

                positions[node] = np.array([x, y, 0])

        return positions

    def create_node(self, text, color=WHITE):

        rect = Rectangle(
            width=2.2,     # antes 3
            height=1.0,    # antes 1.3
            color=color,
            stroke_width=2
        )

        txt = Text(text, font_size=16)   # antes 20
        txt.set_max_width(rect.width * 0.9)

        return VGroup(rect, txt)

In [None]:
%%manim -v WARNING -qm WineDecisionTreeAnimation
from manim import *
import numpy as np
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier

class WineDecisionTreeAnimation(Scene):
    def construct(self):
        # ------------------------------------------------
        # 1. Cargar dataset y entrenar árbol
        # ------------------------------------------------
        wine = load_wine()
        X = wine.data[:, [0, 1]]   # Alcohol y Malic Acid
        y = wine.target
        model = DecisionTreeClassifier(max_depth=3, random_state=42)
        model.fit(X, y)
        
        # Rangos reales
        x_min, x_max = X[:,0].min(), X[:,0].max()
        y_min, y_max = X[:,1].min(), X[:,1].max()
        
        # ------------------------------------------------
        # 2. Crear ejes
        # ------------------------------------------------
        axes = Axes(
            x_range=[x_min-0.5, x_max+0.5, 1],
            y_range=[y_min-0.5, y_max+0.5, 1],
            x_length=9,
            y_length=6,
            axis_config={"include_numbers": True},
        )
        labels = axes.get_axis_labels("Alcohol", "Malic Acid")
        self.play(Create(axes), Write(labels))
        
        # ------------------------------------------------
        # 3. Dibujar puntos reales
        # ------------------------------------------------
        colors = [RED, GREEN, YELLOW]
        dots = VGroup()
        for xi, yi, label in zip(X[:,0], X[:,1], y):
            dot = Dot(
                axes.coords_to_point(xi, yi),
                color=colors[label],
                radius=0.05
            )
            dots.add(dot)
        self.play(FadeIn(dots))
        self.wait(1)
        
        # ------------------------------------------------
        # 4. Crear regiones de decisión (PRIMERO - FONDO)
        # ------------------------------------------------
        xx, yy = np.meshgrid(
            np.linspace(x_min, x_max, 120),
            np.linspace(y_min, y_max, 120)
        )
        grid = np.c_[xx.ravel(), yy.ravel()]
        Z = model.predict(grid)
        Z = Z.reshape(xx.shape)
        
        region_group = VGroup()
        step_x = (x_max - x_min) / 120
        step_y = (y_max - y_min) / 120
        
        for i in range(120):
            for j in range(120):
                pred_class = Z[i, j]
                rect = Rectangle(
                    width=axes.x_axis.unit_size * step_x,
                    height=axes.y_axis.unit_size * step_y,
                    fill_color=colors[pred_class],
                    fill_opacity=0.2,
                    stroke_width=0
                )
                rect.move_to(
                    axes.coords_to_point(xx[i, j], yy[i, j])
                )
                region_group.add(rect)
        
        # ------------------------------------------------
        # 5. Dibujar splits reales del árbol (SEGUNDO)
        # ------------------------------------------------
        tree = model.tree_
        split_lines = VGroup()
        
        for i in range(tree.node_count):
            feature = tree.feature[i]
            threshold = tree.threshold[i]
            
            if feature == 0:  # Split vertical
                line = axes.get_vertical_line(
                    axes.coords_to_point(threshold, y_max),
                    color=WHITE
                )
            elif feature == 1:  # Split horizontal
                line = axes.get_horizontal_line(
                    axes.coords_to_point(x_max, threshold),
                    color=WHITE
                )
            else:
                continue
            
            line.set_stroke(width=3)
            split_lines.add(line)
        
        # ------------------------------------------------
        # 6. ORDEN CORRECTO DE APARICIÓN
        # ------------------------------------------------
        # Primero las regiones (atrás)
        self.play(FadeIn(region_group), run_time=2)
        self.wait(1)
        
        # Luego las líneas de boundaries (adelante de regiones)
        self.play(Create(split_lines), run_time=1.5)
        self.wait(1)
        
        # Finalmente aseguramos que los puntos estén adelante
        self.bring_to_front(dots)
        self.wait(1)
        
        # Título
        title = Title("Wine Decision Tree – Regiones de Clasificación")
        self.play(Write(title))
        self.wait(3)

In [None]:
%%manim -v WARNING -qm WineMetricsAndOverfitting
from manim import *
import numpy as np
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

class WineMetricsAndOverfitting(Scene):
    def construct(self):

        # -------------------------------------------
        # 1. Cargar dataset y dividir train/test
        # -------------------------------------------
        wine = load_wine()
        X = wine.data[:, [0, 1]]
        y = wine.target

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.3, random_state=42
        )

        # -------------------------------------------
        # 2. Calcular accuracy para distintas profundidades
        # -------------------------------------------
        depths = list(range(1, 13))
        train_acc = []
        test_acc = []

        for d in depths:
            model = DecisionTreeClassifier(max_depth=d, random_state=42)
            model.fit(X_train, y_train)

            train_acc.append(model.score(X_train, y_train))
            test_acc.append(model.score(X_test, y_test))

        # -------------------------------------------
        # 3. Crear ejes para gráfico Overfitting
        # -------------------------------------------
        axes = Axes(
            x_range=[1, 12, 1],
            y_range=[0, 1.05, 0.1],
            x_length=9,
            y_length=5,
            axis_config={"include_numbers": True},
        ).to_edge(DOWN)

        labels = axes.get_axis_labels("Depth", "Accuracy")

        self.play(Create(axes), Write(labels))

        # -------------------------------------------
        # 4. Graficar Train Accuracy
        # -------------------------------------------
        train_graph = axes.plot_line_graph(
            x_values=depths,
            y_values=train_acc,
            line_color=GREEN,
            add_vertex_dots=True
        )

        train_label = Text("Train Accuracy", font_size=28, color=GREEN).to_corner(UL)

        self.play(Create(train_graph), FadeIn(train_label))
        self.wait(1)

        # -------------------------------------------
        # 5. Graficar Test Accuracy
        # -------------------------------------------
        test_graph = axes.plot_line_graph(
            x_values=depths,
            y_values=test_acc,
            line_color=RED,
            add_vertex_dots=True
        )

        test_label = Text("Test Accuracy", font_size=28, color=RED).next_to(train_label, DOWN)

        self.play(Create(test_graph), FadeIn(test_label))
        self.wait(2)

        # -------------------------------------------
        # 6. Explicación visual del overfitting
        # -------------------------------------------
        underfit_text = Text("Underfitting", font_size=32).move_to(axes.coords_to_point(2, 0.3))
        goodfit_text = Text("Good Fit", font_size=32).move_to(axes.coords_to_point(5, 0.8))
        overfit_text = Text("Overfitting", font_size=32).move_to(axes.coords_to_point(10, 0.6))

        self.play(Write(underfit_text))
        self.wait(1)

        self.play(Write(goodfit_text))
        self.wait(1)

        self.play(Write(overfit_text))
        self.wait(2)

        # -------------------------------------------
        # 7. Calcular métricas con profundidad fija
        # -------------------------------------------
        best_model = DecisionTreeClassifier(max_depth=3, random_state=42)
        best_model.fit(X_train, y_train)

        y_pred = best_model.predict(X_test)

        acc = accuracy_score(y_test, y_pred)
        cm = confusion_matrix(y_test, y_pred)

        # -------------------------------------------
        # 8. Mostrar Accuracy en pantalla
        # -------------------------------------------
        acc_text = Text(f"Accuracy (Depth=3): {acc:.2f}", font_size=36)
        acc_text.to_edge(UP)

        self.play(Write(acc_text))
        self.wait(2)


In [None]:
%%manim -v WARNING -qm WineConfusionMatrixDetailed
from manim import *
import numpy as np
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

class WineConfusionMatrixDetailed(Scene):
    def construct(self):

        # --------------------------------------------------
        # 1. Dataset y modelo
        # --------------------------------------------------
        wine = load_wine()
        X = wine.data[:, [0, 1]]
        y = wine.target

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.3, random_state=42
        )

        model = DecisionTreeClassifier(max_depth=3, random_state=42)
        model.fit(X_train, y_train)

        y_pred = model.predict(X_test)
        cm = confusion_matrix(y_test, y_pred)
        TP1 = cm[1,1]
        FN1 = cm[1,0] + cm[1,2]
        FP1 = cm[0,1] + cm[2,1]
        TN1 = cm[0,0] + cm[0,2] + cm[2,0] + cm[2,2]

        # --------------------------------------------------
        # 2. Mostrar matriz base
        # --------------------------------------------------
        title = Title("Confusion Matrix – Wine Dataset")
        self.play(Write(title))

        matrix = Matrix(cm).scale(1.2)
        self.play(FadeIn(matrix))
        self.wait(1)

        # --------------------------------------------------
        # 3. Etiquetas
        # --------------------------------------------------
        actual_label = Text("Actual Class", font_size=28).next_to(matrix, LEFT)
        predicted_label = Text("Predicted Class", font_size=28).next_to(matrix, UP)

        self.play(Write(actual_label), Write(predicted_label))
        self.wait(1)

        # --------------------------------------------------
        # 4. Analizar SOLO Clase 0 (visual)
        # --------------------------------------------------
        explanation = Text(
            "Analizando Clase 0 (One-vs-Rest)",
            font_size=32,
            color=YELLOW
        ).to_edge(DOWN)

        self.play(Write(explanation))
        self.wait(2)

        entries = matrix.get_entries()

        def cell(row, col):
            return entries[row*3 + col]

        # --------------------------------------------------
        # 5. TP
        # --------------------------------------------------
        tp = cell(0, 0)
        tp_value = cm[0, 0]

        tp_box = SurroundingRectangle(tp, color=GREEN, buff=0.1)

        self.play(Create(tp_box))
        self.wait(1)

        tp_number = Text(
            f"TP Clase 0 = {tp_value}",
            font_size=32,
            color=GREEN
        ).to_edge(DOWN)

        self.play(Transform(explanation, tp_number))
        self.wait(2)

        # --------------------------------------------------
        # 6. FN
        # --------------------------------------------------
        fn_group = VGroup(cell(0,1), cell(0,2))
        fn_box = SurroundingRectangle(fn_group, color=ORANGE, buff=0.1)

        fn_text = Text(
            #"FN → Clase 0 predicha como otra clase",
            f"FN (Clase 0 predicha como otra clase) = {FN1}",
            font_size=28,
            color=ORANGE
        ).to_edge(DOWN)

        self.play(Create(fn_box), Transform(explanation, fn_text))
        self.wait(2)

        # --------------------------------------------------
        # 7. FP
        # --------------------------------------------------
        fp_group = VGroup(cell(1,0), cell(2,0))
        fp_box = SurroundingRectangle(fp_group, color=RED, buff=0.1)

        fp_text = Text(
            f"FP (Otra clase predicha como Clase 0) = {FP1}",
            font_size=28,
            color=RED
        ).to_edge(DOWN)

        self.play(Create(fp_box), Transform(explanation, fp_text))
        self.wait(2)

        # --------------------------------------------------
        # 8. TN
        # --------------------------------------------------
        tn_group = VGroup(cell(1,1), cell(1,2), cell(2,1), cell(2,2))
        tn_box = SurroundingRectangle(tn_group, color=BLUE, buff=0.1)

        tn_text = Text(
            f"TN (No es Clase 0 y fue correcto) = {TN1}",
            font_size=28,
            color=BLUE
        ).to_edge(DOWN)

        self.play(Create(tn_box), Transform(explanation, tn_text))
        self.wait(3)

        # --------------------------------------------------
        # 9. Métricas para CLASE 1
        # --------------------------------------------------
        # One-vs-Rest para clase 1
        #TP1 = cm[1,1]
        #FN1 = cm[1,0] + cm[1,2]
        #FP1 = cm[0,1] + cm[2,1]
        #TN1 = cm[0,0] + cm[0,2] + cm[2,0] + cm[2,2]

        precision = TP1 / (TP1 + FP1)
        recall = TP1 / (TP1 + FN1)
        f1 = 2 * (precision * recall) / (precision + recall)

        metrics = VGroup(
            Text("Métricas Clase 1", font_size=34, color=YELLOW),
            Text(f"Precision = {precision:.2f}", font_size=30),
            Text(f"Recall = {recall:.2f}", font_size=30),
            Text(f"F1-score = {f1:.2f}", font_size=30),
        ).arrange(DOWN).to_edge(RIGHT)

        self.play(FadeIn(metrics))
        self.wait(4)


# Random Forest

In [56]:
%%manim -v WARNING -qk SquareDecisionTree
from manim import *
import numpy as np

class SquareDecisionTree(Scene):
    def construct(self):


        root = self.create_node("?", YELLOW, fs = 24, _w = 2, h =1)
        left = self.create_node("Verde", GREEN,fs = 24, _w = 2, h =1)
        right = self.create_node("Rojo", RED,fs = 24, _w = 2, h =1)

        root.shift(UP * 1)
        left.shift(DOWN * 1.5 + LEFT * 3)
        right.shift(DOWN * 1.5 + RIGHT * 3)

        simple_tree = VGroup(
            root,
            left,
            right,
            self.connect(root, left),
            self.connect(root, right)
        )

        self.play(FadeIn(simple_tree))
        self.wait(3)

        # -------------------------------
        # 2) ÁRBOL PROFUNDO (INESTABLE A)
        # -------------------------------
        np.random.seed(67)
        tree_a = self.unstable_deep_tree(structure_bias=0.15)

        self.play(
            FadeOut(simple_tree),
            FadeIn(tree_a)
        )
        self.wait(3)

        # -------------------------------
        # 3) CAMBIO DE REGISTRO → ÁRBOL INESTABLE B
        # -------------------------------
        np.random.seed(666)
        tree_b = self.unstable_deep_tree(structure_bias=0.10)

        self.play(
            Transform(tree_a, tree_b),
            run_time=1
        )

        self.wait(4)

    # --------------------------------------------------
    # HELPERS
    # --------------------------------------------------

    def create_node(self, text, color, fs =8,_w = 0.3, h = 0.2):
        box = Rectangle(
            width=_w,
            height=h,
            color=color,
            stroke_width=2
        )
        label = Text(text, font_size=fs)
        label.move_to(box.get_center())
        return VGroup(box, label)

    def connect(self, parent, child):
        return Line(parent.get_bottom(), child.get_top(), color=WHITE)

    # --------------------------------------------------
    # ÁRBOL INESTABLE (SIN PROFUNDIDAD FIJA)
    # --------------------------------------------------

    def unstable_deep_tree(self, structure_bias=0.3):
        nodes = []
        lines = []

        max_depth = 5
        y_spacing = 1.2
        x_spacing_root = 3.2

        root = self.create_node("?", YELLOW)
        root.move_to([0, 3, 0])
        nodes.append(root)

        def build(parent, depth, x_spacing):
            if depth >= max_depth:
                return

            for side in [-1, 1]:
                stop = np.random.rand() < structure_bias

                x = parent.get_x() + side * x_spacing
                y = parent.get_y() - y_spacing

                if stop or depth == max_depth - 1:
                    is_green = np.random.rand() > 0.5
                    leaf = self.create_node(
                        "Verde" if is_green else "Rojo",
                        GREEN if is_green else RED, fs = 8 , _w = 0.3, h =0.2
                    )
                    leaf.move_to([x, y, 0])
                    nodes.append(leaf)
                    lines.append(self.connect(parent, leaf))
                else:
                    node = self.create_node("?", YELLOW)
                    node.move_to([x, y, 0])
                    nodes.append(node)
                    lines.append(self.connect(parent, node))
                    build(node, depth + 1, x_spacing * 0.5)

        build(root, 0, x_spacing_root)

        return VGroup(*nodes, *lines)


                                                                                                                       

In [57]:
%%manim -v WARNING -qk Gene
from manim import *

class Gene(Scene):
    def construct(self):
        title = Text("Generalización", font_size = 50, color = YELLOW).shift(UP*1.3, LEFT*4.3)
        testo1 = Text("La capacidad de un modelo para hacer buenas\n predicciones sobre datos nuevos que no ha visto antes.", font_size = 38, color = WHITE)

        self.play(FadeIn(title,testo1))
        self.wait(4)
        self.play(FadeOut(title,testo1))

                                                                                                                       

In [58]:
%%manim -v WARNING -qk RandomForestScene
from manim import *
import numpy as np
import random

class RandomForestScene(Scene):

    def construct(self):
        # reproducibilidad
        np.random.seed(0)
        random.seed(0)

        # posiciones horizontales para 5 árboles (ajustadas al ancho de pantalla)
        self.x_positions = np.linspace(-6, 6, 5)

        # 1) Árbol único a la izquierda (se queda fijo)
        tree0 = self.unstable_deep_tree(structure_bias=0.4)
        tree0.scale(0.50)
        tree0.move_to([self.x_positions[0], 0.5, 0])  # centrado verticalmente algo arriba
        self.play(FadeIn(tree0))
        self.wait(1)

        # 2) Ensembling: aparecen 4 árboles adicionales para completar 5 en total
        trees = [tree0]  # lista con el primero ya presente
        for i in range(1, 5):
            t = self.unstable_deep_tree(structure_bias=0.35)
            t.scale(0.50)
            t.move_to([self.x_positions[i], 0.5, 0])
            trees.append(t)

        

        # aparecerán los 4 restantes (el primero ya está en pantalla)
        self.play(
            LaggedStart(*[FadeIn(t) for t in trees[1:]], lag_ratio=0.18)
        )
        self.wait(1)

        subtitle = Text("Ensembling", font_size=40).to_edge(UP)

        self.play(FadeIn(subtitle))

        self.wait(3)
        # 3) Bagging: asignar 2-5 registros por árbol, pero alineados (misma Y inicial)
        self.play(Transform(subtitle, Text("Bagging", font_size=40).to_edge(UP)))
        self.wait(0.4)

        # parámetros de alineado de filas
        top_y = -2          # altura del primer registro (misma para todos los árboles)
        row_spacing = 0.33    # separación vertical entre filas de un mismo árbol

        # guardamos las filas por árbol para usar luego en random_forest
        tree_rows = []
        all_rows = []

        for idx, tree in enumerate(trees):
            n_rows = random.randint(2, 5)  # entre 2 y 5 registros por árbol
            rows_for_tree = VGroup()
            x = self.x_positions[idx]

            for j in range(n_rows):
                r = self.dataset_row()
                # colocar cada fila alineada horizontalmente (misma top_y para j=0)
                r.move_to([x, top_y - j * row_spacing, 0])
                rows_for_tree.add(r)
                all_rows.append(r)

            tree_rows.append(rows_for_tree)

        # animar aparición de todas las filas (aparecen en bloque, alineadas)
        self.play(LaggedStart(*[FadeIn(r) for r in all_rows], lag_ratio=0.02))
        self.wait(1)

        # 4) Random Forest: para cada árbol elegimos un subconjunto de features
        self.play(Transform(subtitle, Text("Random Forest", font_size=40).to_edge(UP)))
        self.wait(5)

        # para cada árbol elegimos un mismo subconjunto para todas sus filas
        for rows_group in tree_rows:
            # elegir k features (2 ó 3) y los índices correspondientes (0..3)
            k = random.choice([2, 3])
            selected = sorted(list(np.random.choice(4, k, replace=False)))

            # transformar todas las filas de este árbol para mostrar solo las columnas seleccionadas
            for row in rows_group:
                # construimos nueva VGroup con las columnas seleccionadas
                reduced = VGroup(*[row[i].copy() for i in selected])
                reduced.arrange(RIGHT, buff=0.05)  # mantener espaciado interno
                # colocamos reduced en la misma posición central que la fila original
                reduced.move_to(row.get_center())
                self.play(Transform(row, reduced), run_time=0.1)

        self.wait(3)

        # 5) Voting: mostrar votos (primeros 3 verdes, últimos 2 rojos) y texto de mayoría
        votes = VGroup()
        vote_y = -1.4
        for i in range(5):
            color = GREEN if i < 3 else RED
            dot = Dot(point=[self.x_positions[i], vote_y, 0], radius=0.25, color=color)
            votes.add(dot)

        self.play(FadeIn(votes))
        self.wait(0.4)

        final_text = Text("Majority vote: Verde", font_size=28)
        final_text.to_edge(DOWN).shift(UP * 0.3)  # mostrarlo justo encima del borde inferior
        self.play(Write(final_text))
        self.wait(2)

    # ------------------------
    # Helpers
    # ------------------------
    def create_node(self, text, color, fs=3, w=0.5, h=0.5):
        box = Rectangle(width=w, height=h, color=color, stroke_width=2)
        label = Text(text, font_size=fs)
        label.move_to(box.get_center())
        return VGroup(box, label)

    def connect(self, parent, child):
        return Line(parent.get_bottom(), child.get_top(), color=WHITE)

    def unstable_deep_tree(self, structure_bias=0.3):
        """
        Genera un árbol binario "inestable". El parámetro structure_bias controla
        la probabilidad de que una rama termine temprano y sea una hoja.
        """
        nodes = []
        lines = []

        max_depth = 5
        y_spacing = 0.8
        x_spacing_root = 1.0

        root = self.create_node("?", YELLOW, fs=26)
        root.move_to([0, 2.6, 0])
        nodes.append(root)

        def build(parent, depth, x_spacing):
            if depth >= max_depth:
                return

            for side in [-1, 1]:
                stop = np.random.rand() < structure_bias

                x = parent.get_x() + side * x_spacing
                y = parent.get_y() - y_spacing

                if stop or depth == max_depth - 1:
                    is_green = np.random.rand() > 0.5
                    leaf = self.create_node(
                        "Verde" if is_green else "Rojo",
                        GREEN if is_green else RED,
                        fs=4, w=0.2, h=0.1
                    )
                    leaf.move_to([x, y, 0])
                    nodes.append(leaf)
                    lines.append(self.connect(parent, leaf))
                else:
                    node = self.create_node("?", YELLOW, fs=4, w=0.2, h=0.1)
                    node.move_to([x, y, 0])
                    nodes.append(node)
                    lines.append(self.connect(parent, node))
                    build(node, depth + 1, x_spacing * 0.5)

        build(root, 0, x_spacing_root)

        return VGroup(*nodes, *lines)

    def dataset_row(self):
        """Devuelve un VGroup con 4 rectangles (features) de colores distintos."""
        colors = [BLUE, ORANGE, PURPLE, GREEN]
        features = VGroup(*[
            Rectangle(
                width=0.35,
                height=0.18,
                fill_color=c,
                fill_opacity=0.95,
                stroke_width=0
            )
            for c in colors
        ])
        features.arrange(RIGHT, buff=0.05)
        return features


                                                                                                                       