Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph subnetwork: Multiple outputs #515

Merged
merged 5 commits into from
Apr 7, 2020
Merged

Conversation

mreppen
Copy link
Contributor

@mreppen mreppen commented Apr 6, 2020

Hi, I finally got around to adding this small change. This was discussed in #491 I personally don't need it, but for the sake of completeness, here it is.

Example:

let nn = Graph.(
  let inp = input [| 2 |] in
  let x1 = lambda ~name:"x1" ~out_shape:[|1|] (fun x -> Algodiff.Maths.get_slice [[]; [0]] x) inp in
  let x2 = lambda ~name:"x2" ~out_shape:[|1|] (fun x -> Algodiff.Maths.get_slice [[]; [1]] x) inp in
  let f1 = fully_connected ~name:"f1" 1 x1 in
  let f2 = fully_connected ~name:"f2" 1 x2 in
  let sum = add ~name:"sum" [| f1; f2 |] in
  sum |> get_network);;
Graph.get_subnetwork ~make_inputs:[|"x1"; "x2"|] nn [| "f1"; "f2" |];;

Two comments:

  • I also added a flag for whether to copy the weights.
  • I did change the signature (likely not many are using this yet, so hopefully not a problem). This also allowed me to change the default copy behavior to true. If this is undesirable, d684020 is a version without and I can PR that instead.

PS. I have myself found this subnetwork feature really useful :)

Copy link
Member

@mseri mseri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor comment, based on personal preference, otherwise it looks good to me. Please wait some other review before making any change.

the old one.
val get_subnetwork : ?copy:bool -> ?make_inputs:string array -> network -> string array -> network
(**
Constructs a subnetwork of nodes on which ``output_names`` depend,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be useful to name the options here:

``get_subnetwork ?copy ?make_inputs network output_names`` constructs a subnetwork of nodes of ``network`` ...

@mreppen
Copy link
Contributor Author

mreppen commented Apr 6, 2020

@mseri's comment makes sense to me as well. I have a commit ready to push.

Another point I forgot to bring up: Would it perhaps make sense to drop or rename the get_-part of the function name? After all, it is not a getter function. Either plain subnetwork or something like make_subnetwork?

@mseri
Copy link
Member

mseri commented Apr 6, 2020

Sounds reasonable to me to rename it, this could allow to keep backward compatibility although I don't see the need for this. I think we should wait to see what @jzstark @ryanrhymes or @tachukao think about it.

@tachukao
Copy link
Member

tachukao commented Apr 6, 2020

agree with the name change, looks good to me!

@ryanrhymes ryanrhymes added enhancement R&D Core research and development labels Apr 6, 2020
@ryanrhymes
Copy link
Member

Looks very good to me. Both make_subnetwork and subnetwork are fine by me.

@mreppen
Copy link
Contributor Author

mreppen commented Apr 6, 2020

@ryanrhymes, I picked make_subnetwork. @mseri's doc comment is also in the latest commit.

@ryanrhymes ryanrhymes self-requested a review April 7, 2020 18:04
@ryanrhymes ryanrhymes merged commit 45d9b80 into owlbarn:master Apr 7, 2020
mseri added a commit to mseri/opam-repository that referenced this pull request Oct 4, 2020
CHANGES:

* various documentation improvements (thanks @pveber, @UnixJunkie, @Fourchaux)
* Fix use of access operators (owlbarn/owl#543)
* Upgrade to ocamlformat 0.15.0 (thanks @gpetiot owlbarn/owl#535)
* keep_dims option (owlbarn/owl#531)
* stats: fix infinite loop in ecdf
* Use Fun.protect to ensure all file descriptors are being closed
* owl_ndarray_maths: improve user experience in case of errors
* owl_io: close file descriptors also in case of errors
* owl_dense_ndarray_generic: fix error on printing 0-ary arrays
* fixed bug in sub forward mode (owlbarn/owl#533)
* Add stack to Algodiff (owlbarn/owl#528)
* added log_sum_exp to Ndarray and Algodiff (owlbarn/owl#527)
* added single-precision and double-precision Bessel functions to Ndarray  (owlbarn/owl#526)
* Fixes owlbarn/owl#518 by introducing another `/` to resolve data directory (@jotterbach owlbarn/owl#519)
* Graph Slice node (resolves owlbarn/owl#483) (@mreppen owlbarn/owl#517)
* Graph subnetwork: Multiple outputs (@mreppen owlbarn/owl#515)
* Added kron and swap to Algodiff operations (owlbarn/owl#512)
* various other small fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement R&D Core research and development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants