# Automatic Relevance Determination
The bayesian approach to feature selection is to use Automatic Relevance Determination.  Rather than specifying the prior parameters up-front, these can be optimized based on the observed data.  When the amount of data is small relative to the number of features, this technique effectively prunes features without explanatory power from the model.

The idea behind ARD/Empirical Bayes is to implement Occam's Razor by maximizing the *marginal likelihood* - AKA the *evidence*.  The evidence represents the probability that the dataset is observed, given a model.  A model in this case is determined by particular values of parameters for the prior.  Maximization therefore corresponds to finding the optimal values of these hyperparameters.

Variational inference corresponds to optimization of the *Evidence Lower BOund*.  The function that is optimized represents a lower bound on the evidence.  The hope is that when this function is optimized, the actual evidence is optimized as well.  We'll take the pragmatic approach by assuming that this is indeed the case.

In this notebook, we explore what happens when we have two parameters.  One relevant, one irrelevant.
Each parameter gets a hyperparameter in the form of a standard deviation for its prior.

In [1]:
interp.repositories() ++= Seq(
    coursier.MavenRepository("https://dl.bintray.com/scala-infer/maven")
)

In [2]:
import $ivy.`scala-infer::scala-infer:0.3`
import $ivy.`org.jupyter-scala::kernel-api:0.4.1`

[32mimport [39m[36m$ivy.$                             
[39m
[32mimport [39m[36m$ivy.$                                    [39m

In [3]:
import scappla._
import scappla.Functions._
import scappla.distributions._
import scappla.guides._
import scappla.optimization._
import scappla.tensor.Tensor._
import scappla.tensor._

[32mimport [39m[36mscappla._
[39m
[32mimport [39m[36mscappla.Functions._
[39m
[32mimport [39m[36mscappla.distributions._
[39m
[32mimport [39m[36mscappla.guides._
[39m
[32mimport [39m[36mscappla.optimization._
[39m
[32mimport [39m[36mscappla.tensor.Tensor._
[39m
[32mimport [39m[36mscappla.tensor._[39m

In [4]:
import scala.util.Random

[32mimport [39m[36mscala.util.Random[39m

In [5]:
case class Record(a: Float, b: Float, y: Float)

// Tensor shape - let's make it typed!
case class Batch(size: Int) extends Dim[Batch]
val batch = Batch(1000)

val (a_vals, b_vals, y_vals)= {{
    val a_weight = 1.0
    val b_weight = 0.0
    val noise = 0.5
    
    val data = for { _ <- 0 until batch.size } yield {
        val a = Random.nextGaussian()
        val b = Random.nextGaussian()
        val y = a_weight * a + noise
        Record(a.toFloat, b.toFloat, y.toFloat)
    }

    (
        Value(ArrayTensor(batch.sizes, data.map { _.a }.toArray), batch),
        Value(ArrayTensor(batch.sizes, data.map { _.b }.toArray), batch),
        Value(ArrayTensor(batch.sizes, data.map { _.y }.toArray), batch)
    )
}}


defined [32mclass[39m [36mRecord[39m
defined [32mclass[39m [36mBatch[39m
[36mbatch[39m: [32mBatch[39m = [33mBatch[39m([32m1000[39m)
[36ma_vals[39m: [32mValue[39m[[32mArrayTensor[39m, [32mBatch[39m] = scappla.Constant@12f3f093
[36mb_vals[39m: [32mValue[39m[[32mArrayTensor[39m, [32mBatch[39m] = scappla.Constant@31e80b51
[36my_vals[39m: [32mValue[39m[[32mArrayTensor[39m, [32mBatch[39m] = scappla.Constant@5f6ae152

In [6]:
val a_prior_s = Param(0.0)
val b_prior_s = Param(0.0)

val a_post_mu = Param(0.0)
val a_post_s = Param(0.0)
val a_guide = ReparamGuide(Normal(a_post_mu, exp(a_post_s)))

val b_post_mu = Param(0.0)
val b_post_s = Param(0.0)
val b_guide = ReparamGuide(Normal(b_post_mu, exp(b_post_s)))

val noise_mu = Param(0.0)
val noise_s = Param(0.0)
val noise_guide = ReparamGuide(Normal(noise_mu, exp(noise_s)))

val model = infer {
    val a_weight = sample(Normal(0.0, exp(a_prior_s)), a_guide)
    val b_weight = sample(Normal(0.0, exp(b_prior_s)), b_guide)
    val noise = sample(Normal(0.0, 1.0), noise_guide)
    
    observe(Normal(
        broadcast(a_weight, batch) * a_vals
        + broadcast(b_weight, batch) * b_vals,
        broadcast[Batch, ArrayTensor](exp(noise), batch)
    ), y_vals)
}

[36ma_prior_s[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@7e13f371
[36mb_prior_s[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@314da4f9
[36ma_post_mu[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@33764907
[36ma_post_s[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@3a272583
[36ma_guide[39m: [32mReparamGuide[39m[[32mDouble[39m, [32mUnit[39m] = [33mReparamGuide[39m(
  [33mNormal[39m(scappla.Param@33764907, [33mApply1[39m(scappla.Param@3a272583, <function1>))
)
[36mb_post_mu[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@5cbb302d
[36mb_post_s[39m: [32mParam[39m[[32mDouble[39m, [32mUnit[39m] = scappla.Param@5d9b0110
[36mb_guide[39m: [32mReparamGuide[39m[[32mDouble[39m, [32mUnit[39m] = [33mReparamGuide[39m(
  [33mNormal[39m(scappla.Param@5cbb302d, [33mApply1[39m(scappla.Param@5d9b0110, <function1>))
)
[36mnoise_mu[39m: [

In [7]:
val opt = new Adam(0.1)
val interpreter = new OptimizingInterpreter(opt)

[36mopt[39m: [32mAdam[39m = scappla.optimization.Adam@7bc3e74f
[36minterpreter[39m: [32mOptimizingInterpreter[39m = scappla.OptimizingInterpreter@115f4b3

In [15]:
for { _ <- 0 until 10000 } {
    interpreter.reset()
    model.sample(interpreter)
}

In [17]:
val params = Seq(
    "a_prior" -> exp(a_prior_s),
    "b_prior" -> exp(b_prior_s),
    "a_post_mu" -> a_post_mu,
    "a_post_s" -> exp(a_post_s),
    "b_post_mu" -> b_post_mu,
    "b_post_s" -> exp(b_post_s),
    "noise_mu" -> exp(noise_mu),
    "noise_s" -> exp(noise_s)
)

[36mparams[39m: [32mSeq[39m[([32mString[39m, [32mExpr[39m[[32mDouble[39m, [32mUnit[39m])] = [33mList[39m(
  ([32m"a_prior"[39m, [33mApply1[39m(scappla.Param@7e13f371, <function1>)),
  ([32m"b_prior"[39m, [33mApply1[39m(scappla.Param@314da4f9, <function1>)),
  ([32m"a_post_mu"[39m, scappla.Param@33764907),
  ([32m"a_post_s"[39m, [33mApply1[39m(scappla.Param@3a272583, <function1>)),
  ([32m"b_post_mu"[39m, scappla.Param@5cbb302d),
  ([32m"b_post_s"[39m, [33mApply1[39m(scappla.Param@5d9b0110, <function1>)),
  ([32m"noise_mu"[39m, [33mApply1[39m(scappla.Param@33f65cad, <function1>)),
  ([32m"noise_s"[39m, [33mApply1[39m(scappla.Param@58927bd0, <function1>))
)

In [18]:
for { (name, param) <- params } {
    println(s"$name : ${interpreter.eval(param).v}")
}

a_prior : 0.9748566249159144
b_prior : 0.016316337237029584
a_post_mu : 0.9810327294740097
a_post_s : 0.015555917424498474
b_post_mu : 0.010369466775204813
b_post_s : 0.010296939008060712
noise_mu : 0.5004524053050853
noise_s : 0.022398470675536467
