<a href="https://colab.research.google.com/github/pharringtonp19/mecon/blob/main/notebooks/practice_final/Teaching_Supply.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
#@title **Imports** { display-mode: "form" }
import jax 
import jax.numpy as jnp 
import matplotlib.pyplot as plt 
from functools import partial

### **Dot Products**
In this model, take the dot product between two vectors. We've actually seen this before when we talked about probabiliity, but I want to make sure that everyone has seen how to write it on the computer. To begin, lets say we have two vectors, $x,y \in \mathcal{R}^n$. Then we can take the dot product, $\langle x, y \rangle$, as follows.

\begin{align*}
\langle x, y \rangle = \sum _{i=1}^n x_i y_i
\end{align*}

on the computer, we can express this as follows `jnp.dot(x,y)`. 

### **Matrices**
"Stacking vectors". Let's say we have multiple vectors, $x, y, z \in \mathcal{R}^n$. Then we can stack these vectors to form a matrix as follows `jnp.vstack((x,y,z))`. 

### **Argmax**

As we've mentioned throughout the course, a vector can be thought of as a function defined over a subset of the natural numbers. Thefore taking the `jnp.argmax` of a vector is like taking the `jnp.argmax` of a function. We want to select the index of the vector with the largest corresponding value. 

### **Model**
We're going to model individuals as selecting a job to maximize their utilty. 

\begin{align*}
\underset{x \in {\textrm{Jobs}}}{\textrm{maximize}} \ U_{\alpha}(x)
\end{align*}

### **Jobs**

We're going to represent a job as a vector in $\mathcal{R}^5$

In [112]:
# Wage, Flexibility, Social, Kids, Security 
teacher = jnp.array([0.1, 0.8, 0.5, 1.0, 0.5])
construction = jnp.array([0.5, 0.2, 1.0, 0.1, 0.5])
gov_official = jnp.array([0.7, 0.2, 0.4, 0.2, 1.])
data_scientist = jnp.array([0.9, 0.2, 0.4, 0.2, 0.3])
designer = jnp.array([0.5, 0.3, 0.6, 0.3, 0.2])
jobs = jnp.vstack((teacher, construction, gov_official, data_scientist, designer))


# Wage, Flexibility, Social, Kids, Security 
teacher = jnp.array([0.1, 0.8, 0.5, 1.0, 0.5])
construction = jnp.array([0.5, 0.4, 1.0, 0.1, 0.5])
gov_official = jnp.array([0.7, 0.9, 0.4, 0.2, 1.])
data_scientist = jnp.array([0.9, 0.9, 0.4, 0.2, 0.3])
designer = jnp.array([0.5, 0.3, 0.6, 0.3, 0.2])
jobs_post = jnp.vstack((teacher, construction, gov_official, data_scientist, designer))

In [115]:
def preferences_fn(key):
  return jax.random.uniform(key, shape=(len(teacher),), minval=0., maxval=1.)

def job_value_fn(job, preferences):
  return jnp.dot(job, preferences)

def optimal_job_fn(jobs, preferences):
  value_of_jobs = jax.vmap(partial(job_value_fn, preferences=preferences))(jobs)
  return jnp.argmax(value_of_jobs) 

def f(jobs, key):
  preferences = preferences_fn(key)
  return optimal_job_fn(jobs, preferences)

In [116]:
init_key = jax.random.PRNGKey(0)                                                # Initial key
key_vec = jax.random.split(init_key, 1000)                                      # Split initial key to generate vector of keys
f_jobs = partial(f, jobs)                                                  # Partially Evaluate Function 
f_jobs_vec = jax.vmap(f_jobs)                                                 # Vectorize the function
selected_jobs = f_jobs_vec(key_vec)
print(jnp.mean(jax.vmap(lambda x: x==0)(selected_jobs)))

0.65500003


In [117]:
init_key = jax.random.PRNGKey(0)                                                # Initial key
key_vec = jax.random.split(init_key, 1000)                                      # Split initial key to generate vector of keys
f_jobs = partial(f, jobs_post)                                                  # Partially Evaluate Function 
f_jobs_vec = jax.vmap(f_jobs)                                                 # Vectorize the function
selected_jobs = f_jobs_vec(key_vec)
print(jnp.mean(jax.vmap(lambda x: x==0)(selected_jobs)))

0.319
