Skip to content

Commit

Permalink
CLN: Make fixes to work with vak 0.8.0, fixes #218
Browse files Browse the repository at this point in the history
- Add post_tfm parameter to TweetyNetModel.from_config
- DEV: Raise minimum required version of vak to 0.8.0
  • Loading branch information
NickleDave committed Feb 10, 2023
1 parent 0dff786 commit d426e7a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ requires-python = ">=3.8"
license = {file = "LICENSE"}
dependencies = [
"torch>=1.7.1",
"vak>=0.6.0"
"vak>=0.8.0"
]
[project.optional-dependencies]
test = [
Expand Down
4 changes: 2 additions & 2 deletions src/tweetynet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def acc(y_pred, y):

class TweetyNetModel(vak.Model):
@classmethod
def from_config(cls, config):
def from_config(cls, config, post_tfm=None):
network = TweetyNet(**config['network'])
loss = torch.nn.CrossEntropyLoss(**config['loss'])
optimizer = torch.optim.Adam(params=network.parameters(), **config['optimizer'])
metrics = {'acc': vak.metrics.Accuracy(),
'levenshtein': vak.metrics.Levenshtein(),
'segment_error_rate': vak.metrics.SegmentErrorRate(),
'loss': torch.nn.CrossEntropyLoss()}
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics, post_tfm=post_tfm)

0 comments on commit d426e7a

Please sign in to comment.