In [None]:
# DSU

# Detect cycle in a graph using disjoint sets (union find) - UNDIRECTED GRAPH ONLY AND NO SELF-LOOPS
# assumes no self loops

# Note: while taking union, we have already checked that they had different parents.

def get_parent(x, parents):

    # the boundary is the parent always
    if parents[x] == -1:
        return x

    else:
        return get_parent(parents[x], parents)

def contains_cycle(graph):

    # start with singleton sets
    parents = [-1] * len(graph)

    for x in graph:
        for y in graph[x]:
            print("Parents: ", parents)
            print("Edge: ", x, y)

            # what if A returns 3 because 3 was its own parent and y's parent was 
            parent_x = get_parent(x, parents)
            parent_y = get_parent(y, parents)
            
            if parent_x == parent_y:
                return True

            # take union
            parents[parent_x] = parent_y

    return False

graph = {
    
    0 : [0, 1],
    1 : [0]
}

contains_cycle(graph)

In [None]:
# A union by rank and path compression based program to detect cycle in a graph

from collections import defaultdict

# a structure to represent a graph

class Graph:
	def __init__(self, v):
		self.v = v
		self.edges = defaultdict(list)

	def add_edge(self, u, v):
		self.edges[u].append(v)

class Subset:
	def __init__(self, parent, rank):
		self.parent = parent
		self.rank = rank

# A utility function to find set of an element
# node (uses path compression technique)
def find(subsets, node):
	if subsets[node].parent != node:

		# recursive call. such elegance. essentially, we are saying that 
		# all children, grand-children and stuff, get directly connected to the
		# great-grand grand father. (children are although arleady assigned but get re-assigned)
		subsets[node].parent = find(subsets, subsets[node].parent)
	return subsets[node].parent

# A function that does union of two sets
# of u and v(uses union by rank)

def union(subsets, u, v):

	# Attach smaller rank tree under root of high rank tree (Union by Rank)

	# no change to rank. if while doing union, one set has bigger rank, that becomes the parent.
	if subsets[u].rank > subsets[v].rank:
		subsets[v].parent = u
	elif subsets[v].rank > subsets[u].rank:
		subsets[u].parent = v

	# If ranks are same, then make one as root and increment its rank by one
	else:
		subsets[v].parent = u
		# this is the only case when rank is increased
		subsets[u].rank += 1

# The main function to check whether a given
# graph contains cycle or not
def isCycle(graph):

	# Allocate memory for creating sets
	subsets = []

	# singleton sets equal to number of vertices with 0 rank for init
	# and parents of themselves
	for u in range(graph.v):
		subsets.append(Subset(u, 0))

	# Iterate through all edges of graph,
	# find sets of both vertices of every
	# edge, if sets are same, then there
	# is cycle in graph.
	for u in graph.edges:
		u_parent = find(subsets, u)

		for v in graph.edges[u]:
			v_parent = find(subsets, v)

			if u_parent == v_parent:
				return True
			else:
				# not changing list but just changing attribute of objects within the list
				union(subsets, u_parent, v_parent)

# Driver Code
g = Graph(3)

g.add_edge(0, 1)
g.add_edge(1, 2)
g.add_edge(0, 2)

isCycle(g)

In [49]:
# graph

from collections import deque

# note that s
graph = {
    0: [1, 2, 3],
    1: [0, 2],
    2: [0, 1, 4],
    3: [0],
    4: [2]
}

def bfs(graph, root):
    visited = set([root])
    queue = deque([root])

    while queue:
        popped = queue.popleft()
        print(popped, end=' ')
        
        for neighbour in graph[popped]:
            if (neighbour not in visited):
                visited.add(neighbour)
                queue.append(neighbour)

def dfs(graph, root):

    def helper(node, visited):
        if node in visited:
            return
        
        print(node, end=' ')
        visited.add(node)

        for neib in graph[node]:
            if neib not in visited:
                helper(neib, visited)

    
    visited = set()
    helper(root, visited)


bfs(graph, 0)
print('------')
dfs(graph, 0)

0 1 2 3 4 --------
0 1 2 4 3 

In [91]:
# "overly" complicated class implementation of graphs and dfs/bfs
# just revise bfs and dfs from here

# class Vertex:
#     def __init__(self, name):
#         self.name = name
#         self.neighbours = set()

#     def add_neighbour(self, v):
#         self.neighbours.add(v)
#             # why sort?
#             # self.neighbours.sort()

from collections import deque, defaultdict

class Graph:

    def __init__(self, v=None, adj={}):
        self.vertices = v
        self.adj = adj

    # def add_vertex(self, vertex):
    #     if isinstance(vertex, Vertex) and vertex.name not in self.vertices:
    #         self.vertices[vertex.name] = vertex

    def create_adj_from_list_of_edges(self, edges):

        #self.adj = {i : [] for i in range(n)}
        self.adj = defaultdict(list)

        for edge in edges:
            self.adj[edge[0]].append(edge[1])
            self.adj[edge[1]].append(edge[0])

    def add_edge(self, u, v):
#        if u in self.vertices and v in self.vertices:
        self.vertices[u].add_neighbour(v)
        self.vertices[v].add_neighbour(u)

    def find_bfs_disconnected(self, v):
        
        visited = set()
        queue = deque()

        for start in range(self.vertices):
            if start in visited:
                continue

            print('Fresh start...')

            visited.add(start)
            queue.append(start)

            while(queue):
                popped = queue.popleft()
                print(popped, visited)

                if popped == v:
                    return True

                for neighbour in self.adj[popped]:
                    if neighbour not in visited:
                        print('adding: ', neighbour)
                        visited.add(neighbour)
                        queue.append(neighbour)

        return False

    def find_bfs(self, v):
        
        root = 0
        visited = set([root])
        queue = deque([root])

        while(queue):
            popped = queue.popleft()
            print(popped, visited)

            if popped == v:
                return True

            for neighbour in self.adj[popped]:
                if neighbour not in visited:
                    print('adding: ', neighbour)
                    visited.add(neighbour)
                    queue.append(neighbour)

        return False

    def find_dfs_disconnected(self, v):

        visited = set()

        for start in range(self.vertices):
            if start not in visited:
                print('Fresh start...')
                if self.find_dfs_helper(start, v, visited):
                    return True

        return False

    def find_dfs(self, v):
        visited = set()
        return self.find_dfs_helper(0, v, visited)

    def find_dfs_helper(self, start, v, visited):

        if start in visited:
            print("Already visited.")
            return False

        if start == v:
            return True

        visited.add(start)
        print("Visiting ", start, " \t\t\t | Visited: ", visited)

        for neighbour in self.adj[start]:
            print("Neighbour: ", neighbour)
            if self.find_dfs_helper(neighbour, v, visited):
                return True

        print("Nothing Found In Vicinity of ", start)

        return False

    def find_dfs_stack_disconnected(self, v):

        stack = []
        visited = set()

        for start in range(self.vertices):
            
            if start in visited: continue

            print('Fresh start...')

            stack.append(start)

            while stack:
                popped = stack.pop()
                print("Popped: ", popped)

                if popped == v:
                    return True
                    
                for neighbour in self.adj[popped]:
                    if neighbour not in visited:
                        print("Neighbour: ", neighbour)
                        visited.add(neighbour)
                        stack.append(neighbour)

        return False


# in the form adjacency list. then may not need vertex class but looks like a good practice to have separate Vertex class
# overly complicated. just ignore.
# graph = Graph()
# graph.add_vertex(Vertex(0))
# graph.add_vertex(Vertex(1))
# graph.add_vertex(Vertex(2))
# graph.add_vertex(Vertex(3))
# graph.add_vertex(Vertex(4))

# edges = ['01', '02', '03', '12', '24']

# for edge in edges:
#     graph.add_edge(int(edge[0]), int(edge[1]))


print('--------------------------')
edges = [[1, 0], [2, 0], [1, 2]]
graph = Graph()
graph.create_adj_from_list_of_edges(edges)
print(graph.adj)

adj = {
    0 : [1, 2, 3],
    1 : [0, 2],
    2 : [0, 1, 4],
    3 : [0],
    4 : [2]
}

adj_disconnected = {
    0 : [1, 2, 3],
    1 : [0, 2],
    2 : [0, 1],
    3 : [0],
    4 : []
}

graph = Graph(5, adj)
# print(graph.find_bfs(4))
# print('---------------')

# print(graph.find_dfs(5))
# print('---------------')

graph = Graph(5, adj_disconnected)

print(graph.find_bfs_disconnected(4))
print('---------------')
print(graph.find_dfs_disconnected(4))
print('---------------')
print(graph.find_dfs_stack_disconnected(4))
print('---------------')

Fresh start...
0 {0}
adding:  1
adding:  2
adding:  3
1 {0, 1, 2, 3}
2 {0, 1, 2, 3}
3 {0, 1, 2, 3}
Fresh start...
4 {0, 1, 2, 3, 4}
True
---------------
Fresh start...
Visiting  0  			 | Visited:  {0}
Neighbour:  1
Visiting  1  			 | Visited:  {0, 1}
Neighbour:  0
Already visited.
Neighbour:  2
Visiting  2  			 | Visited:  {0, 1, 2}
Neighbour:  0
Already visited.
Neighbour:  1
Already visited.
Nothing Found In Vicinity of  2
Nothing Found In Vicinity of  1
Neighbour:  2
Already visited.
Neighbour:  3
Visiting  3  			 | Visited:  {0, 1, 2, 3}
Neighbour:  0
Already visited.
Nothing Found In Vicinity of  3
Nothing Found In Vicinity of  0
Fresh start...
True
---------------
Fresh start...
Popped:  0
Neighbour:  1
Neighbour:  2
Neighbour:  3
Popped:  3
Neighbour:  0
Popped:  0
Popped:  2
Popped:  1
Fresh start...
Popped:  4
True
---------------


In [94]:
# Detect cycle in an undirected graph using DFS

# graph = {
#     0 : [1, 2],
#     1 : [0, 3],
#     2 : [0, 3],
#     3 : [1, 2]
# }

# graph = {
#     0 : [1],
#     1 : [0, 1]
# }

graph = {
    0 : [1],
    1 : [0],
    2 : [3, 4],
    3 : [2, 4],
    4 : [3, 2]
}

def helper(v, visited, parent):
    print('v, parent', v, parent)
    if v in visited:
        return True

    visited.add(v)
    
    for neib in graph[v]:
        
        if neib != parent:
            print('neib', neib)
            if helper(neib, visited, v):
                return True
        else:
            print('neib', neib, 'is the parent so ignore.')

    return False
        
def detect_cycle_dfs(graph):
    visited = set()

    # I think this for loop is only needed for disconnected graphs
    for v in graph:
        print('--------------')
        print('new start: ', v)
        if v in visited:
            continue

        if helper(v, visited, -1):
            return True

    return False

print(detect_cycle_dfs(graph))

--------------
new start:  0
v, parent 0 -1
neib 1
v, parent 1 0
neib 0 is the parent so ignore.
--------------
new start:  1
--------------
new start:  2
v, parent 2 -1
neib 3
v, parent 3 2
neib 2 is the parent so ignore.
neib 4
v, parent 4 3
neib 3 is the parent so ignore.
neib 2
v, parent 2 4
True


In [None]:
# here a list of edges results in a simpler implementation

class Subset:
    def __init__(self, node, rank):
        self.parent = node
        self.rank = 0

class Graph:
    def __init__(self, v):
        self.v = v
        self.edges = []

    def add_edge(self, u, v, c):
        self.edges.append((u, v, c))

    def remove_self_loops(self):
        for u, v, c in self.edges:
            if u == v:
                self.edges.remove((u, v, c))

    def remove_parallel_paths(self):
        # based on cost
        visited = {}
        edges = self.edges[:]

        for u, v, c in edges:
            #print('u,v,c:', u,v,c)
            if (u, v) in visited:
                if c < visited[(u, v)]:
                    #print("Here: ", self.edges, 'removing: ', (u, v, visited[(u, v)]))
                    visited[(u, v)] = c
                    self.edges.remove((u, v, visited[(u, v)]))
                else:
                    #print("Here: ", self.edges, 'removing: ', (u, v, c))
                    self.edges.remove((u, v, c))
                    
            elif (v, u) in visited:
                if c < visited[(v, u)]:
                    # if had to reverse, then this means we'll need reverse again while removing
                    visited[(v, u)] = c
                    self.edges.remove((u, v, visited[(v, u)]))
                else:
                    #print("Here: ", self.edges, 'removing: ', (u, v, c))
                    self.edges.remove((u, v, c))

            else:
                visited[(u, v)] = c

    def find_parent(self, parents, node):
        if parents[node] == -1:
            return node
        else:
            return self.find_parent(parents, parents[node])

    def mst_kruskals(self):
        self.remove_self_loops()
        self.remove_parallel_paths()

        self.edges = sorted(self.edges, key=lambda x: x[2])
        
        results = []
        total_cost = 0
        
        parents = [-1] * self.v

        for u, v, c in self.edges:
            parent_u = self.find_parent(parents, u)
            parent_v = self.find_parent(parents, v)

            if parent_u == parent_v:
                #print('Avoiding cycle...')
                continue
            
            results.append((u, v))
            total_cost += c
            parents[parent_u] = parent_v
            #print('Parents: ', parents)
            
        return results, total_cost

    def find_path_compression(self, parents, x):
        # condition to essentially detect non-roots
        if parents[x].parent != x:
            parents[x].parent = self.find_path_compression(parents, parents[x].parent)

        # basically all the children will now directly deal with their gread-grand-grand father
        return parents[x].parent    

    def union_rank(self, parents, u, v):
        if parents[u].rank == parents[v].rank:
            parents[u].rank += 1
            parents[v].parent = u

        elif parents[u].rank > parents[v].rank:
            parents[v].parent = u
        else:
            parents[u].parent = v

    def mst_kruskals_optimized(self):
        # path compression and union by rank
        self.remove_self_loops()
        self.remove_parallel_paths()

        self.edges = sorted(self.edges, key=lambda x: x[2])
        
        results = []
        total_cost = 0
        parents = [Subset(i, 0) for i in range(self.v)]

        for u, v, c in self.edges:
            parent_u = self.find_path_compression(parents, u)
            parent_v = self.find_path_compression(parents, v)

            if parent_u == parent_v:
                #print('Avoiding cycle...')
                continue
            
            results.append((u, v))
            total_cost += c
            self.union_rank(parents, parent_u, parent_v)
            #print('Parents: ', parents)
            
        return results, total_cost

g = Graph(4)
g.add_edge(0, 1, 10)

# self-loop
g.add_edge(0, 0, 4)

g.add_edge(0, 2, 6)
g.add_edge(0, 3, 5)
g.add_edge(1, 3, 15)
g.add_edge(2, 3, 4)
# parallel path
g.add_edge(2, 3, 4)

#print(g.edges)
print(g.mst_kruskals())
print(g.mst_kruskals_optimized())

In [None]:
# Prim's using Heap

import heapq as hq
from heapq import heappush, heappop

def minimum_spanning_tree_cost(graph):
    """Return the sum of the costs of the edges in the minimum spanning
    tree for the given graph, which must be a mapping from nodes to an
    iterable of (neighbour, edge-cost) pairs.
    """
    total = 0                   # Total cost of edges in tree
    start = next(iter(graph))   # Arbitrary starting vertex # THE HIGHLIGHT OF THE CODE
    unexplored = [(0, 'start', start)]   # Unexplored edges ordered by cost
    explored = set()            # Set of vertices in tree    
    result = []

    # I think it's easier talking about what this chunk of code is doing overall.
    # It's a standard BFS so we visit each vertex and explore all edges (O(V + E)).
    # Everytime we explore an edge, we heappush which is log. Every time we heappop
    # while visiting a node, that is log. Since we push all edges in the heap
    # the log is gonna be log E. so O(V log E + E log E)
    # Space looks to be O(E) for the heap and O(V) for the explored set()

    while unexplored: # O(n)
        cost, src, dest = heappop(unexplored) # O(log(n))
        if dest not in explored:
            explored.add(dest)
            result.append((src, dest, cost))
            total += cost
            for neib, cost in graph[dest]: # O(V) if this is connected to every other node
                if neib not in explored:
                    heappush(unexplored, (cost, dest, neib)) #

    return result, total

