<a href="https://colab.research.google.com/github/tahsina13/Puzzle-Book/blob/main/PrincePrincess/PrincePrincess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prince and Princess Problem

Problem Statement:

> Long long ago there lived a beautiful and intelligent princess who was obsessed with mathematics . When the time came she decided to marry one of the two most intelli- gent and good - natured princes : A and B. The two princes were called to the court and were asked to stand in front of the princess .

> A natural number was written on the crown of each prince . No prince could see the number on his own crown but he could see the number written on the other crown . The princess wrote two distinct natural numbers on a board and announced that one of the board numbers is the sum of the numbers written on their crowns . The princess would ask prince A : " do you know the number written on your crown ? " If A's answer was " no " , the princess would ask prince B the same question . If B's answer was " no " , the princess would ask prince A the same question . This process would continue in the cyclic order A , B , A , B , …… . until a prince answers " yes " . The prince who would correctly guess the number written on his crown would win the opportunity to marry the princess .

> The assumptions in the problem are : the two princes were equally intelligent and truthful , no other form of communication was allowed , and each prince was able to hear the answer of the other prince . What happened eventually ? Did the princess get married ? If no , how ? If yes , to whom ?

> The counterintuitive result is that the princess got married eventually to one of the two princes ! How is this possible ?

-- [Puzzle Book](https://https://www.google.com/books/edition/Mathematical_and_Algorithmic_Puzzles/C7q6EAAAQBAJ?hl=en&gbpv=0&kptab=getbook), pg. 100

In [None]:
# Generate all possible partitions of x with n princes
def partitions(n, x):
    if n == 1:
        return [[x]]
    all = []
    for i in range(1, x-n+2):
        sub = partitions(n-1, x-i)
        all.extend([tuple(list(s) + [i]) for s in sub])
    return all

# Generate all possible paritions for set of xs
def all_partitions(n, xs):
    all = []
    for x in xs:
        all.extend(partitions(n, x))
    return all

## Method 1: Decision Tree

In [None]:
import itertools
from collections import namedtuple, deque

In [None]:
#@title Two Princes Decision Tree
def two_princes_decision_tree(a, b, s1, s2, stats={}):
    stats['ops'] = 0
    if s1 > s2:
        s1 = s2, s2 = s1
    d = s2 - s1
    questions = 1
    for k in itertools.count(start=1):
        stats['ops'] += 1

        # Elimination of Prince A
        stats['ops'] += 1
        if b <= (k-1)*d:
            # A computes a = s1-b
            break
        stats['ops'] += 1
        if b >= s1 - (k-1)*d:
            # A computes a = s2-b
            break
        # B concludes b in ((k-1)d, s1-(k-1)d)
        questions += 1

        # Elimination of Prince B
        stats['ops'] += 1
        if a <= k*d:
            # B computes b = s1-a
            break
        stats['ops'] += 1
        if a >= s1 - (k-1)*d:
            # B computes b = s2-a
            break
        # A concludes a in (kd, s1-(k-1)d)
        questions += 1
    return questions

In [None]:
#@title All Two Princes Decision Tree
def all_two_princes_decision_tree(s1, s2):
    if s1 > s2:
        s1 = s2, s2 = s1
    d = s2 - s1
    Princes = namedtuple('Princes', ['a', 'b'])
    s1_sums = deque([Princes(i, s1-i) for i in range(1, s1)])
    s2_sums = deque([Princes(i, s2-i) for i in range(1, s2)])
    all_questions = {}
    questions = 1
    for k in itertools.count(start=1):
        # Elimination of Prince A
        while s1_sums and s1_sums[-1].b <= (k-1)*d:
            all_questions[s1_sums.pop()] = questions
        while s2_sums and s2_sums[-1].b <= (k-1)*d:
            all_questions[s2_sums.pop()] = questions
        while s1_sums and s1_sums[0].b >= s1 - (k-1)*d:
            all_questions[s1_sums.popleft()] = questions
        while s2_sums and s2_sums[0].b >= s1 - (k-1)*d:
            all_questions[s2_sums.popleft()] = questions
        questions += 1

        # Elimination of Prince B
        while s1_sums and s1_sums[0].a <= k*d:
            all_questions[s1_sums.popleft()] = questions
        while s2_sums and s2_sums[0].a <= k*d:
            all_questions[s2_sums.popleft()] = questions
        while s1_sums and s1_sums[-1].a >= s1 - (k-1)*d:
            all_questions[s1_sums.pop()] = questions
        while s2_sums and s2_sums[-1].a >= s1 - (k-1)*d:
            all_questions[s2_sums.pop()] = questions
        questions += 1

        if not s1_sums and not s2_sums:
            break
    return all_questions

## Method 2: Questions

In [None]:
import math

In [None]:
#@title Two Princes Questions
def two_princes_questions(a, b, s1, s2, stats={}):
    stats['ops'] = 0
    if s1 > s2:
        s1 = s2, s2 = s1
    d = s2 - s1
    stats['ops'] += 1
    if b >= s1:
        return 1
    stats['ops'] += 1
    if a >= s1:
        return 2
    stats['ops'] += 3
    ka = min(math.ceil(b/d)+1, math.ceil((s1-b)/d)+1)
    kb = min(math.ceil(a/d), math.ceil((s1-a)/d)+1)
    questions = min(2*ka-1, 2*kb)
    return questions

In [None]:
#@title All Two Princes Questions
def all_two_princes_questions(s1, s2):
    all_pairs = all_partitions(2, [s1, s2])
    return {p:two_princes_questions(p[0], p[1], s1, s2) for p in all_pairs}

## Method 3: Conway-Peterson Elimination

### Two Princes

In [None]:
#@title Two Princes Elimination
def two_princes_elimination(a, b, s1, s2, stats={}):
    stats['ops'] = 0
    all_pairs = set(all_partitions(2, [s1, s2]))
    questions = 1
    while True:
        stats['ops'] += 1

        # Elimination of Prince A
        elim_pairs_a = set()
        for (p, q) in all_pairs:
            stats['ops'] += 1

            stats['ops'] += 1 + (3 if p + q == s1 else 0)
            if p + q == s1 and (p == s2-q or (s2-q, q) not in all_pairs):
                stats['ops'] += 2
                elim_pairs_a.add((p, q))
                stats['ops'] += 1 + (1 if p == a else 0)
                if (p, q) == (a, b):
                    break

            stats['ops'] += 1 + (3 if p + q == s2 else 0)
            if p + q == s2 and (p == s1-q or (s1-q, q) not in all_pairs):
                stats['ops'] += 2
                elim_pairs_a.add((p, q))
                stats['ops'] += 1 + (1 if p == a else 0)
                if (p, q) == (a, b):
                    break

        stats['ops'] += 3
        if (a, b) in elim_pairs_a:
            break
        stats['ops'] += 2*len(elim_pairs_a)
        all_pairs -= elim_pairs_a
        questions += 1

        # Elimination of Prince B
        elim_pairs_b = set()
        for (p, q) in all_pairs:
            stats['ops'] += 1

            stats['ops'] += 1 + (3 if p + q == s1 else 0)
            if p + q == s1 and (q == s2-p or (p, s2-p) not in all_pairs):
                stats['ops'] += 2
                elim_pairs_b.add((p, q))
                stats['ops'] += 1 + (1 if p == a else 0)
                if (p, q) == (a, b):
                    break

            stats['ops'] += 1 + (3 if p + q == s2 else 0)
            if p + q == s2 and (q == s1-p or (p, s1-p) not in all_pairs):
                stats['ops'] += 2
                elim_pairs_b.add((p, q))
                stats['ops'] += 1 + (1 if p == a else 0)
                if (p, q) == (a, b):
                    break

        stats['ops'] += 3
        if (a, b) in elim_pairs_b:
            break
        stats['ops'] += 2*len(elim_pairs_b)
        all_pairs -= elim_pairs_b
        questions += 1
    return questions

In [None]:
#@title All Two Princes Elimination
def all_two_princes_elimination(s1, s2):
    all_pairs = set(all_partitions(2, [s1, s2]))
    all_questions = {}
    questions = 1
    while all_pairs:
        # Elimination of Prince A
        elim_pairs_a = set()
        for (p, q) in all_pairs:
            if p + q == s1 and (p == s2-q or (s2-q, q) not in all_pairs):
                elim_pairs_a.add((p, q))
            if p + q == s2 and (p == s1-q or (s1-q, q) not in all_pairs):
                elim_pairs_a.add((p, q))
        all_pairs -= elim_pairs_a
        all_questions.update({p:questions for p in elim_pairs_a})
        questions += 1

        # Elimination of Prince B
        elim_pairs_b = set()
        for (p, q) in all_pairs:
            if p + q == s1 and (q == s2-p or (p, s2-p) not in all_pairs):
                elim_pairs_b.add((p, q))
            if p + q == s2 and (q == s1-p or (p, s1-p) not in all_pairs):
                elim_pairs_b.add((p, q))
        all_pairs -= elim_pairs_b
        all_questions.update({p:questions for p in elim_pairs_b})
        questions += 1
    return all_questions

### N Princes

In [None]:
def get_elim_tuples(sums, all, index):
    elim_tuples = set()
    for p in all:
        total = sum(p) - p[index]
        twins_count = 0
        for s in sums:
            if s-total == p[index]:
                continue
            q = tuple(list(p[:index]) + [s-total] + list(p[index+1:]))
            twins_count += q in all
        if not twins_count:
            elim_tuples.add(p)
    return elim_tuples

In [None]:
#@title N Princes Elimination
def n_princes_elimination(a, n, sums):
    if len(sums) > n:
        return -1
    all_tuples = set(all_partitions(n, sums))
    questions = 1
    while True:
        eliminated = False
        for i in range(n):
            elim_tuples = get_elim_tuples(sums, all_tuples, i)
            all_tuples -= elim_tuples
            if a in elim_tuples:
                eliminated = True
                break
            questions += 1
        if eliminated:
            break
    return questions

In [None]:
#@title All N Princes Elimination
def all_n_princes_elimination(n, sums):
    if len(sums) > n:
        return {p:-1 for p in all_partitions(n, sums)}
    all_tuples = set(all_partitions(n, sums))
    all_questions = {}
    questions = 1
    while all_tuples:
        for i in range(n):
            elim_tuples = get_elim_tuples(sums, all_tuples, i)
            all_tuples -= elim_tuples
            all_questions.update({p:questions for p in elim_tuples})
            questions += 1
    return all_questions

## Method 4: Visualization

In [None]:
%pip install ipympl

In [None]:
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from ipywidgets import interactive
import ipywidgets as widgets
from IPython.display import display, HTML
import ipympl

from google.colab import output, files
output.enable_custom_widget_manager()

In [None]:
from abc import ABC, abstractmethod

class Visualization:
    def __init__(self, *args, **kwargs):
        self.init_widgets(*args, **kwargs)
        self.init_plot()
        self.display_widgets()

    @abstractmethod
    def init_widgets(self, *args, **kwargs):
        pass

    @abstractmethod
    def init_plot(self):
        pass

    @abstractmethod
    def setup_plot(self):
        pass

    @abstractmethod
    def draw_plot(self, *args, **kwargs):
        pass

    @abstractmethod
    def display_widgets(self):
        pass

In [None]:
#@title 2d Visualization
%matplotlib widget
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

class Visualization2d(Visualization):
    def __init__(self, s1, s2, limit):
        super().__init__(s1, s2, limit)

    def init_widgets(self, s1, s2, limit):
        self.limit = limit
        self.s1_slider = widgets.IntSlider(
            description='s1', value=s1,
            min=1, max=limit
        )
        self.s2_slider = widgets.IntSlider(
            description='s2', value=s2,
            min=1, max=limit
        )
        self.step_slider = widgets.IntSlider(description='step')
        self.play_cntrl = widgets.Play(interval=500)
        widgets.link((self.step_slider, 'value'), (self.play_cntrl, 'value'))
        widgets.link((self.step_slider, 'max'), (self.play_cntrl, 'max'))

    def init_plot(self):
        plt.close('all') # only one viz at a time
        self.fig, self.ax = plt.subplots()

    def update(self, s1, s2):
        questions = all_two_princes_elimination(s1, s2)
        rounds = max(questions.values())
        self.points = [[] for _ in range(rounds+1)]
        for i in range(rounds+1):
            self.points[i] = [p for p,q in questions.items() if q > i]
        self.step_slider.max = rounds

    def setup_plot(self):
        self.ax.cla()
        self.ax.set_title('Prince and Princess 2d')
        self.ax.set_xlabel('Prince $A$')
        self.ax.set_ylabel('Prince $B$')
        self.ax.set_xlim(0, self.limit)
        self.ax.set_ylim(0, self.limit)
        self.ax.set_xticks(np.arange(self.limit+1))
        self.ax.set_yticks(np.arange(self.limit+1))
        self.ax.grid(True, ls=':', lw=0.5, c='gray')

    def draw_plot(self, s1, s2, step):
        self.setup_plot()
        p1 = np.asarray(list(filter(lambda p: sum(p) == s1, self.points[step])))
        p2 = np.asarray(list(filter(lambda p: sum(p) == s2, self.points[step])))
        if np.any(p1):
            self.ax.scatter(
                x=p1[:,0], y=p1[:,1],
                c='b', s=72, label='$S_1$',
                marker='o', zorder=2
            )
        if np.any(p2):
            self.ax.scatter(
                x=p2[:,0], y=p2[:,1],
                c='r', s=72, label='$S_2$',
                marker='s', zorder=2
            )
        if any([np.any(p1), np.any(p2)]):
            self.ax.legend(fontsize=16)
        self.fig.canvas.draw()

    def display_widgets(self):
        self.update(self.s1_slider.value, self.s2_slider.value)
        interactive(self.update, s1=self.s1_slider, s2=self.s2_slider)
        display(widgets.VBox([
            interactive(
                self.draw_plot,
                s1=self.s1_slider,
                s2=self.s2_slider,
                step=self.step_slider,
            ),
            self.play_cntrl,
        ]))

Visualization2d(8, 11, 11);

In [None]:
#@title 3d Visualization
%matplotlib widget
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

class Visualization3d(Visualization):
    def __init__(self, s1, s2, s3, limit):
        super().__init__(s1, s2, s3, limit)

    def init_widgets(self, s1, s2, s3, limit):
        self.limit = limit
        self.s1_slider = widgets.IntSlider(
            description='s1', value=s1,
            min=1, max=limit
        )
        self.s2_slider = widgets.IntSlider(
            description='s2', value=s2,
            min=1, max=limit
        )
        self.s3_slider = widgets.IntSlider(
            description='s3', value=s3,
            min=1, max=limit
        )
        self.step_slider = widgets.IntSlider(description='step')
        self.play_cntrl = widgets.Play(interval=500)
        widgets.link((self.step_slider, 'value'), (self.play_cntrl, 'value'))
        widgets.link((self.step_slider, 'max'), (self.play_cntrl, 'max'))

    def init_plot(self):
        plt.close('all') # only one viz at a time
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(projection='3d')

    def update(self, s1, s2, s3):
        questions = all_n_princes_elimination(3, (s1, s2, s3))
        rounds = max(questions.values())
        self.points = [[] for _ in range(rounds+1)]
        for i in range(rounds+1):
            self.points[i] = [p for p,q in questions.items() if q > i]
        self.step_slider.max = rounds
        self.step_slider.value = min(self.step_slider.value, rounds)

    def setup_plot(self):
        self.ax.cla()
        self.ax.set_title('Prince and Princess 3d')
        self.ax.set_xlabel('Prince $A$')
        self.ax.set_ylabel('Prince $B$')
        self.ax.set_zlabel('Prince $C$')
        self.ax.set_xlim(0, self.limit)
        self.ax.set_ylim(0, self.limit)
        self.ax.set_zlim(0, self.limit)
        self.ax.set_xticks(np.arange(self.limit+1))
        self.ax.set_yticks(np.arange(self.limit+1))
        self.ax.set_zticks(np.arange(self.limit+1))
        self.ax.grid(True, ls=':', lw=0.5, c='gray')

    def draw_plot(self, s1, s2, s3, step):
        p1 = np.asarray(list(filter(lambda p: np.sum(p) == s1, self.points[step])))
        p2 = np.asarray(list(filter(lambda p: np.sum(p) == s2, self.points[step])))
        p3 = np.asarray(list(filter(lambda p: np.sum(p) == s3, self.points[step])))
        self.setup_plot()
        if np.any(p1):
            self.ax.scatter3D(
                p1[:,0], p1[:,1], p1[:,2],
                c='b', s=48, marker='o', label='$S_1$'
            )
        if np.any(p2):
            self.ax.scatter3D(
                p2[:,0], p2[:,1], p2[:,2],
                c='r', s=48, marker='s', label='$S_2$'
            )
        if np.any(p3):
            self.ax.scatter3D(
                p3[:,0], p3[:,1], p3[:,2],
                c='g', s=48, marker='^', label='$S_3$'
            )
        if any([np.any(p1), np.any(p2), np.any(p3)]):
            self.ax.legend(fontsize=12)
        self.fig.canvas.draw()

    def display_widgets(self):
        self.update(self.s1_slider.value, self.s2_slider.value, self.s3_slider.value)
        interactive(self.update, s1=self.s1_slider, s2=self.s2_slider, s3=self.s3_slider)
        display(widgets.VBox([
            interactive(
                self.draw_plot,
                s1=self.s1_slider,
                s2=self.s2_slider,
                s3=self.s3_slider,
                step=self.step_slider,
            ),
            self.play_cntrl,
        ]))

Visualization3d(6, 8, 11, 12);

## Testing

In [None]:
#@title Test Suite

# Compute difference between two sets of results X, Y
def diff_pairs(X, Y):
    pairs = set(X.keys()).union(Y.keys())
    diff = {}
    for p in pairs:
        x = X.get(p, 0)
        y = Y.get(p, 0)
        if x != y:
            diff[p] = (x, y)
    return diff

# Test algorithm f1 against f2, for one pair s1, s2
def test_case(f1, f2, s1, s2, *, details=True):
    diff = diff_pairs(f1(s1, s2), f2(s1, s2))
    if details:
        if diff:
            print(f'Failing s1={s1}, s2={s2} ❌')
            print('----------------------------')
            for ((a, b), (x, y)) in diff.items():
                print(f'a={a}\tb={b}\tf1={x}\tf2={y}')
        else:
            print(f'Passing s1={s1}, s2={s2} ✔️')
    return not diff

# Test algorithm f1 against f2, for many pairs s1, s2
def test_suite(f1, f2, S=200, *, details=True):
    verdicts = [
        test_case(f1, f2, s1, s2, details=details)
        for s1 in range(1, S)
        for s2 in range(s1+1, S+1)
    ]
    all, miss = 0, 0
    for result in verdicts:
        all += 1
        miss += not result
    if miss:
        print(f'Failed {miss}/{all} test cases ❌\n')
    else:
        print(f'Passed {all}/{all} test cases ✔️\n')

In [None]:
#@title Decision Tree v. Questions
%%time
test_suite(
    all_two_princes_decision_tree,
    all_two_princes_questions,
    S=400, details=False
)


Passed 79800/79800 test cases ✔️

CPU times: user 1min 49s, sys: 244 ms, total: 1min 50s
Wall time: 1min 51s


In [None]:
#@title Decision Tree v. Two Princes Elimination
%%time
test_suite(
    all_two_princes_decision_tree,
    all_two_princes_elimination,
    S=400, details=False
)

Passed 79800/79800 test cases ✔️

CPU times: user 2min 18s, sys: 343 ms, total: 2min 18s
Wall time: 2min 19s


In [None]:
#@title Two Princes Elimination v. N Princes Elimination
%%time
test_suite(
    all_two_princes_elimination,
    lambda s1, s2: all_n_princes_elimination(2, (s1, s2)),
    S=400, details=False
)

Passed 79800/79800 test cases ✔️

CPU times: user 5min 14s, sys: 564 ms, total: 5min 15s
Wall time: 5min 19s


## Performance Analysis

In [None]:
def get_ops(func):
    def wrapper(a, b, s1, s2):
        stats = {}
        func(a, b, s1, s2, stats)
        return stats['ops']
    return wrapper

two_princes_decision_tree_ops = get_ops(two_princes_decision_tree)
two_princes_questions_ops = get_ops(two_princes_questions)
two_princes_elimination_ops = get_ops(two_princes_elimination)

In [None]:
#@title Experiment 1 & 2: Vary a and b
%matplotlib widget
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

class PerformancePlotAB(Visualization):
    def __init__(self, s1, s2, limit):
        super().__init__(s1, s2, limit)

    def init_widgets(self, s1, s2, limit):
        self.s1_slider = widgets.IntSlider(
            description='s1', value=s1,
            min=2, max=limit
        )
        self.s2_slider = widgets.IntSlider(
            description='s2', value=s2,
            min=2, max=limit
        )

    def init_plot(self):
        plt.close('all')
        self.fig, self.ax = plt.subplots()

    def setup_plot(self):
        s1, s2 = self.s1_slider.value, self.s2_slider.value
        self.ax.cla()
        self.ax.set_title(f'Performance comparison ($a+b=S_1={s1}$, $S_2={s2}$)')
        self.ax.set_xlabel('$a$')
        self.ax.set_ylabel('Number of comparisons')
        self.ax.set_yscale('log', base=2)
        self.ax.grid(True, ls=':', lw=0.2, c='gray')

    def draw_plot(self, s1, s2):
        self.setup_plot()
        a = np.arange(1, s1)
        elimination_ops = np.empty((len(a,)), dtype=np.int32)
        decision_ops = np.empty((len(a,)), dtype=np.int32)
        questions_ops = np.empty((len(a,)), dtype=np.int32)

        for i in range(len(a)):
            elimination_ops[i] = two_princes_elimination_ops(a[i], s1-a[i], s1, s2)
            decision_ops[i] = two_princes_decision_tree_ops(a[i], s1-a[i], s1, s2)
            questions_ops[i] = two_princes_questions_ops(a[i], s1-a[i], s1, s2)

        max_pw = int(np.ceil(np.log2(max(elimination_ops)))) + 1
        self.ax.set_xticks([1] + list(range(10, s1-1, 10)) + [s1-1])
        self.ax.set_yticks([1<<i for i in range(0, max_pw)])

        self.ax.plot(a, elimination_ops, c='r', ls=(0, (5, 1)), lw=2., label='Elimination')
        self.ax.plot(a, decision_ops, c='b', ls=(5, (10, 3)), lw=2., label='Decision tree')
        self.ax.plot(a, questions_ops, c='g', ls=(0, ()), lw=2., label='Questions')
        self.ax.legend(bbox_to_anchor=(0.5, 0.6), loc='center', borderaxespad=0)
        self.fig.canvas.draw()

    def display_widgets(self):
        self.draw_plot(self.s1_slider.value, self.s2_slider.value)
        display(interactive(
            self.draw_plot,
            {
                'manual': True,
                'manual_name': 'Run',
            },
            s1=self.s1_slider,
            s2=self.s2_slider
        ))

PerformancePlotAB(100, 101, 200);

In [None]:
#@title Experiment 3: Vary S1
%matplotlib widget
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

class PerformancePlotS1(Visualization):
    def __init__(self, s2, limit):
        super().__init__(s2, limit)

    def init_widgets(self, s2, limit):
        self.s2_slider = widgets.IntSlider(
            description='s2', value=s2,
            min=2, max=limit
        )

    def init_plot(self):
        plt.close('all')
        self.fig, self.ax = plt.subplots()

    def setup_plot(self):
        self.ax.cla()
        self.ax.set_title(f'Performance comparison ($S_2={self.s2_slider.value}$)')
        self.ax.set_xlabel('$S_1$')
        self.ax.set_ylabel('Number of comparisons')
        self.ax.set_yscale('log', base=2)
        self.ax.grid(True, ls=':', lw=0.2, c='gray')

    def draw_plot(self, s2):
        self.setup_plot()
        s1 = np.arange(2, s2)
        elimination_ops = np.empty((len(s1,)), dtype=np.int32)
        decision_ops = np.empty((len(s1,)), dtype=np.int32)
        questions_ops = np.empty((len(s1,)), dtype=np.int32)

        for i in range(len(s1)):
            a, b = int(np.floor(s1[i] / 2)), int(np.ceil(s1[i] / 2))
            elimination_ops[i] = two_princes_elimination_ops(a, b, s1[i], s2)
            decision_ops[i] = two_princes_decision_tree_ops(a, b, s1[i], s2)
            questions_ops[i] = two_princes_questions_ops(a, b, s1[i], s2)

        max_pw = int(np.ceil(np.log2(elimination_ops[-1]))) + 1
        self.ax.set_xticks([2] + list(range(10, s2-1, 10)) + [s2-1])
        self.ax.set_yticks([1<<i for i in range(0, max_pw)])

        self.ax.plot(s1, elimination_ops, c='r', ls=(0, (5, 1)), lw=2., label='Elimination')
        self.ax.plot(s1, decision_ops, c='b', ls=(5, (10, 3)), lw=2., label='Decision tree')
        self.ax.plot(s1, questions_ops, c='g', ls=(0, ()), lw=2., label='Questions')
        self.ax.axvline(x=s2, c='black', ls=':', lw=3.)

        self.ax.legend(loc='upper left')
        self.fig.canvas.draw()

    def display_widgets(self):
        self.draw_plot(self.s2_slider.value)
        display(interactive(
            self.draw_plot,
            {
                'manual': True,
                'manual_name': 'Run',
            },
            s2=self.s2_slider
        ))

PerformancePlotS1(100, 200);

In [None]:
#@title Experiment 4: Vary S2
%matplotlib widget
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

class PerformancePlotS2(Visualization):
    def __init__(self, s1, limit):
        super().__init__(s1, limit)

    def init_widgets(self, s1, limit):
        self.s1_slider = widgets.IntSlider(
            description='s1', value=s1,
            min=2, max=limit
        )

    def init_plot(self):
        plt.close('all')
        self.fig, self.ax = plt.subplots()

    def setup_plot(self):
        self.ax.cla()
        self.ax.set_title(f'Performance comparison ($S_1={self.s1_slider.value}$)')
        self.ax.set_xlabel('$S_2$')
        self.ax.set_ylabel('Number of comparisons')
        self.ax.set_yscale('log', base=2)
        self.ax.grid(True, ls=':', lw=0.2, c='gray')

    def draw_plot(self, s1):
        self.setup_plot()
        s2 = np.arange(s1+1, 2*s1+1)
        elimination_ops = np.empty((len(s2,)), dtype=np.int32)
        decision_ops = np.empty((len(s2,)), dtype=np.int32)
        questions_ops = np.empty((len(s2,)), dtype=np.int32)

        for i in range(len(s2)):
            a, b = int(np.floor(s1 / 2)), int(np.ceil(s1 / 2))
            elimination_ops[i] = two_princes_elimination_ops(a, b, s1, s2[i])
            decision_ops[i] = two_princes_decision_tree_ops(a, b, s1, s2[i])
            questions_ops[i] = two_princes_questions_ops(a, b, s1, s2[i])

        max_pw = int(np.ceil(np.log2(elimination_ops[0]))) + 1
        self.ax.set_xticks([s1+1] + list(range(10*((s1+12)//10), 2*s1, 10)) + [2*s1])
        self.ax.set_yticks([1<<i for i in range(0, max_pw)])

        self.ax.plot(s2, elimination_ops, c='r', ls=(0, (5, 1)), lw=2., label='Elimination')
        self.ax.plot(s2, decision_ops, c='b', ls=(5, (10, 3)), lw=2., label='Decision tree')
        self.ax.plot(s2, questions_ops, c='g', ls=(0, ()), lw=2., label='Questions')
        self.ax.axvline(x=s1, c='black', ls=':', lw=3.)

        self.ax.legend(loc='upper right')
        self.fig.canvas.draw()

    def display_widgets(self):
        self.draw_plot(self.s1_slider.value)
        display(interactive(
            self.draw_plot,
            {
                'manual': True,
                'manual_name': 'Run',
            },
            s1=self.s1_slider
        ))

PerformancePlotS2(100, 200);