Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: change how Model is declared to remove sub-classing of vak.engine.Model #536

Closed
NickleDave opened this issue Jul 8, 2022 · 9 comments
Assignees
Labels
ENH: enhancement enhancement; new feature or request
Projects

Comments

@NickleDave
Copy link
Collaborator

NickleDave commented Jul 8, 2022

edit: I think I have a better way to do this, see below

to decouple Model from Engine to address #362 , add a Model attrs class with a from_config classmethod.

This will be a kind of "interface": if you want to be a Model, have a from_config method that will result in network + loss + optimizer +metrics attributes

This way user doesn't have to actually subclass this Model attrs class.
They can use the built-in attrs class if they, want or they can make some other dataclass that can be used with the cli as long as it obeys the "interface" .
vak will "know" to instantiate the model using from_config

@NickleDave NickleDave added the ENH: enhancement enhancement; new feature or request label Jul 8, 2022
@NickleDave NickleDave self-assigned this Jul 8, 2022
@NickleDave NickleDave added this to To do in ENH Jul 9, 2022
@NickleDave
Copy link
Collaborator Author

Goal of this is to make it easier to instantiate a model, and fix #362 (linking here so I can close that one)

@NickleDave
Copy link
Collaborator Author

The one thing this still does not give us is a __call__ dunder-method, in the same way one can do with a torch.Module though :(

Would be really nice to be able to say y = netowrk(x) and then easily visualize both x and y, e.g. assuming one is a spectrogram and the other is predicted segments

@NickleDave NickleDave changed the title ENH: Make Model just an attrs class with a from_config classmethod ENH: change how Model is declared to remove sub-classing of vak.engine.Model Nov 23, 2022
@NickleDave
Copy link
Collaborator Author

picking this up again

what we want:

  • a way to get a Model instance for any model we declare where we can do two things:
    • just get the network output, e.g. for training
    • get an output with an optional post-processing transform applied, e.g. for prediction

how to implement it:

  • make a Model class with two methods:
    • forward, which just passes a tensor into Model.net and returns the output tensor, the same as one would do with a "raw" torch.nn.Module
    • the __call__ method, which internally calls Model.forward and then passes that output through a post-processing transform if one is specified, which will be something like torchvision.transforms.Compose

and additionally implement a decorator which accepts a user-specified model MyModel in the form of a class with a required set of class variables that correspond to attributes expected by both vak.Model and vak.Engine; the decorator returns a new class which basically has the user-specified MyModel as its own class attribute that it uses when making a new instance of Model

this starts to get pretty meta, I know (is this meta-programming yet?); I will post a code snippet next to illustrate a little better

@NickleDave
Copy link
Collaborator Author

Here's a code snippet that (I think) illustrates what I have in mind:

def model(model):
    """a decorator that creates a model"""

    class Model:
        self._model = model

        def __init__(network, 
                     loss,
                     optimizer,
                     metrics,
                     post_tfm=None):
            self.network = network
            self.loss = loss
            self.optimizer = optimizer
            self.metrics = metrics
            self.post_tfm = post_tfm

        def forward(input):
            return self.network(input)

        def __call__(input):
            output = self.network(input)
            if self.post_tfm:
                output = self.post_tfm(output)
            return output

        @classmethod
        def from_config(config):
            network = self._model.network(**config['network'])        
            loss = self._model.loss(**config['loss'])
            optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self._model.metrics.items()}
            # what to do about post_tfm? 
            # do we want to able to write a transform in a config file? 
            # or declare with a class, or both?
            return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
    
    return Model

@vak.model
class TweetyNet:
    network = TweetyNet(**config['network'])        
    loss = torch.nn.CrossEntropyLoss(**config['loss'])
    optimizer = torch.optim.Adam(params=network.parameters(), **config['optimizer'])
    metrics = {'acc': vak.metrics.Accuracy(),
               'levenshtein': vak.metrics.Levenshtein(),
               'segment_error_rate': vak.metrics.SegmentErrorRate(),
               'loss': torch.nn.CrossEntropyLoss()}

@NickleDave
Copy link
Collaborator Author

Unfinished business:

  • I want to be able to do the following I get a new model with defaults, so I don't need to pass in any config