graph = {
    'S': [('C', 8), ('A', 7)],
    'A': [('S', 7), ('C', 3), ('B', 6), ('B', 99)],
    'C': [('S', 8), ('B', 4), ('A', 3), ('D', 3)],
    'B': [('A', 6), ('C', 4), ('D', 2), ('T', 5), ('A', 99)],
    'D': [('C', 3), ('B', 2), ('T', 2)],
    'T': [('B', 5), ('D', 2)]
}

#print(prim_heap(graph))
minimum_spanning_tree_cost(graph)

In [20]:
# djiktra's shortest path (can work on both directed and undirected graphs)

from queue import PriorityQueue
import heapq as hq
from heapq import heappush, heappop
from collections import defaultdict

class Graph:

    def __init__(self, v):
        self.v = v
        self.graph = defaultdict(list)

    def addEdge(self, u, v, w):
        self.graph[u].append((v, w))
        self.graph[v].append((u, w))

    def dijkstra_adj_mat(self, source):
        visited = set()
        unvisited = set(range(self.v))

        distances = [float('inf')] * self.v
        distances[source] = 0
        prev = [-1] * self.v
        
        q = PriorityQueue()
        q.put((0, source))

        while unvisited: # O(V)
            _, node = q.get() # O(log(heap size))

            if node in visited:
                continue

            unvisited.remove(node)
            visited.add(node)
 
            # benefit of 2d adj mat was that we could enumerate easily
            for i, neib_dist in enumerate(self.graph[node]):
                # == 0 condition coz of adj. matrxi structure
                if neib_dist == 0 or i in visited:
                    #print('Continue...')
                    continue

                if neib_dist + distances[node] < distances[i]:
                    distances[i] = neib_dist + distances[node]
                    prev[i] = node

                    # add "potential" candidate in the queue
                    q.put((distances[i], i)) # O(log(heap size))
            
            # O((E+1) log(heap size))
            # Heap size -> VE -> V^2
            # O(Elog(V))

        return list(zip(distances, prev))

    def djikstras(self, source):
        # there are some implemetations that do decrease operation so their
        # heap never has duplicates. so the complexity is determined accordingly.
        q = [(0, source)] 
        visited = set()
        # (cost, parent) tuple dict
        distances = {i : (float('inf'), 'Undefined') for i in range(self.v)}
        distances[source] = (0, -1)

        while len(visited) != self.v:
            _, node = heappop(q)

            if node in visited:
                continue
            
            visited.add(node)

            for neib, cost in self.graph[node]:
                if neib not in visited:
                    if distances[neib][0] > cost + distances[node][0]:
                        distances[neib] = (cost + distances[node][0], node)
                        heappush(q, (distances[neib], neib))

        return distances

graph = Graph(9)
graph.addEdge(0, 1, 4)
graph.addEdge(0, 7, 8)
graph.addEdge(1, 2, 8)
graph.addEdge(1, 7, 11)
graph.addEdge(2, 3, 7)
graph.addEdge(2, 8, 2)
graph.addEdge(2, 5, 4)
graph.addEdge(3, 4, 9)
graph.addEdge(3, 5, 14)
graph.addEdge(4, 5, 10)
graph.addEdge(5, 6, 2)
graph.addEdge(6, 7, 1)
graph.addEdge(6, 8, 6)
graph.addEdge(7, 8, 7)
#print(graph.graph)
print(graph.djikstras(0))

{0: (0, -1), 1: (4, 0), 2: (12, 1), 3: (19, 2), 4: (21, 5), 5: (11, 6), 6: (9, 7), 7: (8, 0), 8: (14, 2)}


In [2]:
# binary search tree

class Node:
    def __init__(self, data):
        self.data = data
        self.right = None
        self.left = None

# might just be better creating a separate class

