This notebook demonstrates parameterizing a distribution with a neural network as a means of uncertainty quantification.

We consider the [OLS Regression Challenge](https://data.world/nrippner/ols-regression-challenge) which aims to predict cancer death rates given a range of socioeconomic factors. 

Example: https://romainstrock.com/blog/modeling-uncertainty-with-pytorch.html

**Data Dictionary**
- TARGET_deathRate: Dependent variable. Mean per capita (100,000) cancer mortalities (a)
- avgAnnCount: Mean number of reported cases of cancer diagnosed annually (a)
- avgDeathsPerYear: Mean number of reported mortalities due to cancer (a)
- incidenceRate: Mean per capita (100,000) cancer diagoses (a)
- medianIncome: Median income per county (b)
- popEst2015: Population of county (b)
- povertyPercent: Percent of populace in poverty (b)
- studyPerCap: Per capita number of cancer-related clinical trials per county (a)
- binnedInc: Median income per capita binned by decile (b)
- MedianAge: Median age of county residents (b)
- MedianAgeMale: Median age of male county residents (b)
- MedianAgeFemale: Median age of female county residents (b)
- Geography: County name (b)
- AvgHouseholdSize: Mean household size of county (b)
- PercentMarried: Percent of county residents who are married (b)
- PctNoHS18_24: Percent of county residents ages 18-24 highest education attained: less than high school (b)
- PctHS18_24: Percent of county residents ages 18-24 highest education attained: high school diploma (b)
- PctSomeCol18_24: Percent of county residents ages 18-24 highest education attained: some college (b)
- PctBachDeg18_24: Percent of county residents ages 18-24 highest education attained: bachelor's degree (b)
- PctHS25_Over: Percent of county residents ages 25 and over highest education attained: high school diploma (b)
- PctBachDeg25_Over: Percent of county residents ages 25 and over highest education attained: bachelor's degree (b)
- PctEmployed16_Over: Percent of county residents ages 16 and over employed (b)
- PctUnemployed16_Over: Percent of county residents ages 16 and over unemployed (b)
- PctPrivateCoverage: Percent of county residents with private health coverage (b)
- PctPrivateCoverageAlone: Percent of county residents with private health coverage alone (no public assistance) (b)
- PctEmpPrivCoverage: Percent of county residents with employee-provided private health coverage (b)
- PctPublicCoverage: Percent of county residents with government-provided health coverage (b)
- PctPubliceCoverageAlone: Percent of county residents with government-provided health coverage alone (b)
- PctWhite: Percent of county residents who identify as White (b)
- PctBlack: Percent of county residents who identify as Black (b)
- PctAsian: Percent of county residents who identify as Asian (b)
- PctOtherRace: Percent of county residents who identify in a category which is not White, Black, or Asian (b)
- PctMarriedHouseholds: Percent of married households (b)
- BirthRate: Number of live births relative to number of women in county (b)

 to nuer of women in county (b)
Notes:

(a): years 2010-2016
(b): 2013 Census Estimates

## Imports

In [16]:
import os
import pandas as pd

import torch
from torch.utils.data import Dataset
from torch.nn import Embedding

import matplotlib.pyplot as plt

## Dataset

In [20]:
df = pd.read_csv(os.path.join("data", "cancer_reg.csv"))
df.head()

Unnamed: 0,avgAnnCount,avgDeathsPerYear,TARGET_deathRate,incidenceRate,medIncome,popEst2015,povertyPercent,studyPerCap,binnedInc,MedianAge,...,PctPrivateCoverageAlone,PctEmpPrivCoverage,PctPublicCoverage,PctPublicCoverageAlone,PctWhite,PctBlack,PctAsian,PctOtherRace,PctMarriedHouseholds,BirthRate
0,1397.0,469,164.9,489.8,61898,260131,11.2,499.748204,"(61494.5, 125635]",39.3,...,,41.6,32.9,14.0,81.780529,2.594728,4.821857,1.843479,52.856076,6.118831
1,173.0,70,161.3,411.6,48127,43269,18.6,23.111234,"(48021.6, 51046.4]",33.0,...,53.8,43.6,31.1,15.3,89.228509,0.969102,2.246233,3.741352,45.3725,4.333096
2,102.0,50,174.7,349.7,49348,21026,14.6,47.560164,"(48021.6, 51046.4]",45.0,...,43.5,34.9,42.1,21.1,90.92219,0.739673,0.465898,2.747358,54.444868,3.729488
3,427.0,202,194.8,430.4,44243,75882,17.1,342.637253,"(42724.4, 45201]",42.8,...,40.3,35.0,45.3,25.0,91.744686,0.782626,1.161359,1.362643,51.021514,4.603841
4,57.0,26,144.4,350.1,49955,10321,12.5,0.0,"(48021.6, 51046.4]",48.3,...,43.9,35.1,44.0,22.7,94.104024,0.270192,0.66583,0.492135,54.02746,6.796657


In [11]:
df.describe()

Unnamed: 0,avgAnnCount,avgDeathsPerYear,TARGET_deathRate,incidenceRate,medIncome,popEst2015,povertyPercent,studyPerCap,MedianAge,MedianAgeMale,...,PctPrivateCoverageAlone,PctEmpPrivCoverage,PctPublicCoverage,PctPublicCoverageAlone,PctWhite,PctBlack,PctAsian,PctOtherRace,PctMarriedHouseholds,BirthRate
count,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,...,2438.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0,3047.0
mean,606.338544,185.965868,178.664063,448.268586,47063.281917,102637.4,16.878175,155.399415,45.272333,39.570725,...,48.453774,41.196324,36.252642,19.240072,83.645286,9.107978,1.253965,1.983523,51.243872,5.640306
std,1416.356223,504.134286,27.751511,54.560733,12040.090836,329059.2,6.409087,529.628366,45.30448,5.226017,...,10.083006,9.447687,7.841741,6.113041,16.380025,14.534538,2.610276,3.51771,6.572814,1.985816
min,6.0,3.0,59.7,201.3,22640.0,827.0,3.2,0.0,22.3,22.4,...,15.7,13.5,11.2,2.6,10.199155,0.0,0.0,0.0,22.99249,0.0
25%,76.0,28.0,161.2,420.3,38882.5,11684.0,12.15,0.0,37.7,36.35,...,41.0,34.5,30.9,14.85,77.29618,0.620675,0.254199,0.295172,47.763063,4.521419
50%,171.0,61.0,178.1,453.549422,45207.0,26643.0,15.9,0.0,41.0,39.6,...,48.7,41.1,36.3,18.8,90.059774,2.247576,0.549812,0.826185,51.669941,5.381478
75%,518.0,149.0,195.2,480.85,52492.0,68671.0,20.4,83.650776,44.0,42.5,...,55.6,47.7,41.55,23.1,95.451693,10.509732,1.221037,2.17796,55.395132,6.493677
max,38150.0,14010.0,362.8,1206.9,125635.0,10170290.0,47.4,9762.308998,624.0,64.7,...,78.9,70.7,65.1,46.6,100.0,85.947799,42.619425,41.930251,78.075397,21.326165


### Train/Test

In [21]:
class TabularDataset(Dataset):
    def __init__(self, x: torch.Tensor, y: torch.Tensor):
        self.data = x
        self.labels = y

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx, :], self.labels[idx]

In [26]:
x = torch.randn(20, 10)
y = torch.randn(20)

dataset = TabularDataset(x, y)