VakMyModel = vak.model(MyModel)  # using the decorator directly as a function
VakMyModel()  # I don't have to pass in any config or anything, just make a new instance

We could do this by setting the default values for the Model.__init__ function to be instances of the user class attributes, I think? Or do I need to somehow override __new__ here? This is getting to be way meta

@NickleDave
Copy link
Collaborator Author

I think I have a very Minimal VP of this working:

import functools

import vak
import torch
import tweetynet

def model(model):
    """a decorator that creates a model"""

    @functools.wraps(model, updated=()) 
    class Model:
        _model = model

        def __init__(self,
                     network=None, 
                     loss=None,
                     optimizer=None,
                     metrics=None,
                     post_tfm=None):
            if network is None:
                network = self._model.network()
            if loss is None:
                loss = self._model.loss()
            if optimizer is None:
                optimizer = self._model.optimizer(params=network.parameters())
            if metrics is None:
                metrics = {metric_name: metric_class()
                           for metric_name, metric_class in self._model.metrics.items()}

            self.network = network
            self.loss = loss
            self.optimizer = optimizer
            self.metrics = metrics
            self.post_tfm = post_tfm

        def forward(input):
            return self.network(input)

        def __call__(input):
            output = self.network(input)
            if self.post_tfm:
                output = self.post_tfm(output)
            return output

        @classmethod
        def from_config(cls, config: dict):
            network = self._model.network(**config['network'])        
            loss = self._model.loss(**config['loss'])
            optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self._model.metrics.items()}
            # what to do about post_tfm? 
            # do we want to able to write a transform in a config file? 
            # or declare with a class, or both?
            return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
    
    return Model

@model
class TweetyNetModel:
    """Model that uses TweetyNet architecture"""
    network = tweetynet.TweetyNet
    loss = torch.nn.CrossEntropyLoss
    optimizer = torch.optim.Adam
    metrics = {'acc': vak.metrics.Accuracy,
               'levenshtein': vak.metrics.Levenshtein,
               'segment_error_rate': vak.metrics.SegmentErrorRate,
               'loss': torch.nn.CrossEntropyLoss}

>>> tweety = TweetyNetModel(network=tweetynet.TweetyNet(num_classes=10))

>>> type(tweety)
__main__.TweetyNetModel

>>> tweety.__doc__
'Model that uses TweetyNet architecture'

@NickleDave NickleDave moved this from To do to In progress in ENH Dec 18, 2022
@NickleDave
Copy link
Collaborator Author

NickleDave commented Dec 24, 2022

Feature branch is already in progress, but documenting here how I ended up implementing this.
Basically, as follows:

  • Have an attrs-like or dataclasses-like ModelDefinition that specifies a model's network, loss function, optimizer, and metrics as class variables, as in add a ModelDefinition class #406
  • Have a base Model class that sub-classes the lightning.LightningModule but additionally has a definition attribute, which is a ModelDefinition class (the whole class! not just an instance).
    • this Model does two things
      • it either accepts instances of the classes on the definition, i.e. an instance of the network that the definition specified, or if no instance is passed in, it makes a default version of the instance
      • it alternatively accepts a config or a path to a config that loads up and instantiates the classes from the definition using the config, i.e. through a class method
  • finally, specific families of models will further sub-class the Model class, and implement their own logic for the lightning.LightningModule methods like train_step, validation_step, predict_step, etc.
  • each of these will have a decorator that does the following
    • make a new subclass of the model family class that has the same name as the model definition, and has the model definition's class set as its definition attribute

So, from a user's perspective, they can in theory declare a model for any task they want by writing a model definition that they decorate with the decorator for a family of models. From their POV, they never do any sub-classing; they write a definition and apply a decorator to it.

@NickleDave
Copy link
Collaborator Author

Here's the Minimum Viable Implementation of that:

import functools
import pathlib
from typing import Callable, ClassVar, NewType

import lightning
import torch
import vak
from vak import labeled_timebins

class ModelDefinition:
    """A class that represents the definition of a model.

    A model definition should specify the following class variables:
        network: torch.nn.Module
        loss: torch.nn.Module
        optimizer: torch.optim.Optimizer
        metrics: dict
    """
    network: torch.nn.Module | dict[str: torch.nn.Module]
    loss: torch.nn.Module | dict[str: torch.nn.Module]
    optimizer: torch.optim.Optimizer
    metrics: dict[str: Callable]

