Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Maximum likelihood estimatation from observed and unobserved data**\n", | ||
"\n", | ||
"You are given a bag containing red and blue coins. All the red coins have the same probability of heads. All the blue coins have the same probability of heads (possibly different from that of the red coins).\n", | ||
"\n", | ||
"Your task is to estimate the proportion of red coins in the bag and the probability of heads for both the red and the blue coin." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "49e2d4bbdfc84d28b6aca607243af6c8", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"FloatSlider(value=0.0, description='prob_red', max=1.0)" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "c4907bb473ee4fbb87fa47244a331f52", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"FloatSlider(value=0.0, description='head_red', max=1.0)" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e3f9c2542f254da782073847a43d67bc", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"FloatSlider(value=0.0, description='head_blue', max=1.0)" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"import ipywidgets as widgets\n", | ||
"prob_red = widgets.FloatSlider(min=0.0, max=1.0, description='prob_red')\n", | ||
"prob_head_red = widgets.FloatSlider(min=0.0, max=1.0, description='head_red')\n", | ||
"prob_head_blue = widgets.FloatSlider(min=0.0, max=1.0, description='head_blue')\n", | ||
"display(prob_red, prob_head_red, prob_head_blue)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Use these widgets to control the model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"def choose_coin():\n", | ||
" return 'R' if random.random() < prob_red.value else 'B'\n", | ||
"\n", | ||
"def flip_coin(coin):\n", | ||
" uar = random.random()\n", | ||
" if coin == 'R':\n", | ||
" if uar < prob_head_red.value:\n", | ||
" return 'H'\n", | ||
" elif uar < prob_head_blue.value:\n", | ||
" return 'H'\n", | ||
" return 'T'\n", | ||
"\n", | ||
"def flip_random_coin_n_times(n, hidden=False):\n", | ||
" coin = choose_coin()\n", | ||
" return ('_' if hidden else coin, ''.join([flip_coin(coin) for i in range(n)]))\n", | ||
"\n", | ||
"def flip_m_random_coins_n_times(m, n, hidden=False):\n", | ||
" return [flip_random_coin_n_times(n, hidden) for i in range(m)]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Use the above methods to sample from the model. The optional parameter 'hidden' controls whether the colour of the coin is observed in the samples." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[('B',\n", | ||
" 'TTTTTHTHTHTTTTTTTTTTTTTTTHTTTTTTTTTTTHTTTHTHTTTTTTTTHTHTTHTTTHHTTTTHTTTTHTTTHTTTTTHTTTTTHTTTTTTTHTTT'),\n", | ||
" ('B',\n", | ||
" 'TTTTTHHHHTHTTHTHTTTTTHTTTTTTTTTTTHTTTTTTTHTHTTHHTHTHHTTHTTTTHHHHTHTTHTTTHHTTTTTHTTTTTHTTTTTHHTTTHTTT'),\n", | ||
" ('B',\n", | ||
" 'HTHTTHTTTTTTTHTTTHTTHTTTTHHHTTHTTHHTHHHTHTTTTTHTTTTHTTTTTTTHHHTTTTTTTTHTTTTTTHHHTTHHHTTTTTTTHTHTTTTT'),\n", | ||
" ('B',\n", | ||
" 'TTHTTTHTTHTTTTTTTTHTTTTTHTTTTHHTTTTHTHTTHTTTTTTHHHHTTTTTTTHTTTHTTTTTTTTTHHHHTHTHTTTTTHHTHTTTHHHTTHTT'),\n", | ||
" ('R',\n", | ||
" 'THHHTHTTHTTTHHHHHHTHHTTTHTTHHHHTHHTHTHHHHTTHHHHTHTHHHHHTHHTTTTHHHHHTHHHHTHTTHHHTHHHTHHHHHHTHTHHTTTHH')]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"flip_m_random_coins_n_times(5, 100)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[('_',\n", | ||
" 'TTHTHHTTHHHTTTTTTTTHTHTHTTTHTTTTHHHHTTTTTTTTTTHTTHTTTTTTTTHTTTHTTTHTHTTTHTTTTHTTTTTTHTTHTHHHTTHTHTHH'),\n", | ||
" ('_',\n", | ||
" 'HHHTTTHHHHHHTHTHHTHHTHHHTTHHHHHHTHHHHTTTHHHTHHHTTHTHHHHHTHTHHHHHHHHHTHHHHHHHHHHTTHHHTHHHHTHHHHHHTHHH'),\n", | ||
" ('_',\n", | ||
" 'HTTTHHTTTTTTTHTTTTHHTTTHTHTTTTTHTTHTHTTTTTHTTHTHTTHHHTHTTHTHHTTTHHHTTHTHHTHHTTTHTHHTTHTHHHTHTTTHTTTH'),\n", | ||
" ('_',\n", | ||
" 'TTTTTHTTTTHTTHTHHHTTHHTTTTTTTHHTHHTTHTTHTHTTHHHTHTTHTTTTTHTHHTTTHHTTTHTTHTTHTTTHHTTHTTTTHTTTTTTTTHTH'),\n", | ||
" ('_',\n", | ||
" 'HTTHHHHHHHTTTHHTHHHHHHTTTTHHTHHHHHHHHTHHHHHHTHHHTHHHHHHHTTHHTHHHHHHTHHHTHHHHTTHTHHTHTHTTHHHHTHTHHTHH')]" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"flip_m_random_coins_n_times(5, 100, hidden=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**TASK 1** Implement the following two functions to estimate parameters for the model in the observed case. Splitting the work into two separate functions will simplify things for the next task. \n", | ||
"\n", | ||
"* How could you measure the error in your estimates?\n", | ||
"* How does the error decrease with the sample size?\n", | ||
"* If you were only allowed to flip coins a total of N times how would you choose m (the number of coins) and n the number of times to flip each coin? Why?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute_sufficient_statistics(samples):\n", | ||
" total = len(samples) * len(samples[0][1])\n", | ||
" count_red = sum([len(sample[1]) for sample in samples if sample[0] == 'R']) \n", | ||
" count_red_head = sum([sample[1].count('H') for sample in samples if sample[0] == 'R'])\n", | ||
" count_blue_head = sum([sample[1].count('H') for sample in samples if sample[0] == 'B'])\n", | ||
" return total, count_red, count_red_head, count_blue_head\n", | ||
"\n", | ||
"def mle(total, count_red, count_red_head, count_blue_head):\n", | ||
" estimate_prob_red = count_red / total\n", | ||
" estimate_prob_head_red = count_red_head / count_red\n", | ||
" estimate_prob_head_blue = count_blue_head / (total - count_red)\n", | ||
" return estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.304 0.20045723684210526 0.3996954022988506\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"samples = flip_m_random_coins_n_times(10000, 100)\n", | ||
"estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue = mle(*compute_sufficient_statistics(samples))\n", | ||
"print(estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**TASK 2** Given a sample from a single coin whose colour is unobserved, estimate the posterior probability that the coin is red, given some estimates of the model parameters.\n", | ||
"\n", | ||
"* If you pass in the true model parameters (e.g. prob_red.value, prob_head_red.value and prob_head_blue.value), how quickly does the posterior change? Use the plot_distribution function to view this.\n", | ||
"* How does this depend on the model parameters?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 61, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute_posterior_prob_red(sample, estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue):\n", | ||
" count_head = sample.count('H')\n", | ||
" count_tail = len(sample) - count_head\n", | ||
" joint_red = estimate_prob_red * estimate_prob_head_red**count_head * (1 - estimate_prob_head_red)**count_tail\n", | ||
" joint_blue = (1 - estimate_prob_red) * estimate_prob_head_blue**count_head * (1 - estimate_prob_head_blue)**count_tail\n", | ||
" return joint_red / (joint_red + joint_blue)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**TASK 3** Reusing your code from Tasks 1 and 2, implement expectation maximization algorithm to find a (locally optimal) solution to the parameters when the colour of the coins is not observed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 67, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute_expected_statistics(samples, estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue):\n", | ||
" total, expected_count_red, expected_count_red_head, expected_count_blue_head = 0, 0.0, 0.0, 0.0\n", | ||
" for sample in samples:\n", | ||
" total += len(sample[1])\n", | ||
" posterior_prob_red = compute_posterior_prob_red(sample[1], estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue)\n", | ||
" expected_count_red += posterior_prob_red * len(sample[1])\n", | ||
" expected_count_red_head += posterior_prob_red * sample[1].count('H')\n", | ||
" expected_count_blue_head += (1 - posterior_prob_red) * sample[1].count('H')\n", | ||
" return total, expected_count_red, expected_count_red_head, expected_count_blue_head\n", | ||
"\n", | ||
"def expectation_maximization(samples, estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue):\n", | ||
" for i in range(10):\n", | ||
" total, expected_count_red, expected_count_red_head, expected_count_blue_head = compute_expected_statistics(\n", | ||
" samples, estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue)\n", | ||
" estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue = mle(\n", | ||
" total, expected_count_red, expected_count_red_head, expected_count_blue_head)\n", | ||
" print(estimate_prob_red, estimate_prob_head_red, estimate_prob_head_blue)\n", | ||
" \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 78, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.555294461770073 0.6019919515493285 0.3951315829659137\n", | ||
"0.3266628264755212 0.699974269087489 0.4178358745991849\n", | ||
"0.2999963891190526 0.719998947269077 0.4200019986750612\n", | ||
"0.2999832413939155 0.7200037090276086 0.4200055926313944\n", | ||
"0.2999832315266686 0.7200037114282721 0.4200055958313131\n", | ||
"0.29998323151917455 0.7200037114300724 0.4200055958337533\n", | ||
"0.29998323151916884 0.7200037114300739 0.4200055958337552\n", | ||
"0.29998323151916884 0.7200037114300738 0.4200055958337553\n", | ||
"0.29998323151916884 0.7200037114300738 0.4200055958337553\n", | ||
"0.29998323151916884 0.7200037114300738 0.4200055958337553\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"samples = flip_m_random_coins_n_times(10, 100, hidden=True)\n", | ||
"expectation_maximization(samples, 0.5, 0.7, 0.2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |