Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TensorModule of distinct input/output types #34

Open
davoclavo opened this issue Jun 27, 2023 · 2 comments
Open

Support TensorModule of distinct input/output types #34

davoclavo opened this issue Jun 27, 2023 · 2 comments
Labels
design API ergonomics and design

Comments

@davoclavo
Copy link
Contributor

davoclavo commented Jun 27, 2023

Currently TensorModule is parametrized on a single type, so keeps the transformation within the same DType:

trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):
  override def toString(): String = "TensorModule"

However there are modules where the input might be different than the output, such as nn.Embedding which accepts Int or Longs as input indexes, and the output could be any DTtype. So the trait would need to be parametrized like this:

trait TensorModule[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]):
 override def toString(): String = "TensorModule"

and the example implementation would be something like:

final class Embedding[ParamType <: FloatNN | ComplexNN: Default](
    numEmbeddings: Int,
    ...
) extends ...
    with TensorModule[IntNN, ParamType]:
    
    def apply(t: Tensor[IntNN]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

This is doable, however there are useful operators on TensorModule, such as nn.Sequential, which expects an array of modules to chain. By having a single parameter the compile time validation is straightforward, but having distinct input/output types things seem to get a bit more complex to validate at compile time.

I will do some research on this on how to solve it.

Any pointers or ideas are more than welcome!

@davoclavo
Copy link
Contributor Author

davoclavo commented Jun 27, 2023

After doing a bit of research it seems like nn.Embedding module is an exception to the norm. There may be other types of pytorch modules that change dtype from input->output, but I am yet unaware of them (perhaps nn.Upsample but in practice I couldn't get it to change types)

So a couple of thoughts

  1. An additional TensorModuleDistinct[D <: DType, D2 <: DType] must be added if we want to have Embedding module

  2. Regarding nn.Sequential - here are the options I can think of:
    a) It would have to accept a Tuple/HList of tensors in order to have all the proper type validations if we wanted to also accept these Modules
    b) Adding an extra optional parameter which allows for an optional initial module which can do type transformations
    c) Ignore this for now and let the user know that they should handle Embedding modules as an extra step when defining their model's layer structure

@davoclavo davoclavo mentioned this issue Jun 28, 2023
6 tasks
@sbrunk
Copy link
Owner

sbrunk commented Jun 28, 2023

After doing a bit of research it seems like nn.Embedding module is an exception to the norm. There may be other types of pytorch modules that change dtype from input->output, but I am yet unaware of them (perhaps nn.Upsample but in practice I couldn't get it to change types)

That's good to know. One thing to keep in mind is that the included models are mostly basic building blocks, but I'm not sure how inputs/output structures and types look like with more complex custom modules. Another good reason to implement a few more architectures like transformers to get a better feeling for it. :)

1. An additional `TensorModuleDistinct[D <: DType, D2 <: DType]` must be added if we want to have `Embedding` module

👍

2. Regarding `nn.Sequential` - here are the options I can think of:
   a) It would have to accept a Tuple/HList of tensors in order to have all the proper type validations if we wanted to also accept these Modules
   b) Adding an extra optional parameter which allows for an optional initial module which can do type transformations
   c) Ignore this for now and let the user know that they should handle Embedding modules as an extra step when defining their model's layer structure

Would be interesting if it is possible to create an easy to use API for a)

One general thought regarding the parameter types of modules is that we need to consider their recursive structure and mutability. Here's what I said in another discussion about that:

In a PyTorch module (which are also mutable) you can convert the parameter types of all submodules recursively. Now if we have a val module MyModule[Float32] and call myModule.to(dtype=torch.float16), we break type-safety (I think).

So perhaps we need to provide an immutable module API and make sure to copy the module and all its submodules recursively to make this safe or perhaps you have another idea how to deal with it.

@sbrunk sbrunk added the design API ergonomics and design label Jul 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design API ergonomics and design
Projects
None yet
Development

No branches or pull requests

2 participants