In [211]:
%load_ext line_profiler

### 1. Build Heap

In [134]:
input_list = [5,4,3,2,1]

def left_child(input_list, index):
    child_index = (2*index) + 1
    if child_index < len(input_list):
        return child_index 
    else:
        return index

def right_child(input_list, index):
    child_index = (2*index) + 2
    if child_index < len(input_list):
        return child_index 
    else:
        return index

def parent(input_list, index):
    return (index-1)//2

def sift_down(input_list, index):
    curr_index = index
    swaps = []

    while (input_list[curr_index] > input_list[left_child(input_list, curr_index)]) or (input_list[curr_index] > input_list[right_child(input_list, curr_index)]):  
        # print(f"sift_down: {curr_index=}, {input_list=}, {left_child(input_list, curr_index)=}, {right_child(input_list, curr_index)=}")
        swap_index = left_child(input_list, curr_index) if input_list[left_child(input_list, curr_index)] < input_list[right_child(input_list, curr_index)] else right_child(input_list, curr_index)

        input_list[curr_index], input_list[swap_index] = input_list[swap_index], input_list[curr_index]
        swaps.append((curr_index, swap_index))
        
        curr_index = swap_index
        # print(f"sift_down: {input_list=}")
    return swaps

def heapify(input_list):
    start_index = len(input_list)//2
    swaps = []
    for index in range(start_index, -1, -1):
        # print(f"Heapify: {index=}, {input_list=}")
        additional_swaps = sift_down(input_list, index)
        swaps.extend(additional_swaps)
        # print(f"Heapify: {index=}, {input_list=}")
    
    return swaps

heapify(input_list)

[(1, 4), (0, 1), (1, 3)]

### 2. Job Queue

- Design
    - Queue to contain all jobs 
    - Priority queue (MinHeap) to contain all threads
    - Each element of MinHeap will contain a thread ID and a next_time_available attribute
        - Heapify using next_time_available
        - Break ties using thread ID

    - For each job in queue
        - ChangePriority for the root node
            - Update next_time_available = next_time_available + job time
        - Heapify
        - Store thread ID and start time as a tuple

In [136]:
from collections import namedtuple
from dataclasses import dataclass
import math

@dataclass
class Thread:
    thread_id: int
    next_available_time: int

    def __repr__(self):
        return f'(id={self.thread_id}, {self.next_available_time})'

def left_child(index):
    child_index = (2*index) + 1
    return child_index

def right_child(index):
    child_index = (2*index) + 2
    if child_index < len(pqueue):
        return child_index
    return child_index

def sift_down(index, pqueue, verbose=False):
    curr_index = index

    while True:
        # if verbose:
        #     print('='*50)
        #     print('sift_down')
        #     print(f"{curr_index=}, {left_child_index=}, {right_child_index=}, {pqueue=}")

        left_child_index = left_child(curr_index) if left_child(curr_index) < len(pqueue) else None
        right_child_index = right_child(curr_index) if right_child(curr_index) < len(pqueue) else None

        if (left_child_index is None) and (right_child_index is None):
            return

        compare_index = [x for x in [curr_index, left_child_index, right_child_index] if x is not None]

        ordered_index = sorted(
            compare_index,
            key = lambda i: [pqueue[i].next_available_time, pqueue[i].thread_id]
        )
        
        # if verbose:
        #     print(f"{ordered_index=}")

        if ordered_index[0] == curr_index:
            return
        else:
            pqueue[curr_index], pqueue[ordered_index[0]] = pqueue[ordered_index[0]], pqueue[curr_index]
            curr_index = ordered_index[0]
        
def heapify(pqueue, verbose=False):
    start_index = (len(pqueue)-1)//2
    for index in range(start_index, -1, -1):
        sift_down(index, pqueue, verbose)
    return pqueue

n_workers = 4
jobs = [1]*20
pqueue = [Thread(i, 0) for i in range(n_workers)]

def assign_jobs(n_workers, jobs, pqueue, verbose=False):

    job_queue = jobs
    process_log = []
    heapify(pqueue, verbose)

    for job_time_taken in job_queue:
        if verbose:
            print('='*50)
            print('assign_jobs')
            print([(x.thread_id, x.next_available_time) for x in pqueue])
        
        assigned_thread = pqueue[0]
        process_log.append((assigned_thread.thread_id, assigned_thread.next_available_time))
        assigned_thread.next_available_time += job_time_taken
        
        if verbose:
            print([(x.thread_id, x.next_available_time) for x in pqueue])
        heapify(pqueue, verbose)
        if verbose:
            print([(x.thread_id, x.next_available_time) for x in pqueue])
    
    return process_log

