-
Notifications
You must be signed in to change notification settings - Fork 40
/
custom_lrnrs.html
367 lines (331 loc) · 34 KB
/
custom_lrnrs.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
<!DOCTYPE html>
<!-- Generated by pkgdown: do not edit by hand --><html lang="en">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Defining New `sl3` Learners • sl3</title>
<!-- jquery --><script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.4.1/jquery.min.js" integrity="sha256-CSXorXvZcTkaix6Yvo6HppcZGetbYMGWSFlBw8HfCJo=" crossorigin="anonymous"></script><!-- Bootstrap --><link href="https://cdnjs.cloudflare.com/ajax/libs/bootswatch/3.4.0/flatly/bootstrap.min.css" rel="stylesheet" crossorigin="anonymous">
<script src="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.4.1/js/bootstrap.min.js" integrity="sha256-nuL8/2cJ5NDSSwnKD8VqreErSWHtnEP9E7AySL+1ev4=" crossorigin="anonymous"></script><!-- bootstrap-toc --><link rel="stylesheet" href="../bootstrap-toc.css">
<script src="../bootstrap-toc.js"></script><!-- Font Awesome icons --><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/all.min.css" integrity="sha256-mmgLkCYLUQbXn0B1SRqzHar6dCnv9oZFPEC1g1cwlkk=" crossorigin="anonymous">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/v4-shims.min.css" integrity="sha256-wZjR52fzng1pJHwx4aV2AO3yyTOXrcDW7jBpJtTwVxw=" crossorigin="anonymous">
<!-- clipboard.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/clipboard.js/2.0.6/clipboard.min.js" integrity="sha256-inc5kl9MA1hkeYUt+EC3BhlIgyp/2jDIyBLS6k3UxPI=" crossorigin="anonymous"></script><!-- headroom.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/headroom.min.js" integrity="sha256-AsUX4SJE1+yuDu5+mAVzJbuYNPHj/WroHuZ8Ir/CkE0=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/jQuery.headroom.min.js" integrity="sha256-ZX/yNShbjqsohH1k95liqY9Gd8uOiE1S4vZc+9KQ1K4=" crossorigin="anonymous"></script><!-- pkgdown --><link href="../pkgdown.css" rel="stylesheet">
<script src="../pkgdown.js"></script><meta property="og:title" content="Defining New `sl3` Learners">
<meta property="og:description" content="sl3">
<!-- mathjax --><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js" integrity="sha256-nvJJv9wWKEm88qvoQl9ekL2J+k/RWIsaSScxxlsrv8k=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/config/TeX-AMS-MML_HTMLorMML.js" integrity="sha256-84DKXVJXs0/F8OTMzX4UR909+jtl4G7SPypPavF+GfA=" crossorigin="anonymous"></script><!--[if lt IE 9]>
<script src="https://oss.maxcdn.com/html5shiv/3.7.3/html5shiv.min.js"></script>
<script src="https://oss.maxcdn.com/respond/1.4.2/respond.min.js"></script>
<![endif]--><!-- Global site tag (gtag.js) - Google Analytics --><script async src="https://www.googletagmanager.com/gtag/js?id=UA-115145808-1"></script><script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-115145808-1');
</script>
</head>
<body data-spy="scroll" data-target="#toc">
<div class="container template-article">
<header><div class="navbar navbar-default navbar-fixed-top" role="navigation">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<span class="navbar-brand">
<a class="navbar-link" href="../index.html">sl3</a>
<span class="version label label-default" data-toggle="tooltip" data-placement="bottom" title="Released version">1.4.3</span>
</span>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul class="nav navbar-nav">
<li>
<a href="https://tlverse.org">tlverse</a>
</li>
<li>
<a href="../index.html">sl3</a>
</li>
<li>
<a href="../reference/index.html">Reference</a>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-expanded="false">
Articles
<span class="caret"></span>
</a>
<ul class="dropdown-menu" role="menu">
<li>
<a href="../articles/intro_sl3.html">Intro to sl3</a>
</li>
<li>
<a href="../articles/SuperLearner_benchmarks.html">sl3 Benchmarks</a>
</li>
<li>
<a href="../slides/sl3-slides.html">sl3 Presentation</a>
</li>
<li>
<a href="../articles/custom_lrnrs.html">Defining Custom sl3 Learners</a>
</li>
</ul>
</li>
</ul>
<ul class="nav navbar-nav navbar-right">
<li>
<a href="https://github.com/tlverse/sl3/">
<span class="fab fa-github fa-lg"></span>
</a>
</li>
</ul>
</div>
<!--/.nav-collapse -->
</div>
<!--/.container -->
</div>
<!--/.navbar -->
</header><script src="custom_lrnrs_files/header-attrs-2.7.2/header-attrs.js"></script><script src="custom_lrnrs_files/accessible-code-block-0.0.1/empty-anchor.js"></script><div class="row">
<div class="col-md-9 contents">
<div class="page-header toc-ignore">
<h1 data-toc-skip>Defining New <code>sl3</code> Learners</h1>
<h4 class="author">Jeremy Coyle, Nima Hejazi, Ivana Malenica, Oleg Sofrygin</h4>
<h4 class="date">2021-10-20</h4>
<small class="dont-index">Source: <a href="https://github.com/tlverse/sl3/blob/master/vignettes/custom_lrnrs.Rmd"><code>vignettes/custom_lrnrs.Rmd</code></a></small>
<div class="hidden name"><code>custom_lrnrs.Rmd</code></div>
</div>
<div id="introduction" class="section level2">
<h2 class="hasAnchor">
<a href="#introduction" class="anchor"></a>Introduction</h2>
<p>This guide describes the process of implementing a learner class for a new machine learning algorithm. By writing a learner class for your favorite machine learning algorithm, you will be able to use it in all the places you could otherwise use any other <code>sl3</code> learners, including <code>Pipeline</code>s, <code>Stack</code>s, and Super Learner. We have done our best to streamline the process of creating new <code>sl3</code> learners.</p>
<p>Before diving into defining a new learner, it will likely be helpful to read some background material. If you haven’t already read it, the <a href="intro_sl3.html">“Modern Machine Learning in R”</a> vignette is a good introduction to the <code>sl3</code> package and it’s underlying architecture. The <a href="https://cran.r-project.org/web/packages/R6/vignettes/Introduction.html"><code>R6</code></a> documentation will help you understand how <code>R6</code> classes are defined. In addition, the help files for <a href="https://sl3.tlverse.org/reference/sl3_Task.html"><code>sl3_Task</code></a> and <a href="https://sl3.tlverse.org/reference/Lrnr_base.html"><code>Lrnr_base</code></a> are good resources for how those objects can be used. If you’re interested in defining learners that fit sub-learners, reading the documentation of the <a href="https://delayed.tlverse.org/articles/delayed.html"><code>delayed</code></a> package will be helpful.</p>
<p>In the following sections, we introduce and review a template for a new <code>sl3</code> learner, describing the sections that can be used to define your new learner. This is followed by a discussion of the important task of documenting and testing your new learner. Finally, we conclude by explaining how you can add your learner to <code>sl3</code> so that others may make use of it.</p>
<div class="sourceCode" id="cb1"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span class="kw"><a href="https://rdrr.io/r/base/library.html">library</a></span><span class="op">(</span><span class="va"><a href="https://tlverse.org/sl3">sl3</a></span><span class="op">)</span></code></pre></div>
</div>
<div id="learner-template" class="section level2">
<h2 class="hasAnchor">
<a href="#learner-template" class="anchor"></a>Learner Template</h2>
<p><code>sl3</code> provides a template of a learner for use in defining new learners. You can make a copy of the template to work on by invoking <code>write_learner_template</code>:</p>
<div class="sourceCode" id="cb2"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span class="co">## Not run:</span>
<span class="fu"><a href="../reference/write_learner_template.html">write_learner_template</a></span><span class="op">(</span><span class="st">"path/to/write/Learner_template.R"</span><span class="op">)</span></code></pre></div>
<p>Let’s take a look at that template:</p>
<div class="sourceCode" id="cb3"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span class="co">##' Template of a \code{sl3} Learner.</span>
<span class="co">##'</span>
<span class="co">##' This is a template for defining a new learner.</span>
<span class="co">##' This can be copied to a new file using \code{\link{write_learner_template}}.</span>
<span class="co">##' The remainder of this documentation is an example of how you might write documentation for your new learner.</span>
<span class="co">##' This learner uses \code{\link[my_package]{my_ml_fun}} from \code{my_package} to fit my favorite machine learning algorithm.</span>
<span class="co">##'</span>
<span class="co">##' @docType class</span>
<span class="co">##' @importFrom R6 R6Class</span>
<span class="co">##' @export</span>
<span class="co">##' @keywords data</span>
<span class="co">##' @return Learner object with methods for training and prediction. See \code{\link{Lrnr_base}} for documentation on learners.</span>
<span class="co">##' @format \code{\link{R6Class}} object.</span>
<span class="co">##' @family Learners</span>
<span class="co">##'</span>
<span class="co">##' @section Parameters:</span>
<span class="co">##' \describe{</span>
<span class="co">##' \item{\code{param_1="default_1"}}{ This parameter does something.</span>
<span class="co">##' }</span>
<span class="co">##' \item{\code{param_2="default_2"}}{ This parameter does something else.</span>
<span class="co">##' }</span>
<span class="co">##' \item{\code{...}}{ Other parameters passed directly to \code{\link[my_package]{my_ml_fun}}. See its documentation for details.</span>
<span class="co">##' }</span>
<span class="co">##' }</span>
<span class="co">##'</span>
<span class="co">##' @section Methods:</span>
<span class="co">##' \describe{</span>
<span class="co">##' \item{\code{special_function(arg_1)}}{</span>
<span class="co">##' My learner is special so it has a special function.</span>
<span class="co">##'</span>
<span class="co">##' \itemize{</span>
<span class="co">##' \item{\code{arg_1}: A very special argument.</span>
<span class="co">##' }</span>
<span class="co">##' }</span>
<span class="co">##' }</span>
<span class="co">##' }</span>
<span class="va">Lrnr_template</span> <span class="op"><-</span> <span class="kw">R6Class</span><span class="op">(</span>
classname <span class="op">=</span> <span class="st">"Lrnr_template"</span>, inherit <span class="op">=</span> <span class="va">Lrnr_base</span>,
portable <span class="op">=</span> <span class="cn">TRUE</span>, class <span class="op">=</span> <span class="cn">TRUE</span>,
<span class="co"># Above, you should change Lrnr_template (in both the object name and the classname argument)</span>
<span class="co"># to a name that indicates what your learner does</span>
public <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/list.html">list</a></span><span class="op">(</span>
<span class="co"># you can define default parameter values here</span>
<span class="co"># if possible, your learner should define defaults for all required parameters</span>
initialize <span class="op">=</span> <span class="kw">function</span><span class="op">(</span><span class="va">param_1</span> <span class="op">=</span> <span class="st">"default_1"</span>, <span class="va">param_2</span> <span class="op">=</span> <span class="st">"default_2"</span>, <span class="va">...</span><span class="op">)</span> <span class="op">{</span>
<span class="co"># this captures all parameters to initialize and saves them as self$params</span>
<span class="va">params</span> <span class="op"><-</span> <span class="fu"><a href="../reference/args_to_list.html">args_to_list</a></span><span class="op">(</span><span class="op">)</span>
<span class="va">super</span><span class="op">$</span><span class="fu">initialize</span><span class="op">(</span>params <span class="op">=</span> <span class="va">params</span>, <span class="va">...</span><span class="op">)</span>
<span class="op">}</span>,
<span class="co"># you can define public functions that allow your learner to do special things here</span>
<span class="co"># for instance glm learner might return prediction standard errors</span>
special_function <span class="op">=</span> <span class="kw">function</span><span class="op">(</span><span class="va">arg_1</span><span class="op">)</span> <span class="op">{</span>
<span class="op">}</span>
<span class="op">)</span>,
private <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/list.html">list</a></span><span class="op">(</span>
<span class="co"># list properties your learner supports here.</span>
<span class="co"># Use sl3_list_properties() for a list of options</span>
.properties <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/c.html">c</a></span><span class="op">(</span><span class="st">""</span><span class="op">)</span>,
<span class="co"># list any packages required for your learner here.</span>
.required_packages <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/c.html">c</a></span><span class="op">(</span><span class="st">"my_package"</span><span class="op">)</span>,
<span class="co"># .train takes task data and returns a fit object that can be used to generate predictions</span>
.train <span class="op">=</span> <span class="kw">function</span><span class="op">(</span><span class="va">task</span><span class="op">)</span> <span class="op">{</span>
<span class="co"># generate an argument list from the parameters that were</span>
<span class="co"># captured when your learner was initialized.</span>
<span class="co"># this allows users to pass arguments directly to your ml function</span>
<span class="va">args</span> <span class="op"><-</span> <span class="va">self</span><span class="op">$</span><span class="va">params</span>
<span class="co"># get outcome variable type</span>
<span class="co"># preferring learner$params$outcome_type first, then task$outcome_type</span>
<span class="va">outcome_type</span> <span class="op"><-</span> <span class="va">self</span><span class="op">$</span><span class="fu">get_outcome_type</span><span class="op">(</span><span class="va">task</span><span class="op">)</span>
<span class="co"># should pass something on to your learner indicating outcome_type</span>
<span class="co"># e.g. family or objective</span>
<span class="co"># add task data to the argument list</span>
<span class="co"># what these arguments are called depends on the learner you are wrapping</span>
<span class="va">args</span><span class="op">$</span><span class="va">x</span> <span class="op"><-</span> <span class="fu"><a href="https://Rdatatable.gitlab.io/data.table/reference/as.matrix.html">as.matrix</a></span><span class="op">(</span><span class="va">task</span><span class="op">$</span><span class="va">X_intercept</span><span class="op">)</span>
<span class="va">args</span><span class="op">$</span><span class="va">y</span> <span class="op"><-</span> <span class="va">outcome_type</span><span class="op">$</span><span class="fu">format</span><span class="op">(</span><span class="va">task</span><span class="op">$</span><span class="va">Y</span><span class="op">)</span>
<span class="co"># only add arguments on weights and offset</span>
<span class="co"># if those were specified when the task was generated</span>
<span class="kw">if</span> <span class="op">(</span><span class="va">task</span><span class="op">$</span><span class="fu">has_node</span><span class="op">(</span><span class="st">"weights"</span><span class="op">)</span><span class="op">)</span> <span class="op">{</span>
<span class="va">args</span><span class="op">$</span><span class="va">weights</span> <span class="op"><-</span> <span class="va">task</span><span class="op">$</span><span class="va">weights</span>
<span class="op">}</span>
<span class="kw">if</span> <span class="op">(</span><span class="va">task</span><span class="op">$</span><span class="fu">has_node</span><span class="op">(</span><span class="st">"offset"</span><span class="op">)</span><span class="op">)</span> <span class="op">{</span>
<span class="va">args</span><span class="op">$</span><span class="va">offset</span> <span class="op"><-</span> <span class="va">task</span><span class="op">$</span><span class="va">offset</span>
<span class="op">}</span>
<span class="co"># call a function that fits your algorithm</span>
<span class="co"># with the argument list you constructed</span>
<span class="va">fit_object</span> <span class="op"><-</span> <span class="fu"><a href="../reference/call_with_args.html">call_with_args</a></span><span class="op">(</span><span class="va">my_ml_fun</span>, <span class="va">args</span><span class="op">)</span>
<span class="co"># return the fit object, which will be stored</span>
<span class="co"># in a learner object and returned from the call</span>
<span class="co"># to learner$predict</span>
<span class="kw"><a href="https://rdrr.io/r/base/function.html">return</a></span><span class="op">(</span><span class="va">fit_object</span><span class="op">)</span>
<span class="op">}</span>,
<span class="co"># .predict takes a task and returns predictions from that task</span>
.predict <span class="op">=</span> <span class="kw">function</span><span class="op">(</span><span class="va">task</span> <span class="op">=</span> <span class="cn">NULL</span><span class="op">)</span> <span class="op">{</span>
<span class="va">self</span><span class="op">$</span><span class="va">training_task</span>
<span class="va">self</span><span class="op">$</span><span class="va">training_outcome_type</span>
<span class="va">self</span><span class="op">$</span><span class="va">fit_object</span>
<span class="va">predictions</span> <span class="op"><-</span> <span class="fu"><a href="https://rdrr.io/r/stats/predict.html">predict</a></span><span class="op">(</span><span class="va">self</span><span class="op">$</span><span class="va">fit_object</span>, <span class="va">task</span><span class="op">$</span><span class="va">X</span><span class="op">)</span>
<span class="kw"><a href="https://rdrr.io/r/base/function.html">return</a></span><span class="op">(</span><span class="va">predictions</span><span class="op">)</span>
<span class="op">}</span>
<span class="op">)</span>
<span class="op">)</span></code></pre></div>
<p>The template has comments indicating where details specific to the learner you’re trying to implement should be filled in. In the next section, we will discuss those details further.</p>
</div>
<div id="defining-your-learner" class="section level2">
<h2 class="hasAnchor">
<a href="#defining-your-learner" class="anchor"></a>Defining your Learner</h2>
<div id="learner-name-and-class" class="section level3">
<h3 class="hasAnchor">
<a href="#learner-name-and-class" class="anchor"></a>Learner Name and Class</h3>
<p>At the top of the template, we define an object <code>Lrnr_template</code> and set <code>classname = "Lrnr_template"</code>. You should modify these to match the name of your new learner, which should also match the name of the corresponding R file. Note that the name should be prefixed by <code>Lrnr_</code> and use <a href="https://en.wikipedia.org/wiki/Snake_case"><code>snake_case</code></a>.</p>
</div>
<div id="publicinitialize" class="section level3">
<h3 class="hasAnchor">
<a href="#publicinitialize" class="anchor"></a><code>public$initialize</code>
</h3>
<p>This function defines the constructor for your learner, and it stores the arguments (if any) provided when a user calls <code><a href="../reference/Lrnr_base.html">make_learner(Lrnr_your_learner, ...)</a></code>. You can also provide default parameter values, just as the template does with <code>param_1 = "default_1"</code>, and <code>param_2 = "default_2"</code>. All parameters used by your newly defined learners should have defaults whenever possible. This will allow users to use your learner without having to figure out what reasonable parameter values might be. Parameter values should be documented; see the section below on <a href="#doctest">documentation</a> for details.</p>
</div>
<div id="publicspecial_functions" class="section level3">
<h3 class="hasAnchor">
<a href="#publicspecial_functions" class="anchor"></a><code>public$special_function</code>s</h3>
<p>You can of course define functions for things only your learner can do. These should be public functions like the <code>special_function</code> defined in the example. These should be documented; see the section below on <a href="#doctest">documentation</a> for details.</p>
</div>
<div id="private-properties" class="section level3">
<h3 class="hasAnchor">
<a href="#private-properties" class="anchor"></a><code>private$.properties</code>
</h3>
<p>This field defines properties supported by your learner. This may include different outcome types that are supported, offsets and weights, amongst many other possibilities. To see a list of all properties supported/used by at least one learner, you may invoke <code>sl3_list_properties</code>:</p>
<div class="sourceCode" id="cb4"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span class="fu"><a href="../reference/sl3_list_properties.html">sl3_list_properties</a></span><span class="op">(</span><span class="op">)</span></code></pre></div>
<pre><code>## [1] "binomial" "categorical" "continuous" "cv"
## [5] "density" "h2o" "ids" "importance"
## [9] "offset" "preprocessing" "sampling" "screener"
## [13] "timeseries" "weights" "wrapper"</code></pre>
</div>
<div id="private-required_packages" class="section level3">
<h3 class="hasAnchor">
<a href="#private-required_packages" class="anchor"></a><code>private$.required_packages</code>
</h3>
<p>This field defines other R packages required for your learner to work properly. These will be loaded when an object of your new learner class is initialized.</p>
</div>
<div id="user-interface-for-learners" class="section level3">
<h3 class="hasAnchor">
<a href="#user-interface-for-learners" class="anchor"></a>User Interface for Learners</h3>
<p>If you’ve used <code>sl3</code> before, you may have noticed that while users are instructed to use <code>learner$train</code>, <code>learner$predict</code>, and <code>learner$chain</code>, to train, generate predictions, and generate a chained task for a given learner object, respectively, the template does not implement these methods. Instead, the template implements private methods called <code>.train</code>, <code>.predict</code>, and <code>.chain</code>. The specifics of these methods are explained below; however, it is helpful to first understand how the two sets of methods are related. At the risk of complicating things further, it is worth noting that there is actually a third set of methods (<code>learner$base_train</code>, <code>learner$base_predict</code>, and <code>learner$base_chain</code>) of which you may not be aware.</p>
<p>So, what happens when a user calls <code>learner$train</code>? That method generates a <code>delayed</code> object using the <code>delayed_learner_train</code> function, and then computes that delayed object. In turn, <code>delayed_learner_train</code> defines a delayed computation that calls <code>base_train</code>, a user-facing function that can be used to train tasks without using the facilities of the <code>delayed</code> package. <code>base_train</code> validates the user input, and in turn calls <code>private$.train</code>. When <code>private$.train</code> returns a <code>fit_object</code>, <code>base_train</code> takes that fit object, generates a learner fit object, and returns it to the user.</p>
<p>Each call to <code>learner$train</code> involves three separate training methods:</p>
<ol style="list-style-type: decimal">
<li>The user-facing <code>learner$train</code> – trains a learner in a manner that can be parallelized using <code>delayed</code>, which calls <code>...</code>
</li>
<li>
<code>...</code> the user-facing <code>learner$base_train</code> that validates user input, and which calls <code>...</code>
</li>
<li>
<code>...</code> the internal <code>private$.train</code>, which does the actual work of fitting the learner and returning the fit object.</li>
</ol>
<p>The logic in the user-facing <code>learner$train</code> and <code>learner$base_train</code> is defined in the <code>Lrnr_base</code> base class and is shared across all learners. As such, these methods need not be reimplemented in individual learners. By contrast, <code>private$.train</code> contains the behavior that is specific to each individual learner and should be reimplemented at the level of each individual learner. Since <code>learner$base_train</code> does not use <code>delayed</code>, it may be helpful to use it when debugging the training code in a new learner. The program flow used for prediction and chaining is analogous.</p>
</div>
<div id="private-train" class="section level3">
<h3 class="hasAnchor">
<a href="#private-train" class="anchor"></a><code>private$.train</code>
</h3>
<p>This is the main training function, which takes in a task and returns a <code>fit_object</code> that contains all information needed to generate predictions. The fit object should not contain more data than is absolutely necessary, as including excess information will create needless inefficiencies. Many learner functions (like <code>glm</code>) store one or more copies of their training data – this uses unnecessary memory and will hurt learner performance for large sample sizes. Thus, these copies of the data should be removed from the fit object before it is returned. You may make use of <code>true_obj_size</code> to estimate the size of your <code>fit_object</code>. For most learners, <code>fit_object</code> size should <em>not grow</em> linearly with training sample size. If it does, and this is unexpected, please try to reduce the size of the <code>fit_object</code>.</p>
<p>Most of the time, the learner you are implementing will be fit using a function that already exists elsewhere. We’ve built some tools to facilitate passing parameter values directly to such functions. The <code>private$.train</code> function in the template uses a common pattern: it builds up an argument list starting with the parameter values and using data from the task, it then uses <code>call_with_args</code> to call <code>my_ml_fun</code> with that argument list. It’s not required that learners use this pattern, but it will be helpful in the common case where the learner is simply wrapping an underlying <code>my_ml_fun</code>.</p>
<p>By default, <code>call_with_args</code> will pass all arguments in the argument list matched by the definition of the function that it is calling. This allows the learner to silently drop irrelevant parameters from the call to <code>my_ml_fun</code>. Some learners either capture important arguments using dot arguments (<code>...</code>) or by passing important arguments through such dot arguments on to a secondary function. Both of these cases can be handled using the <code>other_valid</code> and <code>keep_all</code> options to <code>call_with_args</code>. The former allows you to list other valid arguments and the latter disables argument filtering altogether.</p>
</div>
<div id="private-predict" class="section level3">
<h3 class="hasAnchor">
<a href="#private-predict" class="anchor"></a><code>private$.predict</code>
</h3>
<p>This is the main prediction function, and takes in a task and generates predictions for that task using the <code>fit_object</code>. If those predictions are 1-dimensional, they will be coerced to a vector by <code>base_predict</code>.</p>
</div>
<div id="private-chain" class="section level3">
<h3 class="hasAnchor">
<a href="#private-chain" class="anchor"></a><code>private$.chain</code>
</h3>
<p>This is the main chaining function. It takes in a task and generates a chained task (based on the input task) using the given <code>fit_object</code>. If this method is not implemented, your learner will use the default chaining behavior, which is to return a new task where the covariates are defined as your learner’s predictions for the current task.</p>
</div>
<div id="advanced-learners-with-sub-learners" class="section level3">
<h3 class="hasAnchor">
<a href="#advanced-learners-with-sub-learners" class="anchor"></a><em>Advanced</em>: Learners with sub-learners</h3>
<p>Generally speaking, the above sections will be all that’s required for implementing a new learner in the <code>sl3</code> framework. In some cases, it may be desirable to define learners that have “sub-learners” or other learners on which they depend. Examples of such learners are <code>Stack</code>, <code>Pipeline</code>, <code>Lrnr_cv</code>, and <code>Lrnr_sl</code>. In order to parallelize the fitting of these sub-learners using <code>delayed</code>, these learners implement a specialized <code>private$.train_sublearners</code> method that calls <code>delayed_learner_train</code> on their sub-learners, returning a single <code>delayed</code> object that, when evaluated, returns all relevant fit objects from these sub-learners. The result of that call is then passed as a second argument to their <code>private$.train</code> method, which now has the function prototype <code>private$.train(task, trained_sublearners)</code>. Learners defined in such a manner usually have a much shorter computation time; the <code>predict</code> and <code>chain</code> methods are not currently parallelized in this way, although this is subject to change in the future.</p>
<p>If, like these learners, your learner depends on sub-learners, you have two options:</p>
<ol style="list-style-type: decimal">
<li>Don’t worry about parallelizing sub-learners. Simply implement <code>private$.train</code> as discussed above, being sure to call <code>sublearner$base_train</code> and not <code>sublearner$train</code>, to avoid nesting calls to <code>delayed</code>, which may result in sub-optimal performance.</li>
<li>Implement <code>private$.train_sublearners(task)</code> and <code>private$.train(task, trained_sublearners)</code>, to parallelize sub-learners using <code>delayed</code>. We suggest reviewing the implementations of the <code>Stack</code>, <code>Pipeline</code>, <code>Lrnr_cv</code> and <code>Lrnr_sl</code> to get a better understanding of how to implement parallelized sub-learners.</li>
</ol>
<p>In either case, you should be careful to call <code>sublearner$base_predict</code> and <code>sublearner$base_chain</code>, instead of <code>sublearner$predict</code> and <code>sublearner$chain</code>, except in the context of the <code>private$.train_sublearners</code> function, where you should use <code>delayed_learner_fit_predict</code> and <code>delayed_learner_fit_chain</code>.</p>
</div>
</div>
<div id="doctest" class="section level2">
<h2 class="hasAnchor">
<a href="#doctest" class="anchor"></a>Documenting and Testing your Learner</h2>
<p>If you want other people to be able to use your learner, you will need to document and provide unit tests for it. The above template has example documentation, written in the <a href="http://r-pkgs.had.co.nz/man.html">roxygen</a> format. Most importantly, you should describe what your learner does, reference any external code it uses, and document any parameters and public methods defined by it.</p>
<p>It’s also important to <a href="http://r-pkgs.had.co.nz/tests.html">test</a> your learner. You should write unit tests to verify that your learner can train and predict on new data, and, if applicable, generate a chained task. It might also be a good idea to use the <code>risk</code> function in <code>sl3</code> to verify your learner’s performance on a sample dataset. That way, if you change your learner and performance drops, you know something may have gone wrong.</p>
</div>
<div id="submitting-your-learner-to-sl3" class="section level2">
<h2 class="hasAnchor">
<a href="#submitting-your-learner-to-sl3" class="anchor"></a>Submitting your Learner to <code>sl3</code>
</h2>
<p>Once you’ve implemented your new learner (and made sure that it has quality documentation and unit tests), please consider adding it to the <code>sl3</code> project. This will make it possible for other <code>sl3</code> users to use and build on your work. Make sure to add any R packages listed in <code>.required_packages</code> to the <code>Suggests:</code> field of the <code>DESCRIPTION</code> file of the <code>sl3</code> package. Once this is done, please submit a <strong>Pull Request</strong> to the <code>sl3</code> package <a href="https://github.com/tlverse/sl3">on GitHub</a> to request that your learned be added. If you’ve never made a “Pull Request” before, see this helpful guide: <a href="https://yangsu.github.io/pull-request-tutorial/" class="uri">https://yangsu.github.io/pull-request-tutorial/</a>.</p>
<p>From the <code>sl3</code> team, thanks for your interest in extending <code>sl3</code>!</p>
</div>
</div>
<div class="col-md-3 hidden-xs hidden-sm" id="pkgdown-sidebar">
<nav id="toc" data-toggle="toc"><h2 data-toc-skip>Contents</h2>
</nav>
</div>
</div>
<footer><div class="copyright">
<p>Developed by Jeremy Coyle, Nima Hejazi, Oleg Sofrygin, Ivana Malenica, Rachael Phillips.</p>
</div>
<div class="pkgdown">
<p>Site built with <a href="https://pkgdown.r-lib.org/">pkgdown</a> 1.6.1.</p>
</div>
</footer>
</div>
</body>
</html>