class Tree:
    def __init__(self):
        self.root = None

    def insert_recur(self, data):

        if not self.root:
            self.root = Node(data)
            return self.root

        self.insert_recur_helper(self.root, data)

    def insert_recur_helper(self, node, data):

        if not node:
            return Node(data)

        if data > node.data:
            node.right = self.insert_recur(node.right, data)

        elif data < node.data:
            node.left = self.insert_recur(node.left, data)

        return node
        
    def insert(self, data):
        if not self.root:
            #print("Initializing tree...")
            self.root = Node(data)
            return self.root

        current = self.root

        while current:
            #print("data: ", data, "current: ", current.data)
            if current.data > data:
                if current.left:
                    current = current.left
                else:
                    #print("Saved data to the left of", current.data)
                    current.left = Node(data)
                    break
            else:
                if current.right:
                    current = current.right
                else:
                    ##print("Saved data to the right of", current.data)
                    current.right = Node(data)
                    break

    def contains(self, root, data):
        if root.data == data:
            return True

        else:
            if data < root.data:
                if root.left:
                    return self.contains(root.left, data)
                else:
                    return False

            else:
                if root.right:
                    return self.contains(root.right, data)
                else:
                    return False

    def find(self, root, data):
        if root.data == data:
            return root

        else:
            if data < root.data:
                if root.left:
                    return self.find(root.left, data)
                else:
                    return None

            else:
                if root.right:
                    return self.find(root.right, data)
                else:
                    return None

    def in_order_print(self, root):
        if root:
            self.in_order_print(root.left)
            print(root.data)
            self.in_order_print(root.right)

    def traverse(self, mode, root):

        if root:
            if mode == 1:
                self.traverse(mode, root.left)
                print(root.data)
                self.traverse(mode, root.right)
            elif mode == 2:
                print(root.data)
                self.traverse(mode, root.left)
                self.traverse(mode, root.right)

            elif mode == 3:
                self.traverse(mode, root.left)
                self.traverse(mode, root.right)
                print(root.data)

            else:
                raise("Invalid mode selected.")

    # this was when I wasn't able to write a recursive function and wanted to do recursive.
    # the "obvious" method to do this without recursion is stack.
    # def in_order_print(self):
    #     pass
        # traverse in order

        # 1) Create an empty stack S.
        # 2) Initialize current node as root
        # 3) Push the current node to S and set current = current->left until current is NULL
        # 4) If current is NULL and stack is not empty then
        #      a) Pop the top item from stack.
        #      b) Print the popped item, set current = popped_item->right
        #      c) Go to step 3.
        # 5) If current is NULL and stack is empty then we are done.


    def minValueNode(self, node):
        current = node

        # loop down to find the leftmost leaf
        while current.left:
            current = current.left

        return current

    def deleteNode(self, root, data):
        if not root:
            return None

        if data < root.data:
            root.left = self.deleteNode(root.left, data)

        elif data > root.data:
            root.right = self.deleteNode(root.right, data)

        else:
            # Node with only right child or no child
            if root.left is None:
                temp = root.right
                root = None  # i guess just to demonstrate
                return temp  # deleting is essentially returning child

            # node with only left child
            elif root.right is None:
                temp = root.left
                root = None  # i guess just to demonstrate
                return temp  # deleting is essentially returning child

            # Node with two children:
            # Get the inorder successor
            # (smallest in the right subtree)
            temp = self.minValueNode(root.right)
            root.data = temp.data  # deleting is essentially returning successor
            root.right = self.deleteNode(root.right, temp.data)

        return root

tree = Tree()
root = tree.insert(5)
tree.insert(4)
tree.insert(3)
tree.insert(2)
tree.insert(1)

print("----------------------------")
tree.traverse(mode=1, root=tree.root)
print("----------------------------")
tree.traverse(mode=2, root=tree.root)
print("----------------------------")
tree.traverse(mode=3, root=tree.root)

print(tree.find(root, 3))
tree.deleteNode(root, 3)

print("----------------------------")
tree.traverse(mode=1, root=tree.root)

print(tree.contains(root, 3))

In [None]:
# Python code to insert a node in AVL tree

""""

"You may have to rebalance during insertions, that much is clear: inserting a sorted sequence of values would otherwise lead to a degenerate tree.

By looking at the setup graphics of the four types of rotation, you can easily see that the height of the affected subtree after the insertion is the same as before. Thus, no node outside of this subtree can be imbalanced.

Thus, we never need more than one rotation when inserting. We can stop moving up the tree after we've rotated once, and efficient (non-recursive) implementations do exactly that."

"""

# Generic tree node class

class AVLNode(object):
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
        self.height = 1

