# Learning networks

In [1]:
using MLJ, StableRNGs
import DataFrames
@load RidgeRegressor pkg=MultivariateStats

┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/sandhya/.julia/packages/MLJModels/E8BbE/src/loading.jl:168


import MLJMultivariateStatsInterface

┌ Info: Precompiling MLJMultivariateStatsInterface [1b6a4a23-ba22-4f51-9698-8599985d3728]
└ @ Base loading.jl:1317


 ✔


MLJMultivariateStatsInterface.RidgeRegressor

In [2]:
rng = StableRNG(551234) # for reproducibility

x1 = rand(rng, 300)
x2 = rand(rng, 300)
x3 = rand(rng, 300)
y = exp.(x1 - x2 -2x3 + 0.1*rand(rng, 300))

X = DataFrames.DataFrame(x1=x1, x2=x2, x3=x3)
first(X, 3) |> pretty

┌────────────┬────────────┬────────────┐
│[1m x1         [0m│[1m x2         [0m│[1m x3         [0m│
│[90m Float64    [0m│[90m Float64    [0m│[90m Float64    [0m│
│[90m Continuous [0m│[90m Continuous [0m│[90m Continuous [0m│
├────────────┼────────────┼────────────┤
│ 0.984002   │ 0.771482   │ 0.232099   │
│ 0.891795   │ 0.747399   │ 0.770914   │
│ 0.806395   │ 0.0182751  │ 0.0721645  │
└────────────┴────────────┴────────────┘


In [3]:
test, train = partition(eachindex(y), 0.8);

## Defining a learning network
In MLJ, a learning network is a directed acyclic graph (DAG) whose nodes apply trained or untrained operations such as a predict or transform (trained) or +, vcat etc. (untrained). Learning networks can be seen as pipelines on steroids.

Let's consider the following simple DAG:

Operation DAG

It corresponds to a fairly standard regression workflow: the data is standardized, the target is transformed using a Box-Cox transformation, a ridge regression is applied and the result is converted back by inverting the transform.

Note: actually this DAG is simple enough that it could also have been done with a pipeline.

In [4]:
Xs = source(X)
ys = source(y)

[34mSource @982[39m ⏎ `AbstractVector{Continuous}`

In [5]:
stand = machine(Standardizer(), Xs)
W = transform(stand, Xs)

[34mNode{Machine{Standardizer,…}} @520[39m
  args:
    1:	[34mSource @348[39m
  formula:
    transform(
        [0m[1m[34mMachine{Standardizer,…} @643[39m[22m, 
        [34mSource @348[39m)

In [6]:
fit!(W, rows=train);

┌ Info: Training [34mMachine{Standardizer,…} @643[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342


In [7]:
W()             # transforms all data
W(rows=test, )  # transforms only test data
W(X[3:4, :])    # transforms specific data

Unnamed: 0_level_0,x1,x2,x3
Unnamed: 0_level_1,Float64,Float64,Float64
1,0.856967,-1.59115,-1.48215
2,-1.06436,-1.5056,-0.234452


In [16]:
box_model = UnivariateBoxCoxTransformer()
box = machine(box_model, ys)
z = transform(box, ys)

ridge_model = MLJMultivariateStatsInterface.RidgeRegressor(lambda=0.1)
ridge = machine(ridge_model, W, z)
ẑ = MLJ.predict(ridge, W)

ŷ = inverse_transform(box, ẑ)

[34mNode{Machine{UnivariateBoxCoxTransformer,…}} @876[39m
  args:
    1:	[34mNode{Machine{RidgeRegressor,…}} @426[39m
  formula:
    inverse_transform(
        [0m[1m[34mMachine{UnivariateBoxCoxTransformer,…} @332[39m[22m, 
        predict(
            [0m[1m[34mMachine{RidgeRegressor,…} @589[39m[22m, 
            transform(
                [0m[1m[34mMachine{Standardizer,…} @643[39m[22m, 
                [34mSource @348[39m)))

In [14]:
info("RidgeRegressor", pkg="MultivariateStats").load_path

"MLJMultivariateStatsInterface.RidgeRegressor"

In [17]:
MLJ.fit!(ŷ, rows=train);

┌ Info: Training [34mMachine{UnivariateBoxCoxTransformer,…} @332[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342
┌ Info: Not retraining [34mMachine{Standardizer,…} @643[39m. Use `force=true` to force.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:345
┌ Info: Training [34mMachine{RidgeRegressor,…} @589[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342


In [18]:
rms(y[test], ŷ(rows=test))

0.033604963634078514

## Modifying hyperparameters

In [19]:
ridge_model.lambda = 5.0;

In [20]:
MLJ.fit!(ŷ, rows=train)
rms(y[test], ŷ(rows=test))

┌ Info: Not retraining [34mMachine{UnivariateBoxCoxTransformer,…} @332[39m. Use `force=true` to force.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:345
┌ Info: Not retraining [34mMachine{Standardizer,…} @643[39m. Use `force=true` to force.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:345
┌ Info: Updating [34mMachine{RidgeRegressor,…} @589[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:343


0.03834272597361202

## "Arrow" syntax

In [21]:
W = X |> Standardizer()
z = y |> UnivariateBoxCoxTransformer()

[34mNode{Machine{UnivariateBoxCoxTransformer,…}} @388[39m
  args:
    1:	[34mSource @986[39m
  formula:
    transform(
        [0m[1m[34mMachine{UnivariateBoxCoxTransformer,…} @457[39m[22m, 
        [34mSource @986[39m)

In [24]:
ẑ = (W, z) |> MLJMultivariateStatsInterface.RidgeRegressor(lambda=0.1);

In [25]:
ŷ = ẑ |> inverse_transform(z);

In [28]:
MLJ.fit!(ŷ, rows=train)
rms(y[test], ŷ(rows=test))

┌ Info: Training [34mMachine{UnivariateBoxCoxTransformer,…} @457[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342
┌ Info: Training [34mMachine{Standardizer,…} @529[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342
┌ Info: Training [34mMachine{RidgeRegressor,…} @409[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:342


0.033604963634078514

In [29]:
ẑ[:lambda] = 5.0;

In [30]:
ẑ.machine.model.lambda = 5.0;

In [31]:
MLJ.fit!(ŷ, rows=train)
rms(y[test], ŷ(rows=test))

┌ Info: Not retraining [34mMachine{UnivariateBoxCoxTransformer,…} @457[39m. Use `force=true` to force.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:345
┌ Info: Not retraining [34mMachine{Standardizer,…} @529[39m. Use `force=true` to force.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:345
┌ Info: Updating [34mMachine{RidgeRegressor,…} @409[39m.
└ @ MLJBase /home/sandhya/.julia/packages/MLJBase/pCCd7/src/machines.jl:343


0.03834272597361202