class Model(lightning.LightningModule):
    definition: ClassVar[ModelDefinition]
    REQUIRED_CLASSVARS = ('network', 'loss', 'optimizer', 'metrics')

    def __init__(self,
                 network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 optimizer: torch.optim.Optimizer | None = None,
                 metrics: dict[str: Callable] | None = None):
        super().__init__()

        # check that we are a sub-class of some other class with required class variables
        if not hasattr(self, 'definition'):
            raise ValueError(
                'This model does not have a definition.'
                'Define a model by wrapping a class with the required class variables with '
                'the ``vak.models.model`` decorator.'
            )
        if not all(
                [hasattr(self.definition, reqd_classvar)
                 for reqd_classvar in self.REQUIRED_CLASSVARS]
        ):
            raise ValueError(
                'vak.Model classes should have all the following class variables defined:\n'
                f'{self.REQUIRED_CLASSVARS}'
            )

        if network is None:
            network = self.definition.network()
        if loss is None:
            loss = self.definition.loss()
        if optimizer is None:
            optimizer = self.definition.optimizer(params=network.parameters())
        if metrics is None:
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self.definition.metrics.items()}

    @classmethod
    def from_config(cls, config: dict, post_tfm: Callable | None = None):
        if isinstance(cls.definition.network, dict):
            network = {net_name: net_class(**config['network'][net_name])
                      for net_name, net_class in cls.definition.network.items()}
        elif isinstance(cls.definition.network, torch.nn.Module):
            network = cls.definition.network(**config['network'])

        if isinstance(cls.definition.optimizer, dict):
            # TODO: handle network parameters here
            # simplest case: make parameters from all nets first as flattened list, then pass in
            # more complex case: allow for net_opt config, 
            # not in either opt or net config so we can still just **unpack those
            optimizer = {opt_name: opt_class(**config['optimizer'][opt_name])
                      for opt_name, net_class in cls.definition.network.items()}
        elif isinstance(cls.definition.optimizer, torch.nn.Module):
            optimizer = cls.definition.optimizer(params=network.parameters(), **config['optimizer'])

        loss = cls.definition.loss(**config['loss'])
        metrics = {metric_name: metric_class()
                   for metric_name, metric_class in cls._model.metrics.items()}
        return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics, post_tfm=post_tfm)

    @classmethod
    def from_config_path(cls, config_path: str | pathlib.Path, post_tfm: Callable | None = None):
        # config = config.model  # need to figure out better config for models here
        # self.from_config(config)
        pass

# define this here instead of vak.typing to avoid circular imports
ModelSubclass = NewType('ModelSubclass', Model)

