Skip to content

Commit

Permalink
Revise freeze embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Nzteb committed Sep 1, 2020
1 parent 5154efe commit 57a651e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 46 deletions.
40 changes: 22 additions & 18 deletions kge/model/embedder/lookup_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kge.model import KgeEmbedder
from kge.misc import round_to_points

from typing import List, Dict
from typing import List, Dict, Union


class LookupEmbedder(KgeEmbedder):
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
# initialize weights
self._init_embeddings(self._embeddings.weight.data)

self._embeddings_freeze = None
self._embeddings_frozen = None

# TODO handling negative dropout because using it with ax searches for now
dropout = self.get_option("dropout")
Expand Down Expand Up @@ -114,18 +114,22 @@ def _embeddings_all(self) -> Tensor:
def _get_regularize_weight(self) -> Tensor:
return self.get_option("regularize_weight")

def freeze(self, freeze_indexes) -> Tensor:
def freeze(self, freeze_indexes: Union[List, Tensor]) -> Tensor:
"""Freeze the embeddings of the entities specified by freeze_indexes.
This method overrides the _embed() and _embeddings_all() methods.
"""

num_freeze = len(freeze_indexes)

original_weights = self._embeddings.weight.data

self._embeddings_freeze = torch.nn.Embedding(
if isinstance(freeze_indexes, list):
freeze_indexes = torch.tensor(
freeze_indexes, device=self.config.get("job.device")
).long()

self._embeddings_frozen = torch.nn.Embedding(
num_freeze, self.dim, sparse=self.sparse,
)
self._embeddings = torch.nn.Embedding(
Expand All @@ -135,27 +139,27 @@ def freeze(self, freeze_indexes) -> Tensor:
# for a global index i stores at position i a 1
# when it corresponds to a frozen parameter
freeze_mask = torch.zeros(
self.vocab_size, dtype=torch.bool, device=original_weights.device
self.vocab_size, dtype=torch.bool, device=self.config.get("job.device")
)
freeze_mask[freeze_indexes] = 1

# assign current values to the new embeddings
self._embeddings_freeze.weight.data = original_weights[freeze_mask]
self._embeddings_frozen.weight.data = original_weights[freeze_mask]
self._embeddings.weight.data = original_weights[~freeze_mask]

# freeze
self._embeddings_freeze.weight.requires_grad = False
self._embeddings_frozen.weight.requires_grad = False

# for a global index i stores at position i its index in either the
# frozen or the non-frozen embedding tensor
positions = torch.zeros(
self.vocab_size, dtype=torch.long, device=self._embeddings.weight.device
global_to_local_mapper = torch.zeros(
self.vocab_size, dtype=torch.long, device=self.config.get("job.device")
)
positions[freeze_mask] = torch.arange(
global_to_local_mapper[freeze_mask] = torch.arange(
num_freeze, device=self.config.get("job.device")
)
positions[~freeze_mask] = torch.arange(
self.vocab_size - num_freeze, device=self._embeddings.weight.device
global_to_local_mapper[~freeze_mask] = torch.arange(
self.vocab_size - num_freeze, device=self.config.get("job.device")
)

def _embed(indexes: Tensor) -> Tensor:
Expand All @@ -166,12 +170,12 @@ def _embed(indexes: Tensor) -> Tensor:

frozen_indexes_mask = freeze_mask[indexes.long()]

emb[frozen_indexes_mask] = self._embeddings_freeze(
positions[indexes[frozen_indexes_mask].long()]
emb[frozen_indexes_mask] = self._embeddings_frozen(
global_to_local_mapper[indexes[frozen_indexes_mask].long()]
)

emb[~frozen_indexes_mask] = self._embeddings(
positions[indexes[~frozen_indexes_mask].long()]
global_to_local_mapper[indexes[~frozen_indexes_mask].long()]
)
return emb

Expand All @@ -181,11 +185,11 @@ def _embeddings_all() -> Tensor:
(self.vocab_size, self.dim), device=self._embeddings.weight.device
)

emb[freeze_mask] = self._embeddings_freeze(
emb[freeze_mask] = self._embeddings_frozen(
torch.arange(
num_freeze,
dtype=torch.long,
device=self._embeddings_freeze.weight.device,
device=self._embeddings_frozen.weight.device,
)
)

Expand Down
56 changes: 28 additions & 28 deletions kge/model/kge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kge.misc import filename_in_module
from kge.util import load_checkpoint
from typing import Any, Dict, List, Optional, Union, Tuple
from kge.util.io import file_to_list

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -450,36 +451,35 @@ def load_pretrained_model(
f"Could not find freeze files for {name} embedder"
)
else:
with open(freeze_file, "r") as file:
ids = file.read().rstrip("\n").splitlines()
id_map = self.dataset.load_map(f"{name}_ids")
freeze_indexes = list(
map(lambda _id: id_map.index(_id), ids)
)
model = self.config.get("model")
if (
model == "reciprocal_relations_model"
and name == "relation"
):
# this is the base model and num_relations is twice
# the number of relations already
reciprocal_indexes = list(
map(
lambda idx: idx
+ self.dataset.num_relations() / 2,
freeze_indexes,
)
)
freeze_indexes.extend(reciprocal_indexes)
if len(freeze_indexes) > len(set(freeze_indexes)):
raise Exception(
f"Unique set of ids needed for freezing {name}'s."
ids = file_to_list(freeze_file)
id_map = self.dataset.load_map(f"{name}_ids")
freeze_indexes = list(
map(lambda _id: id_map.index(_id), ids)
)
model = self.config.get("model")
if (
model == "reciprocal_relations_model"
and name == "relation"
):
# this is the base model and num_relations is twice
# the number of relations already
reciprocal_indexes = list(
map(
lambda idx: idx
+ self.dataset.num_relations() / 2,
freeze_indexes,
)

self.config.log(
f"Freezing {name} embeddings found in {freeze_file}"
)
embedder.freeze(freeze_indexes)
freeze_indexes.extend(reciprocal_indexes)
if len(freeze_indexes) > len(set(freeze_indexes)):
raise Exception(
f"Unique set of ids needed for freezing {name}'s."
)

self.config.log(
f"Freezing {name} embeddings found in {freeze_file}"
)
embedder.freeze(freeze_indexes)

#: Scorer
self._scorer: RelationalScorer
Expand Down
7 changes: 7 additions & 0 deletions kge/util/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def load_checkpoint(checkpoint_file: str, device="cpu"):
checkpoint["file"] = checkpoint_file
checkpoint["folder"] = os.path.dirname(checkpoint_file)
return checkpoint


def file_to_list(file: str):
"""Return lines of a file as list. """
with open(file, "r") as f:
data = f.read().rstrip("\n").splitlines()
return data

0 comments on commit 57a651e

Please sign in to comment.