-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: change how Model is declared to remove sub-classing of vak.engine.Model
#536
Comments
Goal of this is to make it easier to instantiate a model, and fix #362 (linking here so I can close that one) |
The one thing this still does not give us is a Would be really nice to be able to say |
from_config
classmethodvak.engine.Model
picking this up again what we want:
how to implement it:
and additionally implement a decorator which accepts a user-specified model this starts to get pretty meta, I know (is this meta-programming yet?); I will post a code snippet next to illustrate a little better |
Here's a code snippet that (I think) illustrates what I have in mind: def model(model):
"""a decorator that creates a model"""
class Model:
self._model = model
def __init__(network,
loss,
optimizer,
metrics,
post_tfm=None):
self.network = network
self.loss = loss
self.optimizer = optimizer
self.metrics = metrics
self.post_tfm = post_tfm
def forward(input):
return self.network(input)
def __call__(input):
output = self.network(input)
if self.post_tfm:
output = self.post_tfm(output)
return output
@classmethod
def from_config(config):
network = self._model.network(**config['network'])
loss = self._model.loss(**config['loss'])
optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
# what to do about post_tfm?
# do we want to able to write a transform in a config file?
# or declare with a class, or both?
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
return Model
@vak.model
class TweetyNet:
network = TweetyNet(**config['network'])
loss = torch.nn.CrossEntropyLoss(**config['loss'])
optimizer = torch.optim.Adam(params=network.parameters(), **config['optimizer'])
metrics = {'acc': vak.metrics.Accuracy(),
'levenshtein': vak.metrics.Levenshtein(),
'segment_error_rate': vak.metrics.SegmentErrorRate(),
'loss': torch.nn.CrossEntropyLoss()} |
Unfinished business:
VakMyModel = vak.model(MyModel) # using the decorator directly as a function
VakMyModel() # I don't have to pass in any config or anything, just make a new instance We could do this by setting the default values for the |
I think I have a very Minimal VP of this working: import functools
import vak
import torch
import tweetynet
def model(model):
"""a decorator that creates a model"""
@functools.wraps(model, updated=())
class Model:
_model = model
def __init__(self,
network=None,
loss=None,
optimizer=None,
metrics=None,
post_tfm=None):
if network is None:
network = self._model.network()
if loss is None:
loss = self._model.loss()
if optimizer is None:
optimizer = self._model.optimizer(params=network.parameters())
if metrics is None:
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
self.network = network
self.loss = loss
self.optimizer = optimizer
self.metrics = metrics
self.post_tfm = post_tfm
def forward(input):
return self.network(input)
def __call__(input):
output = self.network(input)
if self.post_tfm:
output = self.post_tfm(output)
return output
@classmethod
def from_config(cls, config: dict):
network = self._model.network(**config['network'])
loss = self._model.loss(**config['loss'])
optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
# what to do about post_tfm?
# do we want to able to write a transform in a config file?
# or declare with a class, or both?
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
return Model
@model
class TweetyNetModel:
"""Model that uses TweetyNet architecture"""
network = tweetynet.TweetyNet
loss = torch.nn.CrossEntropyLoss
optimizer = torch.optim.Adam
metrics = {'acc': vak.metrics.Accuracy,
'levenshtein': vak.metrics.Levenshtein,
'segment_error_rate': vak.metrics.SegmentErrorRate,
'loss': torch.nn.CrossEntropyLoss}
>>> tweety = TweetyNetModel(network=tweetynet.TweetyNet(num_classes=10))
>>> type(tweety)
__main__.TweetyNetModel
>>> tweety.__doc__
'Model that uses TweetyNet architecture' |
Feature branch is already in progress, but documenting here how I ended up implementing this.
So, from a user's perspective, they can in theory declare a model for any task they want by writing a model definition that they decorate with the decorator for a family of models. From their POV, they never do any sub-classing; they write a definition and apply a decorator to it. |
Here's the Minimum Viable Implementation of that: import functools
import pathlib
from typing import Callable, ClassVar, NewType
import lightning
import torch
import vak
from vak import labeled_timebins
class ModelDefinition:
"""A class that represents the definition of a model.
A model definition should specify the following class variables:
network: torch.nn.Module
loss: torch.nn.Module
optimizer: torch.optim.Optimizer
metrics: dict
"""
network: torch.nn.Module | dict[str: torch.nn.Module]
loss: torch.nn.Module | dict[str: torch.nn.Module]
optimizer: torch.optim.Optimizer
metrics: dict[str: Callable]
class Model(lightning.LightningModule):
definition: ClassVar[ModelDefinition]
REQUIRED_CLASSVARS = ('network', 'loss', 'optimizer', 'metrics')
def __init__(self,
network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
metrics: dict[str: Callable] | None = None):
super().__init__()
# check that we are a sub-class of some other class with required class variables
if not hasattr(self, 'definition'):
raise ValueError(
'This model does not have a definition.'
'Define a model by wrapping a class with the required class variables with '
'the ``vak.models.model`` decorator.'
)
if not all(
[hasattr(self.definition, reqd_classvar)
for reqd_classvar in self.REQUIRED_CLASSVARS]
):
raise ValueError(
'vak.Model classes should have all the following class variables defined:\n'
f'{self.REQUIRED_CLASSVARS}'
)
if network is None:
network = self.definition.network()
if loss is None:
loss = self.definition.loss()
if optimizer is None:
optimizer = self.definition.optimizer(params=network.parameters())
if metrics is None:
metrics = {metric_name: metric_class()
for metric_name, metric_class in self.definition.metrics.items()}
@classmethod
def from_config(cls, config: dict, post_tfm: Callable | None = None):
if isinstance(cls.definition.network, dict):
network = {net_name: net_class(**config['network'][net_name])
for net_name, net_class in cls.definition.network.items()}
elif isinstance(cls.definition.network, torch.nn.Module):
network = cls.definition.network(**config['network'])
if isinstance(cls.definition.optimizer, dict):
# TODO: handle network parameters here
# simplest case: make parameters from all nets first as flattened list, then pass in
# more complex case: allow for net_opt config,
# not in either opt or net config so we can still just **unpack those
optimizer = {opt_name: opt_class(**config['optimizer'][opt_name])
for opt_name, net_class in cls.definition.network.items()}
elif isinstance(cls.definition.optimizer, torch.nn.Module):
optimizer = cls.definition.optimizer(params=network.parameters(), **config['optimizer'])
loss = cls.definition.loss(**config['loss'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in cls._model.metrics.items()}
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics, post_tfm=post_tfm)
@classmethod
def from_config_path(cls, config_path: str | pathlib.Path, post_tfm: Callable | None = None):
# config = config.model # need to figure out better config for models here
# self.from_config(config)
pass
# define this here instead of vak.typing to avoid circular imports
ModelSubclass = NewType('ModelSubclass', Model)
class WindowedFrameClassificationModel(Model):
def __init__(self,
network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
metrics: dict[str: Callable] | None = None,
post_tfm: Callable | None = None,
):
"""A LightningModule that represents
a model that predicts a label for each frame
in a window, e.g., each time bin in
a window from a spectrogram.
This task represents one way of
predicting annotations for a vocalization,
where the annotations consist of a sequence
of segments, each with an onset, offset,
and label.
The model maps the spectrogam window
to a vector of labels for each frame, i.e.,
each time bin.
To annotate a vocalization with such a model,
the spectrogram is converted into a batch of
consecutive non-overlapping windows,
for which the model produces predictions.
These predictions are then concatenated
into a vector of labeled frames,
from which the segments can be recovered.
Post-processing can be applied to the vector
to clean up noisy predictions
before recovering the segments."""
super().__init__(network=network, loss=loss,
optimizer=optimizer, metrics=metrics)
self.lbl_tb2labels = labeled_timebins.lbl_tb2labels
self.post_tfm = post_tfm
def configure_optimizers(self):
return self.optimizer
def training_step(self, batch, batch_idx):
x, y = batch[0], batch[1]
y_pred = self.network(x)
loss = self.loss_func(y_pred, y)
return loss
def validation_step(self, batch, batch_idx):
# TODO: rename "source" -> "spect"
# TODO: a sample can have "spect", "audio", "annot", optionally other things ("padding"?)
x, y = batch["source"], batch["annot"]
# remove "batch" dimension added by collate_fn to x
# we keep for y because loss still expects the first dimension to be batch
# TODO: fix this weirdness. Diff't collate_fn?
if x.ndim == 5:
if x.shape[0] == 1:
x = torch.squeeze(x, dim=0)
else:
raise ValueError(f"invalid shape for x: {x.shape}")
out = self.network(x)
# permute and flatten out
# so that it has shape (1, number classes, number of time bins)
# ** NOTICE ** just calling out.reshape(1, out.shape(1), -1) does not work, it will change the data
out = out.permute(1, 0, 2)
out = torch.flatten(out, start_dim=1)
out = torch.unsqueeze(out, dim=0)
# reduce to predictions, assuming class dimension is 1
y_pred = torch.argmax(
out, dim=1
) # y_pred has dims (batch size 1, predicted label per time bin)
if "padding_mask" in batch:
padding_mask = batch[
"padding_mask"
] # boolean: 1 where valid, 0 where padding
# remove "batch" dimension added by collate_fn
# because this extra dimension just makes it confusing to use the mask as indices
if padding_mask.ndim == 2:
if padding_mask.shape[0] == 1:
padding_mask = torch.squeeze(padding_mask, dim=0)
else:
raise ValueError(
f"invalid shape for padding mask: {padding_mask.shape}"
)
out = out[:, :, padding_mask]
y_pred = y_pred[:, padding_mask]
if self.post_tfm:
y_pred = self.post_tfm(y_pred)
y_labels = self.lbl_tb2labels(
y.cpu().numpy(),
)
y_pred_labels = self.lbl_tb2labels(
y_pred.cpu().numpy()
)
# TODO: figure out smarter way to do this
for metric_name, metric_callable in self.metrics.items():
if metric_name == "loss":
self.log(f'val_{metric_name}', self.loss_func(out, y), batch_size=1)
elif metric_name == "acc":
self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1)
elif metric_name == "levenshtein" or metric_name == "segment_error_rate":
self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)
def predict_step(self, batch, batch_idx: int):
x, spect_path = batch["source"].to(self.device), batch["spect_path"]
if isinstance(spect_path, list) and len(spect_path) == 1:
spect_path = spect_path[0]
if x.ndim == 5:
if x.shape[0] == 1:
x = torch.squeeze(x, dim=0)
y_pred = self.network(x)
return {spect_path: y_pred}
def windowed_frame_classification_model(modeldef: ModelDefinition) -> ModelSubclass:
"""A decorator that creates a model"""
attributes = dict(WindowedFrameClassificationModel.__dict__)
attributes.update({'definition': modeldef})
wrapped_model = type(modeldef.__name__, (WindowedFrameClassificationModel,), attributes)
# realized when testing we don't actually need next line since we adding the ModelDefinition as an attribute
# to the subclass we just made, i.e. we're not really wrapping a class here, so this doesn't make sense
wrapped_model = functools.wraps(windowed_frame_classification_model, updated=())(wrapped_model)
return wrapped_model
@windowed_frame_classification_model
class TweetyNetModel:
network = vak.nets.TweetyNet
loss = torch.nn.CrossEntropyLoss
optimizer = torch.optim.Adam
metrics = {'acc': vak.metrics.Accuracy,
'levenshtein': vak.metrics.Levenshtein,
'segment_error_rate': vak.metrics.SegmentErrorRate,
'loss': torch.nn.CrossEntropyLoss} After all this, one can do the following to instantiate a model, without needing a config tweetynet = vak.nets.TweetyNet(num_classes=10)
model = TweetyNetModel(network=tweetynet) and its methods will be the underlying model.predict_step
<bound method WindowedFrameClassificationModel.predict_step of windowed_frame_classification_model()> |
Closed by #605 |
edit: I think I have a better way to do this, see below
to decouple Model from Engine to address #362 , add a
Model
attrs class with afrom_config
classmethod.This will be a kind of "interface": if you want to be a Model, have a
from_config
method that will result in network + loss + optimizer +metrics attributesThis way user doesn't have to actually subclass this
Model
attrs class.They can use the built-in attrs class if they, want or they can make some other dataclass that can be used with the cli as long as it obeys the "interface" .
vak
will "know" to instantiate the model usingfrom_config
The text was updated successfully, but these errors were encountered: