diff --git a/dowhy/api/causal_data_frame.py b/dowhy/api/causal_data_frame.py index 187c8575a..e4282dd84 100644 --- a/dowhy/api/causal_data_frame.py +++ b/dowhy/api/causal_data_frame.py @@ -27,10 +27,10 @@ def __init__(self, pandas_obj): def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, params=None, dot_graph=None, common_causes=None, instruments=None, estimand_type='ate', proceed_when_unidentifiable=False, - keep_original_treatment=False): + keep_original_treatment=False, use_previous_sampler=False): if not method: raise Exception("You must specify a do sampling method.") - if not self._obj._causal_model: + if not self._obj._causal_model or not use_previous_sampler: self._obj._causal_model = CausalModel(self._obj, [xi for xi in x.keys()][0], outcome, @@ -41,7 +41,7 @@ def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, param proceed_when_unidentifiable=proceed_when_unidentifiable) self._obj._identified_estimand = self._obj._causal_model.identify_effect() do_sampler_class = do_samplers.get_class_object(method + "_sampler") - if not self._obj._sampler: + if not self._obj._sampler or not use_previous_sampler: self._obj._sampler = do_sampler_class(self._obj, self._obj._identified_estimand, self._obj._causal_model._treatment, diff --git a/dowhy/do_samplers/mcmc_sampler.py b/dowhy/do_samplers/mcmc_sampler.py index 0563ce2d6..4ca747128 100644 --- a/dowhy/do_samplers/mcmc_sampler.py +++ b/dowhy/do_samplers/mcmc_sampler.py @@ -109,25 +109,16 @@ def make_intervention_effective(self, x): def do_sample(self, x): self.reset() - print(self._df.sample(10)) g_for_surgery = nx.DiGraph(self.g) g_modified = self.do_x_surgery(g_for_surgery, x) - print(self._df.sample(10)) - self._df = self.make_intervention_effective(x) - print(self._df.sample(10)) - g_modified, trace = self.sample_prior_causal_model(g_modified, self._df, self._variable_types, initialization_trace=self.fit_trace) - print(self._df.sample(10)) - for col in self._df: if col in trace and col not in self._treatment_names: self._df[col] = trace[col] - print(self._df.sample(10)) - return self._df.copy() def _construct_sampler(self): diff --git a/test mcmc do sampler.ipynb b/test mcmc do sampler.ipynb index 662f034f0..8aa92c330 100644 --- a/test mcmc do sampler.ipynb +++ b/test mcmc do sampler.ipynb @@ -54,7 +54,7 @@ "output_type": "stream", "text": [ "WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n", - "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n", + "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'U', 'X0'}\n", "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" ] }, @@ -66,7 +66,6 @@ "yes\n", "{'observed': 'yes'}\n", "Model to find the causal effect of treatment v on outcome y\n", - "{'observed': 'yes'}\n", "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", "All common causes are observed. Causal effect can be identified.\n", "McmcSampler\n" @@ -95,74 +94,13 @@ "INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n", "INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n", "INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n", - "Sampling 4 chains: 100%|██████████| 8000/8000 [00:05<00:00, 1548.09draws/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " X0 v y\n", - "563 1.974469 1.0 6.563333\n", - "983 2.483977 1.0 5.532828\n", - "512 -0.480378 0.0 0.392772\n", - "253 2.846581 1.0 6.208468\n", - "591 0.408251 1.0 5.124851\n", - "154 0.431255 0.0 -1.102835\n", - "684 -0.006090 0.0 -0.101303\n", - "222 -0.152137 0.0 1.193461\n", - "990 -1.001541 1.0 5.909003\n", - "26 0.955564 0.0 0.405158\n", - " X0 v y\n", - "13 1.211798 1.0 6.665465\n", - "531 1.121144 1.0 7.280435\n", - "487 2.806848 0.0 -0.353503\n", - "585 1.094204 0.0 2.032928\n", - "719 -0.132187 0.0 1.678639\n", - "656 0.212752 0.0 1.018769\n", - "171 0.841910 1.0 3.902883\n", - "994 1.531808 1.0 5.625160\n", - "304 1.198312 1.0 7.197449\n", - "731 1.983407 1.0 6.860088\n", - " X0 v y\n", - "671 2.130590 0.0 0.355311\n", - "107 0.895289 1.0 6.034806\n", - "854 0.866116 1.0 5.903469\n", - "225 -0.338098 0.0 1.529993\n", - "140 2.181477 1.0 6.882432\n", - "268 2.702453 1.0 6.771386\n", - "664 1.734149 1.0 7.530319\n", - "675 0.874949 1.0 5.074571\n", - "307 1.218045 1.0 6.340555\n", - "508 1.494511 1.0 5.604792\n", - " X0 v y\n", - "776 0.880533 1.0 7.209202\n", - "271 1.717009 0.0 1.475974\n", - "337 1.584927 0.0 1.290485\n", - "598 -0.678611 0.0 1.136225\n", - "300 0.390007 0.0 -0.480040\n", - "838 -0.122247 0.0 0.330018\n", - "139 0.500275 1.0 4.837644\n", - "650 0.103088 1.0 4.275985\n", - "339 3.203280 1.0 4.001425\n", - "694 1.727457 1.0 6.033716\n", - " X0 v y\n", - "437 1.936274 1.0 5.576163\n", - "167 0.196694 0.0 0.239224\n", - "348 1.678880 1.0 5.638999\n", - "608 1.565841 1.0 5.771532\n", - "451 0.619074 0.0 0.048471\n", - "50 2.056996 0.0 0.955019\n", - "101 2.986783 1.0 6.269176\n", - "557 2.511029 1.0 5.920370\n", - "749 1.654280 0.0 0.774465\n", - "176 -1.600362 1.0 4.238702\n" + "Sampling 4 chains: 100%|██████████| 8000/8000 [00:05<00:00, 1550.20draws/s]\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -171,7 +109,7 @@ }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAEQCAYAAACQip4+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC9dJREFUeJzt3VGMpXdZx/Hf4+42a2Bjw3aowFKnsVVTTKBmbUgwMaXRFks0JEogqZosZE2UBBKN1jRBvas3pphodBVSGpXVqiRYoEoUQoi0dJetSAsIqUWmATpdIdqLhcI+XnTWbofZztllzrzz3/P5JJPOOeftOc/F5Jv//s973lPdHQDG8T1TDwDA+RFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIPZPY8nveyyy3p5eXkeTw1wUTp+/PgT3b00y7FzCffy8nKOHTs2j6cGuChV1RdnPdZWCcBghBtgMMINMJi57HFv5KmnnsrKykpOnTq1XS95Xvbu3ZsDBw5kz549U48C8Jy2LdwrKyvZt29flpeXU1Xb9bIz6e6cPHkyKysrufLKK6ceB+A5bdtWyalTp7J///4dF+0kqars379/x/5rAOBs27rHvROjfcZOng3gbN6cBBjMtu1xr7d86/u39Pkevf3mLX0+gJ1qsnAD57bVC5tFd7Et7BZmq+Ttb3977rjjjv+/fdttt+Ud73jHhBMBXJiFCfehQ4dy1113JUlOnz6do0eP5pZbbpl4KoDztzBbJcvLy9m/f39OnDiRr371q7n22muzf//+qccCOG8LE+4kefOb35w777wzX/nKV3Lo0KGpxwG4IAuzVZIkr3vd63LvvffmgQceyI033jj1OAAXZLIV9xTv8l5yySW5/vrrc+mll2bXrl3b/voAW2GhtkpOnz6d++67L3fffffUowBcsJm2Sqrq0ar696p6sKqG/Gqbhx9+OFdddVVuuOGGXH311VOPA3DBzmfFfX13PzG3SebsmmuuySOPPDL1GADftW19c7K7t/PlzstOng3gbLOGu5P8U1Udr6rDF/JCe/fuzcmTJ3dkIM9cj3vv3r1TjwKwqVm3Sn6iux+rqhcm+VBVfba7P3r2AWtBP5wkV1xxxXc8wYEDB7KyspLV1dXvdua5OPMNOAA73Uzh7u7H1v77eFW9N8l1ST667pgjSY4kycGDB79jWb1nzx7fLgOwBTbdKqmq51XVvjO/J/npJJ+e92AAbGyWFfflSd679g0xu5P8VXffO9epADinTcPd3Y8kefk2zALADBbqWiUAFwPhBhiMcAMMRrgBBiPcAIMRboDBCDfAYIQbYDDCDTAY4QYYjHADDEa4AQYj3ACDEW6AwQg3wGCEG2Awwg0wGOEGGIxwAwxGuAEGI9wAgxFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBzBzuqtpVVSeq6p55DgTAczufFfdbk3xmXoMAMJuZwl1VB5LcnOTP5zsOAJuZdcV9R5LfTHJ6jrMAMINNw11Vr03yeHcf3+S4w1V1rKqOra6ubtmAADzbLCvuVyX52ap6NMnRJK+uqr9Yf1B3H+nug919cGlpaYvHBOCMTcPd3b/d3Qe6eznJG5L8S3ffMvfJANiQ87gBBrP7fA7u7o8k+chcJgFgJlbcAIMRboDBCDfAYIQbYDDCDTAY4QYYjHADDEa4AQYj3ACDEW6AwQg3wGCEG2Awwg0wGOEGGIxwAwxGuAEGI9wAgxFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBCDfAYIQbYDDCDTAY4QYYzKbhrqq9VfWJqvq3qnqoqn5vOwYDYGO7ZzjmG0le3d1PVtWeJB+rqg92931zng2ADWwa7u7uJE+u3dyz9tPzHAqAc5tpj7uqdlXVg0keT/Kh7r5/vmMBcC4zhbu7v93dr0hyIMl1VfWj64+pqsNVdayqjq2urm71nACsOa+zSrr760k+nOSmDR470t0Hu/vg0tLSVs0HwDqznFWyVFWXrv3+vUl+Ksln5z0YABub5aySFyV5d1XtytOh/5vuvme+YwFwLrOcVfKpJNduwywAzMAnJwEGI9wAgxFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBCDfAYIQbYDDCDTAY4QYYjHADDEa4AQYj3ACDEW6AwQg3wGCEG2Awwg0wGOEGGIxwAwxGuAEGI9wAgxFugMEIN8BghBtgMMINMJhNw11VL62qD1fVw1X1UFW9dTsGA2Bju2c45ltJfr27P1lV+5Icr6oPdffDc54NgA1suuLu7i939yfXfv/fJJ9J8pJ5DwbAxs5rj7uqlpNcm+T+eQwDwOZmDndVPT/J3yV5W3f/zwaPH66qY1V1bHV1dStnBOAsM4W7qvbk6Wj/ZXf//UbHdPeR7j7Y3QeXlpa2ckYAzjLLWSWV5J1JPtPdfzD/kQB4LrOsuF+V5BeTvLqqHlz7+Zk5zwXAOWx6OmB3fyxJbcMsAMzAJycBBiPcAIMRboDBCDfAYIQbYDDCDTAY4QYYjHADDEa4AQYj3ACDEW6AwQg3wGCEG2Awwg0wGOEGGIxwAwxGuAEGI9wAgxFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBCDfAYIQbYDDCDTCYTcNdVe+qqser6tPbMRAAz22WFfedSW6a8xwAzGjTcHf3R5P89zbMAsAM7HEDDGbLwl1Vh6vqWFUdW11d3aqnBWCdLQt3dx/p7oPdfXBpaWmrnhaAdWyVAAxmltMB35Pk40l+uKpWqupN8x8LgHPZvdkB3f3G7RgEgNnYKgEYjHADDEa4AQYj3ACDEW6AwQg3wGCEG2Awwg0wGOEGGIxwAwxGuAEGI9wAgxFugMEIN8BghBtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBCDfAYIQbYDDCDTCY3VMPMJXlW98/9QgXlUdvv3nqEWBhWHEDDEa4AQYj3ACDEW6AwQg3wGBmCndV3VRVn6uqL1TVrfMeCoBz2zTcVbUryR8leU2Sa5K8saqumfdgAGxslhX3dUm+0N2PdPc3kxxN8nPzHQuAc5kl3C9J8qWzbq+s3QfABLbsk5NVdTjJ4bWbT1bV57bquRfcZUmemHqIzdTvTz0BE/H3uXV+YNYDZwn3Y0leetbtA2v3PUt3H0lyZNYXZjZVday7D049B2zE3+c0ZtkqeSDJ1VV1ZVVdkuQNSd4337EAOJdNV9zd/a2qekuSf0yyK8m7uvuhuU8GwIZm2uPu7g8k+cCcZ2Fjtp/Yyfx9TqC6e+oZADgPPvIOMBjhBhiMcAMMRrh3qKp6QVW9YOo5gJ1HuHeQqrqiqo5W1WqS+5N8oqoeX7tvedrp4GlVdXlV/djaz+VTz7OInFWyg1TVx5PckeRvu/vba/ftSvILSd7W3a+ccj4WW1W9IsmfJPm+PPPp6QNJvp7kV7v7k1PNtmiEewepqs9399Xn+xhsh6p6MMmvdPf96+5/ZZI/7e6XTzPZ4tmyi0yxJY5X1R8neXeeuSLjS5P8cpITk00FT3ve+mgnSXffV1XPm2KgRWXFvYOsXQvmTXn6eudnLp27kuQfkryzu78x1WxQVX+Y5AeT3JVnLyx+Kcl/dvdbpppt0Qg3MLOqek2evbB4LMn71i6LwTYR7kFU1Wu7+56p5wCm53TAcfz41APAuax9kQrbxJuTO0xV/Ug2/qfo70w3FWyqph5gkVhx7yBV9Vt5+suYK8kn1n4qyXuq6tYpZ4NNfHPqARaJPe4dpKr+I8nLuvupdfdfkuQh53GzU1XVf3X3FVPPsShslewsp5O8OMkX193/orXHYDJV9alzPZTER9+3kXDvLG9L8s9V9fk8c57sFUmuSuIcWaZ2eZIbk3xt3f2V5F+3f5zFJdw7SHffW1U/lOS6PPvNyQfOXLsEJnRPkud394PrH6iqj2z/OIvLHjfAYJxVAjAY4QYYjHADDEa4AQYj3Fz0qur2qvq1s27/blX9xpQzwXdDuFkEf53k9Wfdfv3afTAk53Fz0evuE1X1wqp6cZKlJF/r7i9t9v/BTiXcLIq7k/x8ku+P1TaD8wEcFkJVvSzJnyW5LMlPdveXJx4JLpg9bhZCdz+UZF+Sx0Sb0VlxAwzGihtgMMINMBjhBhiMcAMMRrgBBiPcAIMRboDBCDfAYP4P+xhprKp6juYAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEQCAYAAACk818iAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADXBJREFUeJzt3X2MpeVdxvHrctnNKt2UdDjSlmEcIoihJgUykpqaGsAKSNOGpm0gwbctmSa2piQaxZCg/sdfBjS+jW2lxMpWUGKldLVqSdNYlp1l14ZdWltXKrMpMKxt+pJsednLP/asXYZZduc8d+Y55zffTzJhz8vezy9k+fLsPc+Zx0kEAKjjh/oeAADQFmEHgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AiiHsAFDMGX0c9Oyzz87s7GwfhwaAibVnz57nkgxO9b5ewj47O6vFxcU+Dg0AE8v210/nfWzFAEAxncNu+yLb+074+rbtW1oMBwBYu85bMUm+IukSSbK9SdIhSQ90XRcAMJrWe+xXSfqvJKe1D3SiF154QUtLSzpy5EjjkdrZunWrpqentXnz5r5HAYCTah32GyTdu9oLtuclzUvSzMzMK15fWlrStm3bNDs7K9uNx+ouiQ4fPqylpSWdf/75fY8DACfV7JuntrdIeqek+1Z7PclCkrkkc4PBK6/WOXLkiKampsYy6pJkW1NTU2P9NwoAkNpeFXOtpMeSPDPqAuMa9ePGfT4AkNqG/UadZBsGALB+muyx2z5T0tslfaDFepI0e+unWy0lSXryjuuargf0qfV/HxtdtT40CXuS70maarEWAKAbPnk6dPvtt+vOO+/8/8e33Xab7rrrrh4nAoDREPah7du365577pEkHT16VDt27NBNN93U81QAsHa9/BCwcTQ7O6upqSnt3btXzzzzjC699FJNTbG7BGDyEPYT3Hzzzbr77rv19NNPa/v27X2PAwAjYSvmBNdff7127typ3bt36+qrr+57HAAYydiesfdx+dGWLVt0xRVX6KyzztKmTZvW/fgA0MLYhr0PR48e1SOPPKL77lv1pyIAwERgK2bowIEDuuCCC3TVVVfpwgsv7HscABgZZ+xDF198sQ4ePNj3GADQ2VidsSfpe4RXNe7zAYA0RmHfunWrDh8+PLbxPP7z2Ldu3dr3KADwqsZmK2Z6elpLS0taXl7ue5STOn4HJQAYZ2MT9s2bN3NnIgBoYGy2YgAAbRB2ACiGsANAMYQdAIppEnbbZ9m+3/aXbT9h+2darAsAWLtWV8XcJWlnkvfY3iLpRxqtCwBYo85ht/1aSW+T9KuSlOR5Sc93XRcAMJoWWzHnS1qW9Fe299r+iO0zV77J9rztRduL4/whJACYdC3CfoakyyT9WZJLJX1P0q0r35RkIclckrnBYNDgsACA1bQI+5KkpSS7ho/v17HQAwB60DnsSZ6W9JTti4ZPXSXpQNd1AQCjaXVVzG9I+sTwipiDkn6t0boAgDVqEvYk+yTNtVgLANANnzwFgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AiiHsAFAMYQeAYgg7ABRD2AGgGMIOAMUQdgAohrADQDGEHQCKIewAUAxhB4BimtxByfaTkr4j6SVJLybhbkoA0JNW9zyVpCuSPNdwPQDACNiKAYBiWoU9kv7Z9h7b86u9wfa87UXbi8vLy40OCwBYqVXYfzbJZZKulfRB229b+YYkC0nmkswNBoNGhwUArNQk7EkODf/5rKQHJF3eYl0AwNp1DrvtM21vO/5rSb8g6fGu6wIARtPiqphzJD1g+/h6f5NkZ4N1AQAj6Bz2JAclvbnBLACABrjcEQCKIewAUAxhB4BiCDsAFEPYAaAYwg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AimkWdtubbO+1/WCrNQEAa9fyjP3Dkp5ouB4AYARNwm57WtJ1kj7SYj0AwOhanbHfKem3JR092Rtsz9tetL24vLzc6LAAgJU6h932OyQ9m2TPq70vyUKSuSRzg8Gg62EBACfR4oz9rZLeaftJSTskXWn7rxusCwAYQeewJ/ndJNNJZiXdIOnfktzUeTIAwEi4jh0Aijmj5WJJHpb0cMs1AQBrwxk7ABRD2AGgGMIOAMUQdgAohrADQDGEHQCKIewAUAxhB4BiCDsAFEPYAaAYwg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAU0znstrfaftT2f9jeb/sPWgwGABhNi1vjfV/SlUm+a3uzpC/Y/kySRxqsDQBYo85hTxJJ3x0+3Dz8Std1AQCjabLHbnuT7X2SnpX02SS7VnnPvO1F24vLy8stDgsAWEWTsCd5KcklkqYlXW77p1Z5z0KSuSRzg8GgxWEBAKtoelVMkm9J+pyka1quCwA4fS2uihnYPmv46x+W9HZJX+66LgBgNC2uinmDpI/b3qRj/6P42yQPNlgXADCCFlfFfEnSpQ1mAQA0wCdPAaAYwg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AiiHsAFAMYQeAYgg7ABRD2AGgmBb3PD3P9udsH7C93/aHWwwGABhNi3uevijpN5M8ZnubpD22P5vkQIO1AQBr1PmMPck3kjw2/PV3JD0h6dyu6wIARtN0j932rI7d2HrXKq/N2160vbi8vNzysACAEzQLu+3XSPo7Sbck+fbK15MsJJlLMjcYDFodFgCwQpOw296sY1H/RJK/b7EmAGA0La6KsaSPSnoiyR92HwkA0EWLM/a3SvolSVfa3jf8+sUG6wIARtD5csckX5DkBrMAABrgk6cAUAxhB4BiCDsAFEPYAaAYwg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AiiHsAFBMq5tZf8z2s7Yfb7EeAGB0rc7Y75Z0TaO1AAAdNAl7ks9L+t8WawEAumGPHQCKWbew2563vWh7cXl5eb0OCwAbzrqFPclCkrkkc4PBYL0OCwAbDlsxAFBMq8sd75X0RUkX2V6y/f4W6wIA1u6MFoskubHFOgCA7tiKAYBiCDsAFEPYAaAYwg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAUQ9gBoBjCDgDFEHYAKIawA0AxhB0AiiHsAFAMYQeAYlrdGu8a21+x/TXbt7ZYEwAwms5ht71J0p9IulbSxZJutH1x13UBAKNpccZ+uaSvJTmY5HlJOyS9q8G6AIARtAj7uZKeOuHx0vA5AEAPzlivA9melzQvSTMzM+t12E5mb/103yOU8uQd1/U9Qhn8u8SraXHGfkjSeSc8nh4+9zJJFpLMJZkbDAYNDgsAWE2LsO+WdKHt821vkXSDpE81WBcAMILOWzFJXrT9IUn/JGmTpI8l2d95MgDASJrssSd5SNJDLdYCAHTDJ08BoBjCDgDFEHYAKIawA0AxhB0AiiHsAFAMYQeAYgg7ABRD2AGgGMIOAMUQdgAohrADQDGEHQCKIewAUMy63RpvEnH7MQCTiDN2ACiGsANAMZ3Cbvu9tvfbPmp7rtVQAIDRdT1jf1zSuyV9vsEsAIAGOn3zNMkTkmS7zTQAgM7YYweAYk55xm77XyS9fpWXbkvyD6d7INvzkuYlaWZm5rQHBACszSnDnuTnWxwoyYKkBUmam5tLizUBAK/EVgwAFONk9JNn29dL+mNJA0nfkrQvydWn8fuWJX195ANjpbMlPdf3EMAq+LPZ1o8lGZzqTZ3CjvFgezEJnyPA2OHPZj/YigGAYgg7ABRD2GtY6HsA4CT4s9kD9tgBoBjO2AGgGMIOAMUQdgAohrBPKNuvs/26vucAMH4I+wSxPWN7x/CTu7skPWr72eFzs/1OBxxj+xzblw2/zul7no2Iq2ImiO0vSrpT0v1JXho+t0nSeyXdkuQtfc6Hjc32JZL+XNJrJR0aPj2tYz9u5NeTPNbXbBsNYZ8gtr+a5MK1vgasB9v7JH0gya4Vz79F0l8keXM/k208ne6ghHW3x/afSvq4pKeGz50n6Vck7e1tKuCYM1dGXZKSPGL7zD4G2qg4Y58gtrdIer+kd0k6d/j0kqR/lPTRJN/vazbA9h9J+nFJ9+jlJx6/LOm/k3yor9k2GsIOoBnb1+rlJx6HJH0qyUP9TbXxEPYibL8jyYN9zwGgf1zuWMdP9z0AcDLDex5jnfDN0wlj+ye1+l91f6+/qYBTct8DbCScsU8Q278jaYeO/Ufy6PDLku61fWufswGn8HzfA2wk7LFPENv/KelNSV5Y8fwWSfu5jh3jyvb/JJnpe46Ngq2YyXJU0hv1yhuBv2H4GtAb21862UuS+NEC64iwT5ZbJP2r7a/qB9cJz0i6QBLXCKNv50i6WtI3VzxvSf++/uNsXIR9giTZafsnJF2ul3/zdPfxnx0D9OhBSa9Jsm/lC7YfXv9xNi722AGgGK6KAYBiCDsAFEPYAaAYwg4AxRB2QJLtO2x/8ITHv2/7t/qcCRgVYQeO+aSk953w+H3D54CJw3XsgKQke23/qO03ShpI+maSp071+4BxRNiBH7hP0nskvV6crWOC8QElYMj2myT9paSzJf1ckm/0PBIwEvbYgaEk+yVtk3SIqGOSccYOAMVwxg4AxRB2ACiGsANAMYQdAIoh7ABQDGEHgGIIOwAUQ9gBoJj/A8Sb/MKZv07KAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -193,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "scrolled": true }, @@ -204,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "scrolled": false }, @@ -213,7 +151,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n", + "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'U', 'X0'}\n", "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n", "INFO:dowhy.do_sampler:Using McmcSampler for do sampling.\n" ] @@ -228,7 +166,6 @@ "yes\n", "{'observed': 'yes'}\n", "Model to find the causal effect of treatment v on outcome y\n", - "{'observed': 'yes'}\n", "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", "All common causes are observed. Causal effect can be identified.\n", "McmcSampler\n", @@ -244,53 +181,8 @@ "INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n", "INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n", "INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n", - "Sampling 4 chains: 100%|██████████| 8000/8000 [00:04<00:00, 1610.17draws/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " X0 v y\n", - "612 1.096689 1.0 5.683769\n", - "751 1.086641 0.0 -0.149304\n", - "108 2.198076 0.0 2.670802\n", - "123 1.005457 1.0 6.318228\n", - "451 0.619074 0.0 -0.676561\n", - "898 1.790624 0.0 -0.119601\n", - "577 2.902210 1.0 5.910286\n", - "911 0.744255 1.0 6.569214\n", - "796 0.314369 1.0 5.655302\n", - "923 1.498186 1.0 4.984086\n", - " X0 v y\n", - "223 -0.780635 1.0 4.130883\n", - "297 1.068778 1.0 3.776232\n", - "914 2.270624 1.0 6.126630\n", - "377 -0.016168 1.0 5.808494\n", - "475 0.632153 0.0 1.915281\n", - "982 1.122380 1.0 5.332776\n", - "11 0.484034 0.0 -0.492809\n", - "108 2.198076 0.0 2.670802\n", - "773 2.116355 0.0 0.378585\n", - "871 2.440941 0.0 2.240204\n", - " X0 v y\n", - "222 -0.152137 1 1.193461\n", - "431 2.214968 1 1.676506\n", - "700 0.694453 1 0.447275\n", - "718 1.999470 1 7.856392\n", - "670 1.135020 1 6.696070\n", - "166 -0.988365 1 -0.927939\n", - "909 -0.031697 1 5.897646\n", - "186 1.263339 1 0.121611\n", - "433 1.467491 1 1.818254\n", - "272 2.083983 1 5.634306\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n", + "Sampling 4 chains: 100%|██████████| 8000/8000 [00:05<00:00, 1479.09draws/s]\n", + "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'U', 'X0'}\n", "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" ] }, @@ -298,87 +190,9 @@ "name": "stdout", "output_type": "stream", "text": [ - " X0 v y\n", - "703 0.268584 1 3.415063\n", - "152 -0.112613 1 4.855769\n", - "675 0.874949 1 5.074571\n", - "962 -0.343846 1 -1.281090\n", - "812 -0.673580 1 -1.201265\n", - "795 1.292742 1 0.040008\n", - "706 2.753909 1 6.168354\n", - "215 -0.388153 1 0.017930\n", - "168 0.600874 1 -1.447453\n", - "635 3.066023 1 1.706179\n", - " X0 v y\n", - "429 0.913090 1 3.503883\n", - "765 0.814357 1 3.275973\n", - "316 2.200575 1 8.462898\n", - "13 1.211798 1 4.502192\n", - "55 -0.210134 1 4.072864\n", - "38 -0.589735 1 -1.252626\n", - "857 -0.063631 1 10.943018\n", - "421 1.555166 1 2.712822\n", - "945 0.740692 1 7.542532\n", - "890 0.289802 1 4.344509\n", - "{'observed': 'yes'}\n", "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", "All common causes are observed. Causal effect can be identified.\n", - "McmcSampler\n", - " X0 v y\n", - "251 -0.717858 1.0 3.136038\n", - "558 0.654406 0.0 -1.329377\n", - "98 1.252980 0.0 0.332077\n", - "630 0.377851 1.0 6.450061\n", - "119 -0.291942 0.0 0.417807\n", - "927 2.427735 1.0 6.882456\n", - "793 -0.659126 1.0 4.477884\n", - "502 0.967928 1.0 3.797730\n", - "413 0.909902 1.0 5.546545\n", - "429 0.913090 0.0 -0.211590\n", - " X0 v y\n", - "281 1.504805 0.0 1.884065\n", - "419 0.032770 1.0 3.649192\n", - "404 1.782597 0.0 0.350244\n", - "985 -0.610096 0.0 -1.054331\n", - "120 2.554043 1.0 4.348485\n", - "59 2.829189 1.0 6.539892\n", - "673 -0.452400 1.0 3.788344\n", - "660 2.385558 1.0 6.034419\n", - "852 1.505619 0.0 0.493016\n", - "784 0.512136 0.0 -0.225890\n", - " X0 v y\n", - "756 0.286867 0 -0.590514\n", - "729 1.569174 0 4.433689\n", - "262 -0.079203 0 -0.191959\n", - "405 -0.348433 0 4.305770\n", - "705 0.741166 0 0.068463\n", - "949 1.734418 0 0.655759\n", - "825 0.354715 0 -0.386749\n", - "186 1.263339 0 0.121611\n", - "886 -1.367113 0 -1.408194\n", - "414 1.702780 0 5.148147\n", - " X0 v y\n", - "153 0.622880 0 1.122876\n", - "390 0.296689 0 0.321278\n", - "140 2.181477 0 6.882432\n", - "296 0.664508 0 2.646718\n", - "266 2.412813 0 6.192862\n", - "901 2.007180 0 -0.879939\n", - "403 1.649582 0 4.607203\n", - "137 2.918708 0 2.296482\n", - "630 0.377851 0 6.450061\n", - "223 -0.780635 0 4.130883\n", - " X0 v y\n", - "692 0.386427 0 0.173766\n", - "536 0.162485 0 1.416439\n", - "510 0.733712 0 -0.611215\n", - "915 -0.411775 0 -0.620634\n", - "169 0.460467 0 -2.872950\n", - "399 0.995904 0 1.360631\n", - "935 1.281168 0 1.049168\n", - "697 2.017054 0 0.830318\n", - "627 0.862794 0 0.593102\n", - "527 0.764523 0 -0.339923\n" + "McmcSampler\n" ] } ], @@ -395,12 +209,13 @@ " outcome='y',\n", " method='mcmc', \n", " dot_graph=data['dot_graph'],\n", - " proceed_when_unidentifiable=True)" + " proceed_when_unidentifiable=True,\n", + " use_previous_sampler=True)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": { "scrolled": true }, @@ -434,183 +249,183 @@ " \n", " \n", " 0\n", - " 1.521782\n", + " -0.679571\n", " 0\n", - " 0.734869\n", + " -0.540989\n", " \n", " \n", " 1\n", - " 1.393287\n", + " -0.243537\n", " 0\n", - " 3.111109\n", + " -0.816734\n", " \n", " \n", " 2\n", - " 1.744401\n", + " -0.195101\n", " 0\n", - " 2.814173\n", + " -0.420884\n", " \n", " \n", " 3\n", - " -0.984910\n", + " 0.923288\n", " 0\n", - " -0.547759\n", + " 0.543097\n", " \n", " \n", " 4\n", - " 1.531897\n", + " 0.389773\n", " 0\n", - " 0.250786\n", + " 2.000232\n", " \n", " \n", " 5\n", - " -0.414219\n", + " 0.345340\n", " 0\n", - " -0.263245\n", + " -0.235523\n", " \n", " \n", " 6\n", - " 3.440954\n", + " 0.539989\n", " 0\n", - " 3.228618\n", + " -0.191927\n", " \n", " \n", " 7\n", - " 0.419157\n", + " 1.294383\n", " 0\n", - " 0.527223\n", + " 1.841164\n", " \n", " \n", " 8\n", - " 0.336473\n", + " -0.557656\n", " 0\n", - " -0.805009\n", + " -1.179258\n", " \n", " \n", " 9\n", - " -1.038132\n", + " -0.581319\n", " 0\n", - " 0.206324\n", + " -1.585532\n", " \n", " \n", " 10\n", - " 3.020776\n", + " 0.089578\n", " 0\n", - " 0.930900\n", + " 0.371095\n", " \n", " \n", " 11\n", - " 0.484034\n", + " 0.826961\n", " 0\n", - " 1.303296\n", + " 3.333531\n", " \n", " \n", " 12\n", - " 1.494246\n", + " -0.671221\n", " 0\n", - " 0.315203\n", + " -2.539698\n", " \n", " \n", " 13\n", - " 1.211798\n", + " 0.986791\n", " 0\n", - " 2.533215\n", + " 1.801363\n", " \n", " \n", " 14\n", - " 3.052140\n", + " 1.594109\n", " 0\n", - " 1.405309\n", + " 2.040565\n", " \n", " \n", " 15\n", - " -0.294350\n", + " -0.245527\n", " 0\n", - " -0.697531\n", + " -0.289327\n", " \n", " \n", " 16\n", - " -0.627547\n", + " 0.178187\n", " 0\n", - " -1.323468\n", + " 0.141852\n", " \n", " \n", " 17\n", - " 2.113168\n", + " 2.024924\n", " 0\n", - " -0.652504\n", + " 2.980681\n", " \n", " \n", " 18\n", - " 1.918390\n", + " -1.760241\n", " 0\n", - " 0.496508\n", + " -3.197609\n", " \n", " \n", " 19\n", - " 1.409377\n", + " 1.911770\n", " 0\n", - " -0.916860\n", + " 4.160535\n", " \n", " \n", " 20\n", - " 1.751635\n", + " 0.411533\n", " 0\n", - " 1.383517\n", + " 0.891029\n", " \n", " \n", " 21\n", - " 1.508836\n", + " -1.601890\n", " 0\n", - " 0.090146\n", + " -3.900897\n", " \n", " \n", " 22\n", - " 1.940414\n", + " -0.779602\n", " 0\n", - " 2.880129\n", + " -3.024336\n", " \n", " \n", " 23\n", - " -0.708565\n", + " 2.117118\n", " 0\n", - " 1.141244\n", + " 3.186386\n", " \n", " \n", " 24\n", - " 0.953574\n", + " 0.515388\n", " 0\n", - " 1.505149\n", + " 1.042302\n", " \n", " \n", " 25\n", - " 0.737527\n", + " 0.195795\n", " 0\n", - " 1.087160\n", + " -1.043282\n", " \n", " \n", " 26\n", - " 0.955564\n", + " 0.112839\n", " 0\n", - " -0.356845\n", + " 0.106886\n", " \n", " \n", " 27\n", - " 1.030763\n", + " 0.508712\n", " 0\n", - " -0.037234\n", + " -0.138736\n", " \n", " \n", " 28\n", - " 2.442606\n", + " 1.449329\n", " 0\n", - " 3.705543\n", + " 2.797328\n", " \n", " \n", " 29\n", - " -0.150845\n", + " 0.948285\n", " 0\n", - " 0.385139\n", + " 2.310356\n", " \n", " \n", " ...\n", @@ -620,183 +435,183 @@ " \n", " \n", " 970\n", - " 0.481196\n", + " 0.733012\n", " 0\n", - " 0.190078\n", + " 2.388490\n", " \n", " \n", " 971\n", - " 2.352860\n", + " -0.970542\n", " 0\n", - " 0.402531\n", + " -2.013554\n", " \n", " \n", " 972\n", - " 1.984094\n", + " 1.451770\n", " 0\n", - " -0.040868\n", + " 3.385053\n", " \n", " \n", " 973\n", - " 0.565025\n", + " 0.086223\n", " 0\n", - " -0.361526\n", + " -0.224399\n", " \n", " \n", " 974\n", - " 1.148690\n", + " 0.062156\n", " 0\n", - " 1.527679\n", + " -0.295354\n", " \n", " \n", " 975\n", - " 0.788356\n", + " 1.178053\n", " 0\n", - " -0.933352\n", + " 1.891484\n", " \n", " \n", " 976\n", - " 0.755202\n", + " 0.038045\n", " 0\n", - " -1.300738\n", + " 0.491061\n", " \n", " \n", " 977\n", - " 0.727366\n", + " -0.420454\n", " 0\n", - " 1.825278\n", + " -1.158550\n", " \n", " \n", " 978\n", - " 0.705377\n", + " 0.498168\n", " 0\n", - " 1.484147\n", + " 0.916308\n", " \n", " \n", " 979\n", - " 1.259369\n", + " 0.179320\n", " 0\n", - " 0.821890\n", + " 1.385915\n", " \n", " \n", " 980\n", - " 0.216203\n", + " -1.081955\n", " 0\n", - " 0.279668\n", + " -4.169600\n", " \n", " \n", " 981\n", - " 0.717459\n", + " 0.792726\n", " 0\n", - " -0.083790\n", + " 0.799942\n", " \n", " \n", " 982\n", - " 1.122380\n", + " 0.817159\n", " 0\n", - " 0.965924\n", + " 1.660682\n", " \n", " \n", " 983\n", - " 2.483977\n", + " 0.277124\n", " 0\n", - " -0.015003\n", + " 0.915327\n", " \n", " \n", " 984\n", - " 1.550738\n", + " 1.139441\n", " 0\n", - " 0.805843\n", + " 2.971099\n", " \n", " \n", " 985\n", - " -0.610096\n", + " 2.636305\n", " 0\n", - " 0.654642\n", + " 4.528907\n", " \n", " \n", " 986\n", - " 0.469281\n", + " 0.849081\n", " 0\n", - " -1.003249\n", + " 0.664671\n", " \n", " \n", " 987\n", - " 1.385763\n", + " 2.231701\n", " 0\n", - " 0.039411\n", + " 4.656510\n", " \n", " \n", " 988\n", - " 2.566888\n", + " 1.065030\n", " 0\n", - " 1.139561\n", + " 2.355032\n", " \n", " \n", " 989\n", - " 0.109643\n", + " -0.134438\n", " 0\n", - " -0.717650\n", + " 0.298530\n", " \n", " \n", " 990\n", - " -1.001541\n", + " -2.540774\n", " 0\n", - " -0.879806\n", + " -5.619055\n", " \n", " \n", " 991\n", - " 1.827036\n", + " -0.368138\n", " 0\n", - " 0.566749\n", + " -0.208480\n", " \n", " \n", " 992\n", - " 0.693931\n", + " 1.050256\n", " 0\n", - " 0.437658\n", + " 1.603709\n", " \n", " \n", " 993\n", - " 1.658332\n", + " 0.669631\n", " 0\n", - " -0.883789\n", + " 0.961663\n", " \n", " \n", " 994\n", - " 1.531808\n", + " -0.508734\n", " 0\n", - " 2.709831\n", + " -0.771802\n", " \n", " \n", " 995\n", - " 2.011649\n", + " -0.103255\n", " 0\n", - " -0.629456\n", + " -0.251265\n", " \n", " \n", " 996\n", - " -0.407929\n", + " -0.906700\n", " 0\n", - " 1.881557\n", + " -2.711775\n", " \n", " \n", " 997\n", - " -0.076402\n", + " 0.156403\n", " 0\n", - " -0.211007\n", + " 0.602863\n", " \n", " \n", " 998\n", - " 1.302853\n", + " -0.276539\n", " 0\n", - " 0.317637\n", + " -0.986462\n", " \n", " \n", " 999\n", - " 1.987589\n", + " 0.608260\n", " 0\n", - " 0.353057\n", + " 0.644999\n", " \n", " \n", "\n", @@ -805,407 +620,86 @@ ], "text/plain": [ " X0 v y\n", - "0 1.521782 0 0.734869\n", - "1 1.393287 0 3.111109\n", - "2 1.744401 0 2.814173\n", - "3 -0.984910 0 -0.547759\n", - "4 1.531897 0 0.250786\n", - "5 -0.414219 0 -0.263245\n", - "6 3.440954 0 3.228618\n", - "7 0.419157 0 0.527223\n", - "8 0.336473 0 -0.805009\n", - "9 -1.038132 0 0.206324\n", - "10 3.020776 0 0.930900\n", - "11 0.484034 0 1.303296\n", - "12 1.494246 0 0.315203\n", - "13 1.211798 0 2.533215\n", - "14 3.052140 0 1.405309\n", - "15 -0.294350 0 -0.697531\n", - "16 -0.627547 0 -1.323468\n", - "17 2.113168 0 -0.652504\n", - "18 1.918390 0 0.496508\n", - "19 1.409377 0 -0.916860\n", - "20 1.751635 0 1.383517\n", - "21 1.508836 0 0.090146\n", - "22 1.940414 0 2.880129\n", - "23 -0.708565 0 1.141244\n", - "24 0.953574 0 1.505149\n", - "25 0.737527 0 1.087160\n", - "26 0.955564 0 -0.356845\n", - "27 1.030763 0 -0.037234\n", - "28 2.442606 0 3.705543\n", - "29 -0.150845 0 0.385139\n", + "0 -0.679571 0 -0.540989\n", + "1 -0.243537 0 -0.816734\n", + "2 -0.195101 0 -0.420884\n", + "3 0.923288 0 0.543097\n", + "4 0.389773 0 2.000232\n", + "5 0.345340 0 -0.235523\n", + "6 0.539989 0 -0.191927\n", + "7 1.294383 0 1.841164\n", + "8 -0.557656 0 -1.179258\n", + "9 -0.581319 0 -1.585532\n", + "10 0.089578 0 0.371095\n", + "11 0.826961 0 3.333531\n", + "12 -0.671221 0 -2.539698\n", + "13 0.986791 0 1.801363\n", + "14 1.594109 0 2.040565\n", + "15 -0.245527 0 -0.289327\n", + "16 0.178187 0 0.141852\n", + "17 2.024924 0 2.980681\n", + "18 -1.760241 0 -3.197609\n", + "19 1.911770 0 4.160535\n", + "20 0.411533 0 0.891029\n", + "21 -1.601890 0 -3.900897\n", + "22 -0.779602 0 -3.024336\n", + "23 2.117118 0 3.186386\n", + "24 0.515388 0 1.042302\n", + "25 0.195795 0 -1.043282\n", + "26 0.112839 0 0.106886\n", + "27 0.508712 0 -0.138736\n", + "28 1.449329 0 2.797328\n", + "29 0.948285 0 2.310356\n", ".. ... .. ...\n", - "970 0.481196 0 0.190078\n", - "971 2.352860 0 0.402531\n", - "972 1.984094 0 -0.040868\n", - "973 0.565025 0 -0.361526\n", - "974 1.148690 0 1.527679\n", - "975 0.788356 0 -0.933352\n", - "976 0.755202 0 -1.300738\n", - "977 0.727366 0 1.825278\n", - "978 0.705377 0 1.484147\n", - "979 1.259369 0 0.821890\n", - "980 0.216203 0 0.279668\n", - "981 0.717459 0 -0.083790\n", - "982 1.122380 0 0.965924\n", - "983 2.483977 0 -0.015003\n", - "984 1.550738 0 0.805843\n", - "985 -0.610096 0 0.654642\n", - "986 0.469281 0 -1.003249\n", - "987 1.385763 0 0.039411\n", - "988 2.566888 0 1.139561\n", - "989 0.109643 0 -0.717650\n", - "990 -1.001541 0 -0.879806\n", - "991 1.827036 0 0.566749\n", - "992 0.693931 0 0.437658\n", - "993 1.658332 0 -0.883789\n", - "994 1.531808 0 2.709831\n", - "995 2.011649 0 -0.629456\n", - "996 -0.407929 0 1.881557\n", - "997 -0.076402 0 -0.211007\n", - "998 1.302853 0 0.317637\n", - "999 1.987589 0 0.353057\n", + "970 0.733012 0 2.388490\n", + "971 -0.970542 0 -2.013554\n", + "972 1.451770 0 3.385053\n", + "973 0.086223 0 -0.224399\n", + "974 0.062156 0 -0.295354\n", + "975 1.178053 0 1.891484\n", + "976 0.038045 0 0.491061\n", + "977 -0.420454 0 -1.158550\n", + "978 0.498168 0 0.916308\n", + "979 0.179320 0 1.385915\n", + "980 -1.081955 0 -4.169600\n", + "981 0.792726 0 0.799942\n", + "982 0.817159 0 1.660682\n", + "983 0.277124 0 0.915327\n", + "984 1.139441 0 2.971099\n", + "985 2.636305 0 4.528907\n", + "986 0.849081 0 0.664671\n", + "987 2.231701 0 4.656510\n", + "988 1.065030 0 2.355032\n", + "989 -0.134438 0 0.298530\n", + "990 -2.540774 0 -5.619055\n", + "991 -0.368138 0 -0.208480\n", + "992 1.050256 0 1.603709\n", + "993 0.669631 0 0.961663\n", + "994 -0.508734 0 -0.771802\n", + "995 -0.103255 0 -0.251265\n", + "996 -0.906700 0 -2.711775\n", + "997 0.156403 0 0.602863\n", + "998 -0.276539 0 -0.986462\n", + "999 0.608260 0 0.644999\n", "\n", "[1000 rows x 3 columns]" ] }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cdf_0" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n", - "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'observed': 'yes'}\n", - "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", - "All common causes are observed. Causal effect can be identified.\n", - "McmcSampler\n", - " X0 v y\n", - "863 1.564016 1.0 5.108925\n", - "260 -0.321099 0.0 -0.677254\n", - "871 2.440941 0.0 2.240204\n", - "787 0.390931 1.0 5.349331\n", - "98 1.252980 0.0 0.332077\n", - "368 1.144241 0.0 -1.676433\n", - "765 0.814357 1.0 5.318945\n", - "667 -0.745254 0.0 -0.145617\n", - "838 -0.122247 0.0 0.330018\n", - "110 0.452500 1.0 2.972335\n", - " X0 v y\n", - "856 0.673717 1.0 3.456715\n", - "278 2.006345 1.0 7.413507\n", - "34 0.378670 0.0 1.629312\n", - "299 -0.612085 1.0 4.736626\n", - "293 -0.244447 0.0 0.007012\n", - "212 -0.944316 0.0 -0.126088\n", - "446 -0.200964 0.0 0.213676\n", - "590 0.789561 1.0 5.514221\n", - "289 0.858574 1.0 5.301086\n", - "518 0.563971 0.0 1.271799\n", - " X0 v y\n", - "996 -0.407929 1 0.144058\n", - "125 2.159065 1 5.595686\n", - "842 0.258104 1 4.303597\n", - "877 1.418883 1 0.774502\n", - "849 -0.133069 1 -0.850851\n", - "867 0.732025 1 4.252658\n", - "757 -1.953362 1 -0.464166\n", - "299 -0.612085 1 4.736626\n", - "662 0.910525 1 5.963600\n", - "540 1.428901 1 6.909129\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n", - "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " X0 v y\n", - "917 1.661385 1 -0.065593\n", - "205 0.521601 1 6.795833\n", - "507 0.781827 1 4.209187\n", - "794 0.117923 1 -1.575165\n", - "728 0.481260 1 -0.278846\n", - "574 -0.072040 1 0.104267\n", - "2 1.744401 1 6.409548\n", - "359 1.431587 1 6.761684\n", - "668 1.204832 1 0.660197\n", - "757 -1.953362 1 -0.464166\n", - " X0 v y\n", - "733 0.067462 1 5.597181\n", - "299 -0.612085 1 7.377299\n", - "501 0.152931 1 6.666780\n", - "479 2.444641 1 5.800005\n", - "416 1.549304 1 5.898808\n", - "150 1.289450 1 6.366260\n", - "54 2.066452 1 5.537184\n", - "928 2.967249 1 6.490645\n", - "842 0.258104 1 5.955817\n", - "113 1.097331 1 7.499741\n", - "{'observed': 'yes'}\n", - "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", - "All common causes are observed. Causal effect can be identified.\n", - "McmcSampler\n", - " X0 v y\n", - "567 -0.652549 0.0 0.793207\n", - "171 0.841910 1.0 3.902883\n", - "614 0.087104 1.0 4.605662\n", - "536 0.162485 0.0 0.138752\n", - "581 1.895297 0.0 1.138745\n", - "365 1.252492 1.0 6.751402\n", - "741 1.207304 0.0 2.145552\n", - "813 1.276272 0.0 -0.948142\n", - "336 0.311222 1.0 6.612389\n", - "305 0.323710 1.0 5.151747\n", - " X0 v y\n", - "232 2.995964 1.0 5.416620\n", - "833 0.112845 0.0 -0.088816\n", - "14 3.052140 1.0 8.518381\n", - "502 0.967928 1.0 3.797730\n", - "194 0.072495 0.0 -0.108534\n", - "951 0.315488 1.0 6.662354\n", - "638 0.845503 1.0 5.636279\n", - "691 -0.488899 0.0 0.724800\n", - "259 0.009683 1.0 5.050615\n", - "571 0.048128 1.0 4.654207\n", - " X0 v y\n", - "953 3.214783 0 5.874202\n", - "760 1.983350 0 6.006957\n", - "74 -0.121476 0 0.080851\n", - "670 1.135020 0 6.696070\n", - "199 0.638032 0 6.660963\n", - "840 1.098385 0 7.073482\n", - "316 2.200575 0 0.989949\n", - "503 1.145900 0 4.729096\n", - "813 1.276272 0 -0.948142\n", - "400 0.476705 0 5.555583\n", - " X0 v y\n", - "302 2.085505 0 -0.404540\n", - "730 0.411607 0 4.717481\n", - "566 0.164182 0 5.729023\n", - "461 -0.367278 0 5.088359\n", - "69 0.992770 0 0.713309\n", - "553 0.620375 0 5.713161\n", - "98 1.252980 0 0.332077\n", - "386 1.454488 0 -1.004379\n", - "468 0.915893 0 6.420551\n", - "165 1.176042 0 1.099864\n", - " X0 v y\n", - "366 0.359005 0 0.277255\n", - "620 -0.523936 0 -0.430879\n", - "413 0.909902 0 -0.127029\n", - "893 0.693098 0 -1.137795\n", - "761 2.568516 0 0.295067\n", - "520 0.740888 0 1.817991\n", - "231 1.365035 0 1.853342\n", - "660 2.385558 0 2.248198\n", - "791 0.304952 0 -2.404168\n", - "455 1.097787 0 -0.141627\n" - ] - } - ], - "source": [ - "cdf_1 = cdf.causal.do(x={'v': 1}, \n", - " variable_types={'v': 'b', 'y': 'c', 'X0': 'c'}, \n", - " outcome='y',\n", - " method='mcmc', \n", - " common_causes=['X0'],\n", - " proceed_when_unidentifiable=True)\n", - "cdf_0 = cdf.causal.do(x={'v': 0}, \n", - " variable_types={'v': 'b', 'y': 'c', 'X0': 'c'}, \n", - " outcome='y',\n", - " method='mcmc', \n", - " common_causes=['X0'],\n", - " proceed_when_unidentifiable=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKkAAAAPBAMAAABtvvLvAAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAzWYQMplU74mrdiK7RN1/7zyFAAAACXBIWXMAAA7EAAAOxAGVKw4bAAADCUlEQVQ4Ea2UyYsTQRhHX7qzdNLZmBGvRj24HWwP4nKY5KinCUpkDg42ouBtRqN40JEoKnEQjeKoIGgQxO1gQFxA0IAouMeDeHPGgyJ4mCUa1zF+VZXxL7AOXd39e/W6Ul9VCMwPbkCaXXieky6YVw/yeO1QA7d81Jf74Ss5nIsHwbDDC3Mdenn5HNwtP4bVV/o7sfFE2u2SjGQVzrR0yQqEixCC09wiPAXLGqE0s+mpodmAb382tL2DO56bZyBDnqhvYuMJHLmupKyH3dJdrrDqSBM2i40bPn/gEfG6NYiTRrMLYL+ho4NE87EWqbFwBrsTG09UO2EfXKjhzpW5WmI9Cj28bTBJ4pcQ8QrhJpp9AgOeplNpnInEEKm0U8ceM7HxzFjneFyQ5Qt0rKMPOSA2WYH4oPSpComWGfYDRn1NZ+vE1CcHMonfnpMxsfFE513KSaLauMeSGWu0vdCXVz0lUrOG+8nmSfxEsfZXsVY1HZG5Si3sT/Izpw6jYwmUJ+65qkrSEt+xqzNWjk96Ut8PHtkukv6DEsGfmnW/Q29V084EgSncY7KUTrtblsmoxKPaS31V5XeYsYb7x4fkdWyI7ASh4oM8QQ2/dGWuvSWzWZ6yrSXQ5QZb331rKItSqW0kbVNNdwXY8s+6iNg3mSy7vdQY1rRZAcXe7KyA0IQXr/kiXaRoVbhRlDutUskJOK4/E8tjl/5Zhen178ErP5nGmpJqhVuGlWoN+AitWrTp1oi3Ih7BTmyS7fIBNSm24t4eKWzaVVU7y56QEZm2J1bZldZ0vK7eavY5LPUUXZNBgUqqKdas3J40sUny8F5JgyViwkUqZr/ukcWv7VX1DDYJDZpToFl9CjRtFemtReokmxEZ+UwOmKiM5xbWHrJp7pcPnUesdSlRE157LGaePrFnWeHzkRU5OcHCJn17p6Fj3XYXVkbt126sqomNx+0b8VTZ5rTbf2SHvPqRCZ2ZXEdwRP5dEoWNDXnXdwpWXpOLZu3hq7kOXS5I/KbwAtb29Xdi45H5/f/2F7jeF5bOYBbeAAAAAElFTkSuQmCC\n", - "text/latex": [ - "$$5.241836050293852$$" - ], - "text/plain": [ - "5.241836050293852" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(cdf_1['y'] - cdf_0['y']).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAAAPBAMAAABATN1VAAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAiXZmMs1UEN0i77urRJlR0qN3AAAACXBIWXMAAA7EAAAOxAGVKw4bAAADOElEQVQ4EbWUzWtcVRjGf/N55+POnZmoIDSkNa6UUqdxVwQn/4AzEtGV9kp3LsztohhKadKFBaG005XWhRlQrFZCg4gF48dURbSLZjbddGFGtAj9GKLpdJJivD7nnNCkf4B38bzPPec857nnfd9zIfHko5jHxuPVX2AqeiEkX90VanBckxYYeaLO2eqX4PWX6njjY13HrNiBVTs6sr8Ox299obeXmWmZQRO9D3k/4mL8F3xAdg3vVW444Eg3aOZ7TNc4RmKTAH5yDN6IJBc4NdnzkAm9d/A67O6SapNoaoGNfhu/x4vfy/GZkHWKPZYc8C3FxeQ9ynP8CRc4BEcc8y4va28DTn1sdAgH4DGCiEKbYoesRlwsN0ms0tEr17sMWDFns1DaECstoBWfwHJrF8w4Bs+bEwicmpS2+xqmo8Ic/gblDqV7WmDj7CLJDWeASdFXmnBQbBsqWY2DEcv1lUs84ti2gVNbg3/0XWFmiH+X2R6lv6W0saATbPLDxB9mr5kKg8l9OoiB8ksjdzTmnTAznIn8eH/o2LaBUxsD718Z7NF8bsh8hbQxsFH5yaxxifkunD0deYMaJy0wO0VOrfW4CqBUqQl2D0xiDHuQIqc2BnkNN/ZoaqXCfI+0WeXi57xu8lWcEyQXvBhuvmUgml0lUHfwtKzJqW53ziw4tm2AUxsDnaBR0fyprdSIulRlJybv6iUwJeVCtC5114K6J7WpsYJx6cNzJO/rCGI7DJx6R4qCni1udqvINvpD3xbiHHwXvi11zUKuSWot36KotUnJ5NMILdthAL46SAaoyNMh7JN9cdGOPIiZjmoVDIkjGVwxJ7CgHk9tlofW4E3y765qsxpirYcMMh1nsGRHsz3OPXzRUudptAIluclnplcaETexkB4StAuLpi3SFZKtT/VpLcu2DazaGdiLxiQ8BSc5WvcGLiZf8abI9zjQ4llzD/yK97EDLnM0TNXMPXivOvYz1yImHNOlVzEMWLUyqBTlQu8j8ieqt+fg8PiPtto2Vvtd/YZG1Yyl/i3RsdH6FiR+07Jf+9/AwTheJ31VPzvLuH3xtZoDqw6uDH7HG9lbJxPHsQz+5+c/hXtRw8x6gPkAAAAASUVORK5CYII=\n", - "text/latex": [ - "$$0.09536663282119102$$" - ], - "text/plain": [ - "0.09536663282119102" - ] - }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "1.96*(cdf_1['y'] - cdf_0['y']).std() / np.sqrt(len(cdf))" + "cdf_0" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { - "scrolled": false + "scrolled": true }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "
OLS Regression Results
Dep. Variable: y R-squared: 0.941
Model: OLS Adj. R-squared: 0.941
Method: Least Squares F-statistic: 7979.
Date: Sun, 10 Feb 2019 Prob (F-statistic): 0.00
Time: 14:51:27 Log-Likelihood: -1431.1
No. Observations: 1000 AIC: 2866.
Df Residuals: 998 BIC: 2876.
Df Model: 2
Covariance Type: nonrobust
\n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "
coef std err t P>|t| [0.025 0.975]
X0 0.4582 0.029 15.991 0.000 0.402 0.514
v 4.9844 0.052 96.455 0.000 4.883 5.086
\n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "\n", - " \n", - "\n", - "
Omnibus: 8.676 Durbin-Watson: 2.023
Prob(Omnibus): 0.013 Jarque-Bera (JB): 6.070
Skew: -0.032 Prob(JB): 0.0481
Kurtosis: 2.624 Cond. No. 2.39


Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified." - ], - "text/plain": [ - "\n", - "\"\"\"\n", - " OLS Regression Results \n", - "==============================================================================\n", - "Dep. Variable: y R-squared: 0.941\n", - "Model: OLS Adj. R-squared: 0.941\n", - "Method: Least Squares F-statistic: 7979.\n", - "Date: Sun, 10 Feb 2019 Prob (F-statistic): 0.00\n", - "Time: 14:51:27 Log-Likelihood: -1431.1\n", - "No. Observations: 1000 AIC: 2866.\n", - "Df Residuals: 998 BIC: 2876.\n", - "Df Model: 2 \n", - "Covariance Type: nonrobust \n", - "==============================================================================\n", - " coef std err t P>|t| [0.025 0.975]\n", - "------------------------------------------------------------------------------\n", - "X0 0.4582 0.029 15.991 0.000 0.402 0.514\n", - "v 4.9844 0.052 96.455 0.000 4.883 5.086\n", - "==============================================================================\n", - "Omnibus: 8.676 Durbin-Watson: 2.023\n", - "Prob(Omnibus): 0.013 Jarque-Bera (JB): 6.070\n", - "Skew: -0.032 Prob(JB): 0.0481\n", - "Kurtosis: 2.624 Cond. No. 2.39\n", - "==============================================================================\n", - "\n", - "Warnings:\n", - "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n", - "\"\"\"" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = OLS(df['y'], df[['X0', 'v']])\n", - "result = model.fit()\n", - "result.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, "outputs": [ { "data": { @@ -1236,183 +730,183 @@ " \n", " \n", " 0\n", - " 1.521782\n", + " -0.679571\n", " 1\n", - " 5.219948\n", + " 3.545042\n", " \n", " \n", " 1\n", - " 1.393287\n", + " -0.243537\n", " 1\n", - " 6.244632\n", + " 4.665001\n", " \n", " \n", " 2\n", - " 1.744401\n", + " -0.195101\n", " 1\n", - " 7.404562\n", + " 4.499220\n", " \n", " \n", " 3\n", - " -0.984910\n", + " 0.923288\n", " 1\n", - " 5.158497\n", + " 6.673522\n", " \n", " \n", " 4\n", - " 1.531897\n", + " 0.389773\n", " 1\n", - " 4.609039\n", + " 5.976617\n", " \n", " \n", " 5\n", - " -0.414219\n", + " 0.345340\n", " 1\n", - " 4.628114\n", + " 5.553557\n", " \n", " \n", " 6\n", - " 3.440954\n", + " 0.539989\n", " 1\n", - " 5.222432\n", + " 6.375131\n", " \n", " \n", " 7\n", - " 0.419157\n", + " 1.294383\n", " 1\n", - " 7.573535\n", + " 7.499560\n", " \n", " \n", " 8\n", - " 0.336473\n", + " -0.557656\n", " 1\n", - " 6.183994\n", + " 3.655596\n", " \n", " \n", " 9\n", - " -1.038132\n", + " -0.581319\n", " 1\n", - " 3.692398\n", + " 4.275826\n", " \n", " \n", " 10\n", - " 3.020776\n", + " 0.089578\n", " 1\n", - " 5.733954\n", + " 5.298058\n", " \n", " \n", " 11\n", - " 0.484034\n", + " 0.826961\n", " 1\n", - " 8.189707\n", + " 6.531107\n", " \n", " \n", " 12\n", - " 1.494246\n", + " -0.671221\n", " 1\n", - " 5.556789\n", + " 3.541304\n", " \n", " \n", " 13\n", - " 1.211798\n", + " 0.986791\n", " 1\n", - " 7.776344\n", + " 7.004788\n", " \n", " \n", " 14\n", - " 3.052140\n", + " 1.594109\n", " 1\n", - " 6.377158\n", + " 7.981339\n", " \n", " \n", " 15\n", - " -0.294350\n", + " -0.245527\n", " 1\n", - " 7.915583\n", + " 4.159227\n", " \n", " \n", " 16\n", - " -0.627547\n", + " 0.178187\n", " 1\n", - " 3.188627\n", + " 5.359598\n", " \n", " \n", " 17\n", - " 2.113168\n", + " 2.024924\n", " 1\n", - " 5.123711\n", + " 9.260993\n", " \n", " \n", " 18\n", - " 1.918390\n", + " -1.760241\n", " 1\n", - " 6.868787\n", + " 1.503151\n", " \n", " \n", " 19\n", - " 1.409377\n", + " 1.911770\n", " 1\n", - " 5.932311\n", + " 8.763278\n", " \n", " \n", " 20\n", - " 1.751635\n", + " 0.411533\n", " 1\n", - " 5.963934\n", + " 6.014290\n", " \n", " \n", " 21\n", - " 1.508836\n", + " -1.601890\n", " 1\n", - " 3.310934\n", + " 1.895632\n", " \n", " \n", " 22\n", - " 1.940414\n", + " -0.779602\n", " 1\n", - " 3.648624\n", + " 3.509766\n", " \n", " \n", " 23\n", - " -0.708565\n", + " 2.117118\n", " 1\n", - " 5.069651\n", + " 9.529310\n", " \n", " \n", " 24\n", - " 0.953574\n", + " 0.515388\n", " 1\n", - " 5.687417\n", + " 5.812637\n", " \n", " \n", " 25\n", - " 0.737527\n", + " 0.195795\n", " 1\n", - " 6.356627\n", + " 5.080119\n", " \n", " \n", " 26\n", - " 0.955564\n", + " 0.112839\n", " 1\n", - " 6.093250\n", + " 5.088389\n", " \n", " \n", " 27\n", - " 1.030763\n", + " 0.508712\n", " 1\n", - " 4.887715\n", + " 5.831991\n", " \n", " \n", " 28\n", - " 2.442606\n", + " 1.449329\n", " 1\n", - " 6.403713\n", + " 8.066816\n", " \n", " \n", " 29\n", - " -0.150845\n", + " 0.948285\n", " 1\n", - " 5.130774\n", + " 6.386133\n", " \n", " \n", " ...\n", @@ -1422,183 +916,183 @@ " \n", " \n", " 970\n", - " 0.481196\n", + " 0.733012\n", " 1\n", - " 6.080241\n", + " 6.507144\n", " \n", " \n", " 971\n", - " 2.352860\n", + " -0.970542\n", " 1\n", - " 6.123784\n", + " 3.359742\n", " \n", " \n", " 972\n", - " 1.984094\n", + " 1.451770\n", " 1\n", - " 6.242540\n", + " 7.869123\n", " \n", " \n", " 973\n", - " 0.565025\n", + " 0.086223\n", " 1\n", - " 5.809739\n", + " 4.587526\n", " \n", " \n", " 974\n", - " 1.148690\n", + " 0.062156\n", " 1\n", - " 6.491556\n", + " 5.341651\n", " \n", " \n", " 975\n", - " 0.788356\n", + " 1.178053\n", " 1\n", - " 5.775652\n", + " 7.150807\n", " \n", " \n", " 976\n", - " 0.755202\n", + " 0.038045\n", " 1\n", - " 5.803881\n", + " 4.709112\n", " \n", " \n", " 977\n", - " 0.727366\n", + " -0.420454\n", " 1\n", - " 3.142826\n", + " 3.717059\n", " \n", " \n", " 978\n", - " 0.705377\n", + " 0.498168\n", " 1\n", - " 5.054480\n", + " 5.969144\n", " \n", " \n", " 979\n", - " 1.259369\n", + " 0.179320\n", " 1\n", - " 5.687205\n", + " 5.106763\n", " \n", " \n", " 980\n", - " 0.216203\n", + " -1.081955\n", " 1\n", - " 5.359506\n", + " 3.072383\n", " \n", " \n", " 981\n", - " 0.717459\n", + " 0.792726\n", " 1\n", - " 6.772600\n", + " 6.142553\n", " \n", " \n", " 982\n", - " 1.122380\n", + " 0.817159\n", " 1\n", - " 6.254867\n", + " 6.614257\n", " \n", " \n", " 983\n", - " 2.483977\n", + " 0.277124\n", " 1\n", - " 6.512420\n", + " 5.478525\n", " \n", " \n", " 984\n", - " 1.550738\n", + " 1.139441\n", " 1\n", - " 4.693647\n", + " 7.124206\n", " \n", " \n", " 985\n", - " -0.610096\n", + " 2.636305\n", " 1\n", - " 4.108820\n", + " 9.761611\n", " \n", " \n", " 986\n", - " 0.469281\n", + " 0.849081\n", " 1\n", - " 4.089297\n", + " 6.316136\n", " \n", " \n", " 987\n", - " 1.385763\n", + " 2.231701\n", " 1\n", - " 6.056986\n", + " 9.592429\n", " \n", " \n", " 988\n", - " 2.566888\n", + " 1.065030\n", " 1\n", - " 5.531015\n", + " 6.767247\n", " \n", " \n", " 989\n", - " 0.109643\n", + " -0.134438\n", " 1\n", - " 5.653737\n", + " 4.681228\n", " \n", " \n", " 990\n", - " -1.001541\n", + " -2.540774\n", " 1\n", - " 4.191305\n", + " 0.112361\n", " \n", " \n", " 991\n", - " 1.827036\n", + " -0.368138\n", " 1\n", - " 6.285340\n", + " 4.395120\n", " \n", " \n", " 992\n", - " 0.693931\n", + " 1.050256\n", " 1\n", - " 4.383980\n", + " 7.209847\n", " \n", " \n", " 993\n", - " 1.658332\n", + " 0.669631\n", " 1\n", - " 5.419160\n", + " 6.251246\n", " \n", " \n", " 994\n", - " 1.531808\n", + " -0.508734\n", " 1\n", - " 6.741684\n", + " 4.251908\n", " \n", " \n", " 995\n", - " 2.011649\n", + " -0.103255\n", " 1\n", - " 7.785070\n", + " 5.060815\n", " \n", " \n", " 996\n", - " -0.407929\n", + " -0.906700\n", " 1\n", - " 4.738703\n", + " 3.055053\n", " \n", " \n", " 997\n", - " -0.076402\n", + " 0.156403\n", " 1\n", - " 3.615250\n", + " 5.534543\n", " \n", " \n", " 998\n", - " 1.302853\n", + " -0.276539\n", " 1\n", - " 9.341458\n", + " 4.520962\n", " \n", " \n", " 999\n", - " 1.987589\n", + " 0.608260\n", " 1\n", - " 5.110972\n", + " 6.146062\n", " \n", " \n", "\n", @@ -1607,72 +1101,72 @@ ], "text/plain": [ " X0 v y\n", - "0 1.521782 1 5.219948\n", - "1 1.393287 1 6.244632\n", - "2 1.744401 1 7.404562\n", - "3 -0.984910 1 5.158497\n", - "4 1.531897 1 4.609039\n", - "5 -0.414219 1 4.628114\n", - "6 3.440954 1 5.222432\n", - "7 0.419157 1 7.573535\n", - "8 0.336473 1 6.183994\n", - "9 -1.038132 1 3.692398\n", - "10 3.020776 1 5.733954\n", - "11 0.484034 1 8.189707\n", - "12 1.494246 1 5.556789\n", - "13 1.211798 1 7.776344\n", - "14 3.052140 1 6.377158\n", - "15 -0.294350 1 7.915583\n", - "16 -0.627547 1 3.188627\n", - "17 2.113168 1 5.123711\n", - "18 1.918390 1 6.868787\n", - "19 1.409377 1 5.932311\n", - "20 1.751635 1 5.963934\n", - "21 1.508836 1 3.310934\n", - "22 1.940414 1 3.648624\n", - "23 -0.708565 1 5.069651\n", - "24 0.953574 1 5.687417\n", - "25 0.737527 1 6.356627\n", - "26 0.955564 1 6.093250\n", - "27 1.030763 1 4.887715\n", - "28 2.442606 1 6.403713\n", - "29 -0.150845 1 5.130774\n", + "0 -0.679571 1 3.545042\n", + "1 -0.243537 1 4.665001\n", + "2 -0.195101 1 4.499220\n", + "3 0.923288 1 6.673522\n", + "4 0.389773 1 5.976617\n", + "5 0.345340 1 5.553557\n", + "6 0.539989 1 6.375131\n", + "7 1.294383 1 7.499560\n", + "8 -0.557656 1 3.655596\n", + "9 -0.581319 1 4.275826\n", + "10 0.089578 1 5.298058\n", + "11 0.826961 1 6.531107\n", + "12 -0.671221 1 3.541304\n", + "13 0.986791 1 7.004788\n", + "14 1.594109 1 7.981339\n", + "15 -0.245527 1 4.159227\n", + "16 0.178187 1 5.359598\n", + "17 2.024924 1 9.260993\n", + "18 -1.760241 1 1.503151\n", + "19 1.911770 1 8.763278\n", + "20 0.411533 1 6.014290\n", + "21 -1.601890 1 1.895632\n", + "22 -0.779602 1 3.509766\n", + "23 2.117118 1 9.529310\n", + "24 0.515388 1 5.812637\n", + "25 0.195795 1 5.080119\n", + "26 0.112839 1 5.088389\n", + "27 0.508712 1 5.831991\n", + "28 1.449329 1 8.066816\n", + "29 0.948285 1 6.386133\n", ".. ... .. ...\n", - "970 0.481196 1 6.080241\n", - "971 2.352860 1 6.123784\n", - "972 1.984094 1 6.242540\n", - "973 0.565025 1 5.809739\n", - "974 1.148690 1 6.491556\n", - "975 0.788356 1 5.775652\n", - "976 0.755202 1 5.803881\n", - "977 0.727366 1 3.142826\n", - "978 0.705377 1 5.054480\n", - "979 1.259369 1 5.687205\n", - "980 0.216203 1 5.359506\n", - "981 0.717459 1 6.772600\n", - "982 1.122380 1 6.254867\n", - "983 2.483977 1 6.512420\n", - "984 1.550738 1 4.693647\n", - "985 -0.610096 1 4.108820\n", - "986 0.469281 1 4.089297\n", - "987 1.385763 1 6.056986\n", - "988 2.566888 1 5.531015\n", - "989 0.109643 1 5.653737\n", - "990 -1.001541 1 4.191305\n", - "991 1.827036 1 6.285340\n", - "992 0.693931 1 4.383980\n", - "993 1.658332 1 5.419160\n", - "994 1.531808 1 6.741684\n", - "995 2.011649 1 7.785070\n", - "996 -0.407929 1 4.738703\n", - "997 -0.076402 1 3.615250\n", - "998 1.302853 1 9.341458\n", - "999 1.987589 1 5.110972\n", + "970 0.733012 1 6.507144\n", + "971 -0.970542 1 3.359742\n", + "972 1.451770 1 7.869123\n", + "973 0.086223 1 4.587526\n", + "974 0.062156 1 5.341651\n", + "975 1.178053 1 7.150807\n", + "976 0.038045 1 4.709112\n", + "977 -0.420454 1 3.717059\n", + "978 0.498168 1 5.969144\n", + "979 0.179320 1 5.106763\n", + "980 -1.081955 1 3.072383\n", + "981 0.792726 1 6.142553\n", + "982 0.817159 1 6.614257\n", + "983 0.277124 1 5.478525\n", + "984 1.139441 1 7.124206\n", + "985 2.636305 1 9.761611\n", + "986 0.849081 1 6.316136\n", + "987 2.231701 1 9.592429\n", + "988 1.065030 1 6.767247\n", + "989 -0.134438 1 4.681228\n", + "990 -2.540774 1 0.112361\n", + "991 -0.368138 1 4.395120\n", + "992 1.050256 1 7.209847\n", + "993 0.669631 1 6.251246\n", + "994 -0.508734 1 4.251908\n", + "995 -0.103255 1 5.060815\n", + "996 -0.906700 1 3.055053\n", + "997 0.156403 1 5.534543\n", + "998 -0.276539 1 4.520962\n", + "999 0.608260 1 6.146062\n", "\n", "[1000 rows x 3 columns]" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -1687,477 +1181,145 @@ "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
X0vy
01.5217820.0-0.096805
11.3932871.07.398900
21.7444011.06.409548
3-0.9849100.00.488064
41.5318970.0-0.790932
5-0.4142190.00.813474
63.4409541.07.405016
70.4191571.04.270433
80.3364731.04.983390
9-1.0381321.02.763467
103.0207761.07.402877
110.4840340.0-0.492809
121.4942461.05.931295
131.2117981.06.665465
143.0521401.08.518381
15-0.2943501.05.806599
16-0.6275471.05.498996
172.1131681.07.732707
181.9183901.05.952340
191.4093770.0-1.237780
201.7516351.04.865461
211.5088360.0-0.750060
221.9404140.02.005464
23-0.7085651.05.059955
240.9535740.0-0.336221
250.7375271.06.438027
260.9555640.00.405158
271.0307631.06.071627
282.4426060.02.045977
29-0.1508450.01.174115
............
9700.4811960.0-0.115205
9712.3528600.0-0.642925
9721.9840941.05.497744
9730.5650250.0-1.255980
9741.1486900.0-1.271845
9750.7883560.0-1.048446
9760.7552021.04.592698
9770.7273660.01.295704
9780.7053770.0-0.074376
9791.2593690.0-0.443612
9800.2162031.04.783068
9810.7174591.04.488495
9821.1223801.05.332776
9832.4839771.05.532828
9841.5507381.06.385995
985-0.6100960.0-1.054331
9860.4692811.06.648060
9871.3857631.06.473496
9882.5668881.07.911278
9890.1096430.0-1.490529
990-1.0015411.05.909003
9911.8270361.04.255412
9920.6939311.04.591450
9931.6583321.04.909992
9941.5318081.05.625160
9952.0116490.02.146866
996-0.4079290.00.144058
997-0.0764020.00.212049
9981.3028531.06.749981
9991.9875890.0-0.643916
\n", - "

1000 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " X0 v y\n", - "0 1.521782 0.0 -0.096805\n", - "1 1.393287 1.0 7.398900\n", - "2 1.744401 1.0 6.409548\n", - "3 -0.984910 0.0 0.488064\n", - "4 1.531897 0.0 -0.790932\n", - "5 -0.414219 0.0 0.813474\n", - "6 3.440954 1.0 7.405016\n", - "7 0.419157 1.0 4.270433\n", - "8 0.336473 1.0 4.983390\n", - "9 -1.038132 1.0 2.763467\n", - "10 3.020776 1.0 7.402877\n", - "11 0.484034 0.0 -0.492809\n", - "12 1.494246 1.0 5.931295\n", - "13 1.211798 1.0 6.665465\n", - "14 3.052140 1.0 8.518381\n", - "15 -0.294350 1.0 5.806599\n", - "16 -0.627547 1.0 5.498996\n", - "17 2.113168 1.0 7.732707\n", - "18 1.918390 1.0 5.952340\n", - "19 1.409377 0.0 -1.237780\n", - "20 1.751635 1.0 4.865461\n", - "21 1.508836 0.0 -0.750060\n", - "22 1.940414 0.0 2.005464\n", - "23 -0.708565 1.0 5.059955\n", - "24 0.953574 0.0 -0.336221\n", - "25 0.737527 1.0 6.438027\n", - "26 0.955564 0.0 0.405158\n", - "27 1.030763 1.0 6.071627\n", - "28 2.442606 0.0 2.045977\n", - "29 -0.150845 0.0 1.174115\n", - ".. ... ... ...\n", - "970 0.481196 0.0 -0.115205\n", - "971 2.352860 0.0 -0.642925\n", - "972 1.984094 1.0 5.497744\n", - "973 0.565025 0.0 -1.255980\n", - "974 1.148690 0.0 -1.271845\n", - "975 0.788356 0.0 -1.048446\n", - "976 0.755202 1.0 4.592698\n", - "977 0.727366 0.0 1.295704\n", - "978 0.705377 0.0 -0.074376\n", - "979 1.259369 0.0 -0.443612\n", - "980 0.216203 1.0 4.783068\n", - "981 0.717459 1.0 4.488495\n", - "982 1.122380 1.0 5.332776\n", - "983 2.483977 1.0 5.532828\n", - "984 1.550738 1.0 6.385995\n", - "985 -0.610096 0.0 -1.054331\n", - "986 0.469281 1.0 6.648060\n", - "987 1.385763 1.0 6.473496\n", - "988 2.566888 1.0 7.911278\n", - "989 0.109643 0.0 -1.490529\n", - "990 -1.001541 1.0 5.909003\n", - "991 1.827036 1.0 4.255412\n", - "992 0.693931 1.0 4.591450\n", - "993 1.658332 1.0 4.909992\n", - "994 1.531808 1.0 5.625160\n", - "995 2.011649 0.0 2.146866\n", - "996 -0.407929 0.0 0.144058\n", - "997 -0.076402 0.0 0.212049\n", - "998 1.302853 1.0 6.749981\n", - "999 1.987589 0.0 -0.643916\n", - "\n", - "[1000 rows x 3 columns]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:dowhy.do_why:Causal Graph not provided. DoWhy will construct a graph based on data inputs.\n", + "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'U', 'X0'}\n", + "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n", + "INFO:dowhy.do_sampler:Using McmcSampler for do sampling.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['X0']\n", + "yes\n", + "{'observed': 'yes'}\n", + "Model to find the causal effect of treatment v on outcome y\n", + "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", + "All common causes are observed. Causal effect can be identified.\n", + "McmcSampler\n", + "treatments ['v']\n", + "backdoor ['X0']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pymc3:Auto-assigning NUTS sampler...\n", + "INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n", + "INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n", + "INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n", + "Sampling 4 chains: 100%|██████████| 8000/8000 [00:08<00:00, 919.62draws/s] \n", + "INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'U', 'X0'}\n", + "INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'label': 'Unobserved Confounders', 'observed': 'no'}\n", + "All common causes are observed. Causal effect can be identified.\n", + "McmcSampler\n" + ] } ], "source": [ - "cdf" + "cdf_1 = cdf.causal.do(x={'v': 1}, \n", + " variable_types={'v': 'b', 'y': 'c', 'X0': 'c'}, \n", + " outcome='y',\n", + " method='mcmc', \n", + " common_causes=['X0'],\n", + " proceed_when_unidentifiable=True,\n", + " use_previous_sampler=False)\n", + "cdf_0 = cdf.causal.do(x={'v': 0}, \n", + " variable_types={'v': 'b', 'y': 'c', 'X0': 'c'}, \n", + " outcome='y',\n", + " method='mcmc', \n", + " common_causes=['X0'],\n", + " proceed_when_unidentifiable=True,\n", + " use_previous_sampler=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(cdf_1['y'] - cdf_0['y']).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "1.96*(cdf_1['y'] - cdf_0['y']).std() / np.sqrt(len(cdf))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "model = OLS(df['y'], df[['X0', 'v']])\n", + "result = model.fit()\n", + "result.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "cdf_1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "cdf_0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cdf_do = cdf.causal.do(x={'v': 0}, \n", + " variable_types={'v': 'b', 'y': 'c', 'X0': 'c'}, \n", + " outcome='y',\n", + " method='mcmc', \n", + " common_causes=['X0'],\n", + " proceed_when_unidentifiable=True,\n", + " keep_original_treatment=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cdf_do" ] }, {