New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor linear machine api #4319
Conversation
@@ -40,59 +32,51 @@ void CAveragedPerceptron::init() | |||
SG_ADD(&learn_rate, "learn_rate", "Learning rate.", MS_AVAILABLE); | |||
} | |||
|
|||
bool CAveragedPerceptron::train_machine(CFeatures* data) | |||
void CAveragedPerceptron::train_machine(CFeatures* features, CLabels* labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont we want this to be const parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course. unfortunately we can't use const for the time being because there are many non-const methods that should be const logically, e.g. get_feature_matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, sure
(though get_feature_matrix can be made const without problem)
{ | ||
output[i] = features->dense_dot(i, w.vector, w.vlen) + bias; | ||
output[i] = dot_features->dense_dot(i, w.vector, w.vlen) + bias; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this could be done via DotIterator
bias=tmp_bias/(num_vec*iter); | ||
|
||
SG_FREE(output); | ||
SG_FREE(tmp_w); | ||
|
||
set_w(w); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would prefer to remove this and udpate the state vector itself inside the main loop.
@shubham808 can elaborate as he is doing similar stuff for the perceptron
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
long time ago this comment. But basically we want to start updating model states inside the loops for iterative machines
@shubham808 comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karlnapf I think set_w
has been already called inside the iteration function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你是对的,对不起
src/shogun/classifier/Perceptron.cpp
Outdated
{ | ||
ASSERT(m_labels) | ||
if (!features->has_property(FP_DOT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not as
? would be a bit cleaner. Training is expensive, so the additional costs shouldnt matter or?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can provide more informative error message here, but as
is as good
src/shogun/machine/Machine.cpp
Outdated
@@ -66,6 +66,23 @@ bool CMachine::train(CFeatures* data) | |||
return result; | |||
} | |||
|
|||
void CMachine::fit(CFeatures* features) | |||
{ | |||
REQUIRE(train(features), "Failed to fit machine %s\n", get_name()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should throw an exception rather than this boolean stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
require failed -> sg_io->msg(MSG_ERROR, "blabla") -> https://github.com/shogun-toolbox/shogun/blob/develop/src/shogun/io/SGIO.cpp#L125
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but I mean, shouldn't the train method itself just throw an exception?
We don't get any context information here. If train would throw an exception, we could catch it here and then say "training failed for reason X". No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm going to just throw exceptions in fit
. but for now train returns boolean, i keep this so that it is consistent that fit always throw exceptions when failure.
b7d2ff9
to
f211b65
Compare
f211b65
to
0fa4faa
Compare
550639e
to
af8af72
Compare
a241ab1
to
dbfd69e
Compare
Hi! |
@karlnapf I found there are too many things involved here. When you change the |
a2078d9
to
fb4930c
Compare
|
||
ASSERT(m_labels) | ||
init_linear_term(); | ||
void CLibLinear::train_machine(CFeatures* features, CLabels* labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are many non-const methods algorithm internal methods. We can either use non-const arguments here, or use const_cast
to drop const in internal methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karlnapf One problem ref counting is IterativeMachine
. We need to increase ref count of features and labels in init_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah crap yes, the ref counting is non-const....
ok so we need to postpone the const making for now until we have another way to do the reference counting
Nice! Let me know when we should have a look at this... |
I like the idea of porting old methods to the new nicer api using const casts. Although we might run into see trouble doing that.... @lisitsyn @iglesias @vigsterkr ? |
|
||
virtual void CLPBoost::train_machine(CFeatures* features, CLabels* labels); | ||
{ | ||
ASSERT(labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think handling those asserts happens (should happen?) in the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, it should be in the base class. I will check if it is possible for now.
@@ -567,6 +567,20 @@ class CSGObject | |||
demangled_type<T>().c_str()); | |||
return nullptr; | |||
} | |||
|
|||
template <class T> const T* as() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are using this below I assume when changing the signatures of some methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useful for const pointers (const CFeatures*
). Although we decided to move to const methods later, I leave it here for possible future usage.
Cool some progress :) Let us know how you are getting on here |
eefbcf8
to
7caeb67
Compare
@@ -24,4 +18,3 @@ CLeastSquaresRegression::CLeastSquaresRegression(CDenseFeatures<float64_t>* data | |||
: CLinearRidgeRegression(0, data, lab) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karlnapf Got some linking issue
Undefined symbols for architecture x86_64:
2019-01-07T16:59:36.8868140Z "bool shogun::CLinearRidgeRegression::train_machine_templated<double>(shogun::CDenseFeatures<double> const*)", referenced from:
2019-01-07T16:59:36.8911800Z shogun::CDenseRealDispatch<shogun::CLinearRidgeRegression, shogun::CLinearMachine>::train_dense(shogun::CFeatures*) in LeastSquaresRegression.cpp.o
2019-01-07T16:59:36.9066190Z "bool shogun::CLinearRidgeRegression::train_machine_templated<long double>(shogun::CDenseFeatures<long double> const*)", referenced from:
2019-01-07T16:59:36.9110490Z shogun::CDenseRealDispatch<shogun::CLinearRidgeRegression, shogun::CLinearMachine>::train_dense(shogun::CFeatures*) in LeastSquaresRegression.cpp.o
2019-01-07T16:59:36.9263540Z "bool shogun::CLinearRidgeRegression::train_machine_templated<float>(shogun::CDenseFeatures<float> const*)", referenced from:
2019-01-07T16:59:36.9306660Z shogun::CDenseRealDispatch<shogun::CLinearRidgeRegression, shogun::CLinearMachine>::train_dense(shogun::CFeatures*) in LeastSquaresRegression.cpp.o
CLinearRidgeRegression::train_machine_templated
is a template method, its definition is in LinearRidgeRegression.cpp and is not visible in this cpp file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karlnapf I only saw this error on CI. I cannot reproduce this locally with docker image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I remember this issue somehow with the same reproducibility issues. Maybe @Saurabh7 remembers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this works now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we need to explicit instantiate templates in cpp files to make it linkable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is solved.
Very Interesting PR btw ! So we are planning to change all APIs to .fit
, .predict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I actually meant to ping @shubham808 ;)
But yes @Saurabh7 we are intending to do that
779b11a
to
29d3bb6
Compare
|
||
SG_SERROR( | ||
"Object of type %s cannot be converted to type %s.\n", | ||
demangled_type<std::remove_pointer_t<decltype(this)>>().c_str(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw can we re-use code from the non-const version in here? Like one calls the other?
372b720
to
0cd2e63
Compare
0cd2e63
to
f58cade
Compare
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
keeping this alive as I think this is still the direction we want to go in |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
bump |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
This issue is now being closed due to a lack of activity. Feel free to reopen it. |
[WIP]
Adapt machine api to fit + predict
This need some heavy refactor, so let's start from linear machines and iterate a few times.
The idea is to keep both old and new apis working and then remove old apis gradually.
train_machine(CFeatures*, CLabels*)
is added to machines that need labels. The old method,train_machine(CFeatures*)
will redirect to the new api and passm_labels
as the labels argument. In this way, the old api (set_labels
+train(CFeatures*)
) still works.Roadmap
void fit(CFeatures*)
,void fit(CFeatures*, CLabels*)
toCMachine
bool train_machine(CFeatures*)
tovoid train_machine(CFeatures*, CLabels*)
in LinearMachineLinearMachine
as a base class method after the new api works for allLinearMachine
subclassesbool train_machine(CFeatures*)
andbool train_machine_templated(CFeatures*)
return void (The latter one should be easier as it is added last year and we haven't used in many places)Known issues for moving to const methods:
regression_labels
) doesn't accept const args