In [2]:
// :dep tch = { version = "0.1.7", path = "../../LaurentMazare/tch-rs" }
// use tch::{Tensor, IndexOp, NewAxis};
// let x: Tensor = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
// x.print();
// let y: Tensor = x.i((.., 0, NewAxis));
// y.print();
// x.print();
// println!("{:?}", y.size());

# Quantile regression

Python code is [here](https://gist.github.com/taku-y/471b7eae5ef85badb1ddcfb389982bd6).

In [3]:
:dep tch = { version = "0.1.7", path = "../../LaurentMazare/tch-rs" }
// :dep plotters = { git = "https://github.com/38/plotters", features = ["evcxr"] }
:dep plotters = { path = "../../38/plotters", features = ["evcxr"] }
:dep qr = { path = "./" }

In [4]:
use tch::{nn, nn::ModuleT, Tensor, Kind};
use qr::{create_model, train_qr_model, load_data};
use plotters::prelude::*;

In [5]:
fn plot_result(xs: &Tensor, ys: &Tensor) -> plotters::evcxr::SVGWrapper {
    evcxr_figure((640, 480), |root| {
    
        // The following code will create a chart context
        let mut chart = ChartBuilder::on(&root)
            .caption("Training data", ("Arial", 20).into_font())
            .x_label_area_size(40)
            .y_label_area_size(40)
            .build_ranged(12f64..22f64, -2f64..3f64)?;
    
        chart.configure_mesh()
            .disable_x_mesh()
            .disable_y_mesh()
            .draw()?;
        
        // Plot prediction
        let n_points = xs.size()[0];
        for j in 0..=6 {
            chart.draw_series(LineSeries::new((0..n_points).map(|i|
                (xs.double_value(&[i, 0]), ys.double_value(&[i, j]))),
                &BLUE
            ));
        }
        
        // Scatter plot
        let (xs, ys) = load_data().unwrap();
        let n_points = xs.size()[0];
        chart.draw_series((0..n_points).map(|i| Circle::new(
            (xs.double_value(&[i]), ys.double_value(&[i])), 3, BLUE.filled()
        )));

        Ok(())
    }).style("width:60%")
}

In [8]:
fn train(cumsum: bool) -> plotters::evcxr::SVGWrapper {
    // Device
    let device = tch::Device::cuda_if_available();
    let vs = nn::VarStore::new(device);
    
    // Initialize model
    let n_inputs = 1i64;
    let n_quants = 7i64;
    let model = create_model(&vs.root(), n_inputs, n_quants, vec![8], cumsum);

    // Training loop
    // the type impl tch::nn::module::ModuleT cannot currently be persisted
    let model = train_qr_model(model, n_quants, device, &vs);
    
    // Prediction
    let xs = Tensor::arange2(10.0, 24.0, 0.2, (Kind::Float, device)).view_(&[-1, 1]);
    let ys = xs.apply_t(&model, false); // train = false in 2nd arg
    println!("{:?}", ys.size());
    
    // Plot prediction and data
    plot_result(&xs, &ys)
}

## Without cumsum

Quantile regression curves are estimated independently, thus the order of regression curves may be inconsistent.

In [12]:
train(false)

Quantiles
 0.1250
 0.2500
 0.3750
 0.5000
 0.6250
 0.7500
 0.8750
[ CPUFloatType{7} ]
iter =     0, loss = 0.36304357647895813
iter =  2000, loss = 0.20935633778572083
iter =  4000, loss = 0.2083345651626587
iter =  6000, loss = 0.2075996696949005
iter =  8000, loss = 0.20779283344745636
iter = 10000, loss = 0.2074453979730606
iter = 12000, loss = 0.2065182477235794
iter = 14000, loss = 0.20544277131557465
iter = 16000, loss = 0.20386028289794922
iter = 18000, loss = 0.20136895775794983
[70, 7]


## With cumsum

Regression curves are cumulative sum of positive values excepting the lowest quantile, so they are always ordered from lower to higer quantiles. See the implementation of the output layer in `create_model()` in `lib.rs`.

In [17]:
train(true)

Quantiles
 0.1250
 0.2500
 0.3750
 0.5000
 0.6250
 0.7500
 0.8750
[ CPUFloatType{7} ]
iter =     0, loss = 0.8314995169639587
iter =  2000, loss = 0.2114652693271637
iter =  4000, loss = 0.20878703892230988
iter =  6000, loss = 0.20742233097553253
iter =  8000, loss = 0.20734447240829468
iter = 10000, loss = 0.2072826772928238
iter = 12000, loss = 0.2068628966808319
iter = 14000, loss = 0.20639674365520477
iter = 16000, loss = 0.20528718829154968
iter = 18000, loss = 0.20336098968982697
[70, 7]