class AVLTree(object):
    # Recursive function to insert key in
    # subtree rooted with root and returns
    # new root of subtree.
    def insert(self, root, key):

        # Step 1 - Perform normal BST
        if not root:
            return AVLNode(key)

        elif key < root.val:
            root.left = self.insert(root.left, key)

        else:
            root.right = self.insert(root.right, key)

        # Step 2 - Update the height of the ancestor node
        root.height = 1 + max(self.getHeight(root.left),
                              self.getHeight(root.right))

        # Step 3 - Get the balance factor
        balance = self.getBalance(root)

        # Since we are going upwards after inserting 
        # the new key, if the tree is become imbalanced,
        # we'll eventually find the first subtree that 
        # is imbalanced and fixing to its previous
        # balanced state will balance up the whole tree. 

        # Step 4 - If the node is unbalanced,
        # then try out the 4 cases

        # + ive balance clearly means the left subtree's height is more. 
        # the question then becomes, which case is it and that can be determined
        # by getting a sense of direction that where was the new node inserted.

        # Case 1 - Left Left
        if balance > 1 and key < root.left.val:
            return self.rightRotate(root)

        # Case 2 - Right Right
        if balance < -1 and key > root.right.val:
            return self.leftRotate(root)

        # Case 3 - Left Right
        if balance > 1 and key > root.left.val:
            root.left = self.leftRotate(root.left)
            return self.rightRotate(root)

        # Case 4 - Right Left
        if balance < -1 and key < root.right.val:
            root.right = self.rightRotate(root.right)
            return self.leftRotate(root)

        return root

    def leftRotate(self, z):
        y = z.right
        T2 = y.left

        # Perform rotation
        y.left = z
        z.right = T2

        # Update heights
        z.height = 1 + max(self.getHeight(z.left),
                           self.getHeight(z.right))
        y.height = 1 + max(self.getHeight(y.left),
                           self.getHeight(y.right))

        # Return the new root
        return y

    def rightRotate(self, z):
        y = z.left
        T3 = y.right

        # Perform rotation
        y.right = z
        z.left = T3

        # Update heights
        z.height = 1 + max(self.getHeight(z.left),
                           self.getHeight(z.right))
        y.height = 1 + max(self.getHeight(y.left),
                           self.getHeight(y.right))

        # Return the new root
        return y

    def getHeight(self, root):
        if not root:
            return 0

        return root.height

    def getBalance(self, root):
        if not root:
            return 0

        return self.getHeight(root.left) - self.getHeight(root.right)

    def preOrder(self, root):
        if not root:
            return

        print("{0} ".format(root.val), end="")
        self.preOrder(root.left)
        self.preOrder(root.right)


myTree = AVLTree()
# root = None
root = myTree.insert(None, 10)
root = myTree.insert(root, 20)
root = myTree.insert(root, 30)
root = myTree.insert(root, 40)
root = myTree.insert(root, 50)
root = myTree.insert(root, 25)
 
"""The constructed AVL Tree would be
            30
           /  \
         20   40
        /  \     \
       10  25    50"""

myTree.preOrder(root)

# This code is contributed by Ajitesh Pathak

In [None]:
# red black trees

# review binary search tree: ordered/sorted binary tree. left is small and right is large. balanced guarantrees log(n) access.

# red black trees (self-balancing binary search tree):
# - a node is either red/black
# - root is black
# - root and leaves (NIL) are black
# - if node is red, children black
# - all paths from a node to its NIL descendants contain same number of black nodes

In [6]:
# height of a tree (general not just binary search tree)

def get_height(root):
    if not root:
        return -1

    return 1 + max(get_height(root.left), get_height(root.right))

root_1 = Node(8)
root_1.left = Node(3)
root_1.right = Node(10)
root_1.left.left = Node(1)
root_1.left.right = Node(6)
root_1.left.right.left = Node(4)
root_1.left.right.right = Node(7)

print(get_height(root_1))


3


In [None]:
# non-binary tree with multiple chidlren

class NodeMultiple:
    def __init__(self, data):
        self.data = data
        self.children = []

def find(root, target):

    if root.data == target:
        return root

    for child in root.children:
        # see why the False condition should be catched. Because, we wanna backtrack and explore other options as well
        if find(child, target):
            return child

    return False

root_1 = NodeMultiple(1)
root_1.children.append(NodeMultiple(2))
root_1.children.append(NodeMultiple(3))

subtree_root = find(root_1, 3)
print(subtree_root.data)