In [2]:
:dep tch = { version = "0.1.7", path = "../../LaurentMazare/tch-rs" }
// :dep plotters = { git = "https://github.com/38/plotters", features = ["evcxr"] }
:dep plotters = { path = "../../38/plotters", features = ["evcxr"] }
:dep qr = { path = "./" }

In [3]:
use tch::{nn, nn::ModuleT, Tensor, Kind};
use qr::{create_model, train_qr_model, load_data};
use plotters::prelude::*;

In [10]:
fn plot_result(xs: &Tensor, ys: &Tensor) -> plotters::evcxr::SVGWrapper {
    evcxr_figure((640, 480), |root| {
    
        // The following code will create a chart context
        let mut chart = ChartBuilder::on(&root)
            .caption("Training data", ("Arial", 20).into_font())
            .x_label_area_size(40)
            .y_label_area_size(40)
            .build_ranged(12f64..22f64, -2f64..3f64)?;
    
        chart.configure_mesh()
            .disable_x_mesh()
            .disable_y_mesh()
            .draw()?;
        
        // Plot prediction
        let n_points = xs.size()[0];
        for j in 0..=6 {
            chart.draw_series(LineSeries::new((0..n_points).map(|i|
                (xs.double_value(&[i, 0]), ys.double_value(&[i, j]))),
                &BLUE
            ));
        }
        
        // Scatter plot
        let (xs, ys) = load_data().unwrap();
        let n_points = xs.size()[0];
        chart.draw_series((0..n_points).map(|i| Circle::new(
            (xs.double_value(&[i]), ys.double_value(&[i])), 3, BLUE.filled()
        )));

        Ok(())
    }).style("width:60%")
}

In [11]:
fn train() -> plotters::evcxr::SVGWrapper {
    // Device
    let device = tch::Device::cuda_if_available();
    let vs = nn::VarStore::new(device);
    
    // Initialize model
    let n_inputs = 1i64;
    let n_quants = 7i64;
    let model = create_model(&vs.root(), n_inputs, n_quants, vec![8]);

    // Training loop
    // the type impl tch::nn::module::ModuleT cannot currently be persisted
    let model = train_qr_model(model, n_quants, device, &vs);
    
    // Prediction
    let xs = Tensor::arange2(10.0, 24.0, 0.2, (Kind::Float, device)).view_(&[-1, 1]);
    let ys = xs.apply_t(&model, false); // train = false in 2nd arg
    println!("{:?}", ys.size());
    
    // Plot prediction and data
    plot_result(&xs, &ys)
}

In [12]:
train()

Quantiles
 0.1250
 0.2500
 0.3750
 0.5000
 0.6250
 0.7500
 0.8750
[ CPUFloatType{7} ]
iter =     0, loss = 0.3218071460723877
iter =  2000, loss = 0.2108207643032074
iter =  4000, loss = 0.20822012424468994
iter =  6000, loss = 0.20785333216190338
iter =  8000, loss = 0.2078661173582077
iter = 10000, loss = 0.20727649331092834
iter = 12000, loss = 0.20629681646823883
iter = 14000, loss = 0.20462894439697266
iter = 16000, loss = 0.20273444056510925
iter = 18000, loss = 0.20055894553661346
[70, 7]