class WindowedFrameClassificationModel(Model):
    def __init__(self,
                 network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 optimizer: torch.optim.Optimizer | None = None,
                 metrics: dict[str: Callable] | None = None,
                 post_tfm: Callable | None = None,
                 ):
        """A LightningModule that represents
        a model that predicts a label for each frame
        in a window, e.g., each time bin in
        a window from a spectrogram.

        This task represents one way of
        predicting annotations for a vocalization,
        where the annotations consist of a sequence
        of segments, each with an onset, offset,
        and label.
        The model maps the spectrogam window
        to a vector of labels for each frame, i.e.,
        each time bin.

        To annotate a vocalization with such a model,
        the spectrogram is converted into a batch of
        consecutive non-overlapping windows,
        for which the model produces predictions.
        These predictions are then concatenated
        into a vector of labeled frames,
        from which the segments can be recovered.

        Post-processing can be applied to the vector
        to clean up noisy predictions
        before recovering the segments."""
        super().__init__(network=network, loss=loss,
                         optimizer=optimizer, metrics=metrics)
        self.lbl_tb2labels = labeled_timebins.lbl_tb2labels
        self.post_tfm = post_tfm

    def configure_optimizers(self):
        return self.optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        y_pred = self.network(x)
        loss = self.loss_func(y_pred, y)
        return loss

    def validation_step(self, batch, batch_idx):
        # TODO: rename "source" -> "spect"
        # TODO: a sample can have "spect", "audio", "annot", optionally other things ("padding"?)
        x, y = batch["source"], batch["annot"]
        # remove "batch" dimension added by collate_fn to x
        # we keep for y because loss still expects the first dimension to be batch
        # TODO: fix this weirdness. Diff't collate_fn?
        if x.ndim == 5:
            if x.shape[0] == 1:
                x = torch.squeeze(x, dim=0)
        else:
            raise ValueError(f"invalid shape for x: {x.shape}")

        out = self.network(x)
        # permute and flatten out
        # so that it has shape (1, number classes, number of time bins)
        # ** NOTICE ** just calling out.reshape(1, out.shape(1), -1) does not work, it will change the data
        out = out.permute(1, 0, 2)
        out = torch.flatten(out, start_dim=1)
        out = torch.unsqueeze(out, dim=0)
        # reduce to predictions, assuming class dimension is 1
        y_pred = torch.argmax(
            out, dim=1
        )  # y_pred has dims (batch size 1, predicted label per time bin)

        if "padding_mask" in batch:
            padding_mask = batch[
                "padding_mask"
            ]  # boolean: 1 where valid, 0 where padding
            # remove "batch" dimension added by collate_fn
            # because this extra dimension just makes it confusing to use the mask as indices
            if padding_mask.ndim == 2:
                if padding_mask.shape[0] == 1:
                    padding_mask = torch.squeeze(padding_mask, dim=0)
            else:
                raise ValueError(
                    f"invalid shape for padding mask: {padding_mask.shape}"
                )

            out = out[:, :, padding_mask]
            y_pred = y_pred[:, padding_mask]

        if self.post_tfm:
            y_pred = self.post_tfm(y_pred)

        y_labels = self.lbl_tb2labels(
            y.cpu().numpy(),
        )
        y_pred_labels = self.lbl_tb2labels(
            y_pred.cpu().numpy()
        )

        # TODO: figure out smarter way to do this
        for metric_name, metric_callable in self.metrics.items():
            if metric_name == "loss":
                self.log(f'val_{metric_name}', self.loss_func(out, y), batch_size=1)
            elif metric_name == "acc":
                self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1)
            elif metric_name == "levenshtein" or metric_name == "segment_error_rate":
                self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)

    def predict_step(self, batch, batch_idx: int):
        x, spect_path = batch["source"].to(self.device), batch["spect_path"]
        if isinstance(spect_path, list) and len(spect_path) == 1:
            spect_path = spect_path[0]
        if x.ndim == 5:
            if x.shape[0] == 1:
                x = torch.squeeze(x, dim=0)
        y_pred = self.network(x)
        return {spect_path: y_pred}

def windowed_frame_classification_model(modeldef: ModelDefinition) -> ModelSubclass:
    """A decorator that creates a model"""    
    attributes = dict(WindowedFrameClassificationModel.__dict__)
    attributes.update({'definition': modeldef})
    wrapped_model = type(modeldef.__name__, (WindowedFrameClassificationModel,), attributes)
   # realized when testing we don't actually need next line since we adding the ModelDefinition as an attribute
   # to the subclass we just made, i.e. we're not really wrapping a class here, so this doesn't make sense
    wrapped_model = functools.wraps(windowed_frame_classification_model, updated=())(wrapped_model)

    return wrapped_model

@windowed_frame_classification_model
class TweetyNetModel:
    network = vak.nets.TweetyNet
    loss = torch.nn.CrossEntropyLoss
    optimizer = torch.optim.Adam
    metrics = {'acc': vak.metrics.Accuracy,
               'levenshtein': vak.metrics.Levenshtein,
               'segment_error_rate': vak.metrics.SegmentErrorRate,
               'loss': torch.nn.CrossEntropyLoss}

After all this, one can do the following to instantiate a model, without needing a config

tweetynet = vak.nets.TweetyNet(num_classes=10)
model = TweetyNetModel(network=tweetynet)

and its methods will be the underlying LightningModule methods, specifically those implemented by WindowedFrameClassificationModel

model.predict_step
<bound method WindowedFrameClassificationModel.predict_step of windowed_frame_classification_model()>

@NickleDave
Copy link
Collaborator Author

Closed by #605

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ENH: enhancement enhancement; new feature or request
Projects
ENH
Done
Development

No branches or pull requests

1 participant