Skip to content

Commit

Permalink
Use magicmock to check whether reset_parameters was called
Browse files Browse the repository at this point in the history
  • Loading branch information
mberr committed Nov 6, 2020
1 parent 6f59baa commit 8742e7f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
import unittest
from typing import Any, ClassVar, Mapping, Optional, Type
from unittest.mock import MagicMock

import numpy
import pytest
Expand Down Expand Up @@ -36,7 +37,7 @@
symmetric_edge_weights,
)
from pykeen.models.unimodal.trans_d import _project_entity
from pykeen.nn import Embedding, RepresentationModule
from pykeen.nn import RepresentationModule
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop, TrainingLoop
from pykeen.triples import TriplesFactory
from pykeen.utils import all_in_bounds, clamp_norm, set_random_seed
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self, num_entities: int, embedding_dim: int = 2):
super().__init__()
self.num_embeddings = num_entities
self.embedding_dim = embedding_dim
self.x = nn.Parameter(torch.rand(embedding_dim,))
self.x = nn.Parameter(torch.rand(embedding_dim, ))

def forward(self, indices: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
n = self.num_embeddings if indices is None else indices.shape[0]
Expand Down Expand Up @@ -462,8 +463,7 @@ def test_score_t_with_score_hrt_equality(self) -> None:

def test_reset_parameters_constructor_call(self):
"""Tests whether reset_parameters is called in the constructor."""
self.model.reset_parameters_ = None
assert isinstance(self.model, (EntityEmbeddingModel, EntityRelationEmbeddingModel))
self.model.reset_parameters_ = MagicMock(return_value=None)
try:
self.model.__init__(
self.factory,
Expand All @@ -472,6 +472,7 @@ def test_reset_parameters_constructor_call(self):
)
except TypeError as error:
assert error.args == ("'NoneType' object is not callable",)
self.model.reset_parameters_.assert_called_once()

def test_custom_representations(self):
"""Tests whether we can provide custom representations."""
Expand Down Expand Up @@ -726,7 +727,7 @@ def _check_constraints(self):
Enriched embeddings have to be reset.
"""
assert self.model.enriched_embeddings is None
assert self.model.entity_representations.enriched_embeddings is None


class TestRGCNBasis(_TestRGCN, unittest.TestCase):
Expand Down

0 comments on commit 8742e7f

Please sign in to comment.