assign_jobs(n_workers, jobs, pqueue, verbose=False)

# # heapify([])
# # thread_prio_queue
# process_log
# pqueue

[(0, 0),
 (1, 0),
 (2, 0),
 (3, 0),
 (0, 1),
 (1, 1),
 (2, 1),
 (3, 1),
 (0, 2),
 (1, 2),
 (2, 2),
 (3, 2),
 (0, 3),
 (1, 3),
 (2, 3),
 (3, 3),
 (0, 4),
 (1, 4),
 (2, 4),
 (3, 4)]

### 3. Merging tables

- Each table x has some count of rows x.rows
- There are $n$ such tables
- We do $m$ operations to merge these tables
- Return the size of the largest table at the end of each operation

In [309]:
from dataclasses import dataclass

class Database:
    def __init__(self, row_counts):
        self.row_counts = row_counts
        self.max_row_count = max(row_counts)
        n_tables = len(row_counts)
        self.ranks = [1] * n_tables
        self.parents = list(range(n_tables))

    def merge(self, src, dst, verbose=False):
        # if verbose:
        #     print('='*50)
        #     print(f"{(src, dst)=}")
        #     print(f"{self.row_counts=}")
        #     print(f"{self.ranks=}")
        #     print(f"{self.parents=}")

        src_parent = self.get_parent(src)
        dst_parent = self.get_parent(dst)

        if src_parent == dst_parent:
            return False

        if self.ranks[src_parent] > self.ranks[dst_parent]:
            self.parents[dst_parent] = src_parent
            self.row_counts[src_parent] += self.row_counts[dst_parent]
            self.row_counts[dst_parent] = 0
            if self.max_row_count < self.row_counts[src_parent]:
                self.max_row_count = self.row_counts[src_parent]
        else:
            self.parents[src_parent] = dst_parent
            self.row_counts[dst_parent] += self.row_counts[src_parent]
            self.row_counts[src_parent] = 0
            if self.max_row_count < self.row_counts[dst_parent]:
                self.max_row_count = self.row_counts[dst_parent]
            if self.ranks[src_parent] == self.ranks[dst_parent]:
                self.ranks[dst_parent] += 1    
        
        # if verbose:
        #     print(f"{(src, dst)=}")
        #     print(f"{self.row_counts=}")
        #     print(f"{self.ranks=}")
        #     print(f"{self.parents=}")

        return True

    def get_parent(self, table):
        # find parent and compress path
        curr_table = table
        path = []

        while curr_table != self.parents[curr_table]:
            path.append(curr_table)
            curr_table = self.parents[curr_table]
        
        for node in path:
            self.parents[node] = curr_table

        return curr_table

In [310]:
# del Database

In [312]:
# n_tables, n_queries = 6,4
# counts = [10,0,5,0,3,3]
# # ops = [(3,5), (2,4), (1,4), (5,4), (5,3)] 
# ops = [(6,6),(6,5),(5,4),(4,3)] 

import numpy as np
n_tables = np.random.randint(10_000, 10_001, 1)[0]
n_queries = np.random.randint(10_000, 10_001, 1)[0]
counts = list(np.random.randint(1, 10_000, n_tables))
ops = [tuple(np.random.randint(1, n_tables, 2)) for _ in range(n_queries)]
print(f'{n_tables}, {n_queries}')

assert len(counts) == n_tables
db = Database(counts)

for i in range(n_queries):
    dst, src = ops[i]
    
    %lprun -f Database.merge db.merge(dst - 1, src - 1, verbose=False)
    
    # db.merge(dst - 1, src - 1, verbose=False)
    # print(db.max_row_count)


10000, 10000


Timer unit: 1e-09 s

Total time: 6e-06 s
File: /var/folders/sz/cgf6qmyj36bcgkz0rn5dphyw0000gr/T/ipykernel_78827/3898043228.py
Function: merge at line 11

Line #      Hits         Time  Per Hit   % Time  Line Contents
    11                                               def merge(self, src, dst, verbose=False):
    12                                                   # if verbose:
    13                                                   #     print('='*50)
    14                                                   #     print(f"{(src, dst)=}")
    15                                                   #     print(f"{self.row_counts=}")
    16                                                   #     print(f"{self.ranks=}")
    17                                                   #     print(f"{self.parents=}")
    18                                           
    19         1       4000.0   4000.0     66.7          src_parent = self.get_parent(src)
    20         1       2000.0   2000.0     3