Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freeze embeddings #136

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 99 additions & 3 deletions kge/model/embedder/lookup_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kge.model import KgeEmbedder
from kge.misc import round_to_points

from typing import List, Dict
from typing import List, Dict, Union


class LookupEmbedder(KgeEmbedder):
Expand Down Expand Up @@ -44,6 +44,8 @@ def __init__(
# initialize weights
self._init_embeddings(self._embeddings.weight.data)

self._embeddings_frozen = None

# TODO handling negative dropout because using it with ax searches for now
dropout = self.get_option("dropout")
if dropout < 0:
Expand Down Expand Up @@ -89,7 +91,10 @@ def init_pretrained(self, pretrained_embedder: KgeEmbedder) -> None:
)

def embed(self, indexes: Tensor) -> Tensor:
return self._postprocess(self._embeddings(indexes.long()))
return self._postprocess(self._embed(indexes))

def _embed(self, indexes: Tensor) -> Tensor:
return self._embeddings(indexes.long())

def embed_all(self) -> Tensor:
return self._postprocess(self._embeddings_all())
Expand All @@ -109,6 +114,97 @@ def _embeddings_all(self) -> Tensor:
def _get_regularize_weight(self) -> Tensor:
return self.get_option("regularize_weight")

def freeze(self, freeze_indexes: Union[List, Tensor]) -> Tensor:
"""Freeze the embeddings of the entities specified by freeze_indexes.

This method overrides the _embed() and _embeddings_all() methods.

"""
num_freeze = len(freeze_indexes)

original_weights = self._embeddings.weight.data

if isinstance(freeze_indexes, list):
freeze_indexes = torch.tensor(
freeze_indexes, device=self.config.get("job.device")
).long()

self._embeddings_frozen = torch.nn.Embedding(
num_freeze, self.dim, sparse=self.sparse,
)
self._embeddings = torch.nn.Embedding(
self.vocab_size - num_freeze, self.dim, sparse=self.sparse,
)

# for a global index i stores at position i a 1
# when it corresponds to a frozen parameter
freeze_mask = torch.zeros(
self.vocab_size, dtype=torch.bool, device=self.config.get("job.device")
)
freeze_mask[freeze_indexes] = 1

# assign current values to the new embeddings
self._embeddings_frozen.weight.data = original_weights[freeze_mask]
self._embeddings.weight.data = original_weights[~freeze_mask]

# freeze
self._embeddings_frozen.weight.requires_grad = False

# for a global index i stores at position i its index in either the
# frozen or the non-frozen embedding tensor
global_to_local_mapper = torch.zeros(
self.vocab_size, dtype=torch.long, device=self.config.get("job.device")
)
global_to_local_mapper[freeze_mask] = torch.arange(
num_freeze, device=self.config.get("job.device")
)
global_to_local_mapper[~freeze_mask] = torch.arange(
self.vocab_size - num_freeze, device=self.config.get("job.device")
)

def _embed(indexes: Tensor) -> Tensor:

emb = torch.empty(
(len(indexes), self.dim), device=self._embeddings.weight.device
)

frozen_indexes_mask = freeze_mask[indexes.long()]

emb[frozen_indexes_mask] = self._embeddings_frozen(
global_to_local_mapper[indexes[frozen_indexes_mask].long()]
)

emb[~frozen_indexes_mask] = self._embeddings(
global_to_local_mapper[indexes[~frozen_indexes_mask].long()]
)
return emb

def _embeddings_all() -> Tensor:

emb = torch.empty(
(self.vocab_size, self.dim), device=self._embeddings.weight.device
)

emb[freeze_mask] = self._embeddings_frozen(
torch.arange(
num_freeze,
dtype=torch.long,
device=self._embeddings_frozen.weight.device,
)
)

emb[~freeze_mask] = self._embeddings(
torch.arange(
self.vocab_size - num_freeze,
dtype=torch.long,
device=self._embeddings.weight.device,
)
)
return emb

self._embeddings_all = _embeddings_all
self._embed = _embed

def penalty(self, **kwargs) -> List[Tensor]:
# TODO factor out to a utility method
result = super().penalty(**kwargs)
Expand All @@ -135,7 +231,7 @@ def penalty(self, **kwargs) -> List[Tensor]:
unique_indexes, counts = torch.unique(
kwargs["indexes"], return_counts=True
)
parameters = self._embeddings(unique_indexes)
parameters = self._embed(unique_indexes)
if p % 2 == 1:
parameters = torch.abs(parameters)
result += [
Expand Down
7 changes: 7 additions & 0 deletions kge/model/embedder/lookup_embedder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ lookup_embedder:
# the packaged model
# if false initialize other embeddings normally
ensure_all: False

# Freeze a subset of the embeddings during training. Expects a file with
# entity/relation ids per line. Expects either an absolute path or the filename when
# the file is located in the dataset folder. Embeddings associated with the ids
# are hold constant during training. Leave empty for not freezing embeddings.
freeze:
ids_file: ""

# Dropout used for the embeddings.
dropout: 0.
Expand Down
44 changes: 44 additions & 0 deletions kge/model/kge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kge.misc import filename_in_module
from kge.util import load_checkpoint
from typing import Any, Dict, List, Optional, Union, Tuple
from kge.util.io import file_to_list

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -436,6 +437,49 @@ def load_pretrained_model(
self._relation_embedder.init_pretrained(
pretrained_relations_model.get_p_embedder()
)
# freeze embeddings if desired
for embedder, name in [
(self._relation_embedder, "relation"),
(self._entity_embedder, "entity"),
]:
freeze_file = embedder.get_option("freeze.ids_file")
if freeze_file != "":
if not os.path.isfile(freeze_file):
freeze_file = os.path.join(self.dataset.folder, freeze_file)
if not os.path.isfile(freeze_file):
raise FileNotFoundError(
f"Could not find freeze files for {name} embedder"
)
else:
ids = file_to_list(freeze_file)
id_map = self.dataset.load_map(f"{name}_ids")
freeze_indexes = list(
map(lambda _id: id_map.index(_id), ids)
)
model = self.config.get("model")
if (
model == "reciprocal_relations_model"
and name == "relation"
):
# this is the base model and num_relations is twice
# the number of relations already
reciprocal_indexes = list(
map(
lambda idx: idx
+ self.dataset.num_relations() / 2,
freeze_indexes,
)
)
freeze_indexes.extend(reciprocal_indexes)
if len(freeze_indexes) > len(set(freeze_indexes)):
raise Exception(
f"Unique set of ids needed for freezing {name}'s."
)

self.config.log(
f"Freezing {name} embeddings found in {freeze_file}"
)
embedder.freeze(freeze_indexes)

#: Scorer
self._scorer: RelationalScorer
Expand Down
7 changes: 7 additions & 0 deletions kge/util/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def load_checkpoint(checkpoint_file: str, device="cpu"):
checkpoint["file"] = checkpoint_file
checkpoint["folder"] = os.path.dirname(checkpoint_file)
return checkpoint


def file_to_list(file: str):
"""Return lines of a file as list. """
with open(file, "r") as f:
data = f.read().rstrip("\n").splitlines()
return data
96 changes: 96 additions & 0 deletions tests/test_freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest
import os
import torch
from tests.util import create_config, empty_cache, get_cache_dir
from kge.misc import kge_base_dir
from kge.model.kge_model import KgeModel
from kge.job import TrainingJob
from kge.dataset import Dataset


class TestFreeze(unittest.TestCase):
def setUp(self) -> None:
self.dataset_name = "toy"
self.folder = os.path.join(get_cache_dir(), "test_freeze")
self.config = create_config(self.dataset_name)
self.config.folder = self.folder
self.config.init_folder()
self.config.set("train.max_epochs", 1)
self.dataset = Dataset.create(config=self.config)

def tearDown(self) -> None:
empty_cache()

def test_freeze(self) -> None:
"""Test if frozen embeddings are correctly frozen.

Ensure, after calling freeze() of the LookupEmbedder, embeddings are hold
constant during training.

"""

model = KgeModel.create(config=self.config, dataset=self.dataset)

# freeze every other entity and relation embedding
freeze_indexes_ent = list(range(0, model.dataset.num_entities(), 2))
freeze_indexes_rel = list(range(0, model.dataset.num_relations(), 2))

entity_embedder = model.get_o_embedder()
relation_embedder = model.get_p_embedder()

# copy before freeze
frozen_emb_rel = (
relation_embedder.embed(torch.tensor(freeze_indexes_rel)).clone().detach()
)

frozen_emb_ent = (
entity_embedder.embed(torch.tensor(freeze_indexes_ent)).clone().detach()
)

# freeze
entity_embedder.freeze(freeze_indexes_ent)
relation_embedder.freeze(freeze_indexes_rel)

training_job = TrainingJob.create(
config=model.config, dataset=model.dataset, model=model
)
training_job.run()

frozen_emb_rel_after = relation_embedder.embed(torch.tensor(freeze_indexes_rel))
frozen_emb_ent_after = entity_embedder.embed(torch.tensor(freeze_indexes_ent))

# Ensure the frozen embeddings have not been changed
self.assertTrue(
torch.all(torch.eq(frozen_emb_ent, frozen_emb_ent_after)),
msg="Frozen parameter changed during training",
)

self.assertTrue(
torch.all(torch.eq(frozen_emb_rel, frozen_emb_rel_after)),
msg="Frozen parameter changed during training",
)

def test_scores_after_freeze(self) -> None:
"""Test if score calculation is correct after calling freeze() on Embeddings."""

model = KgeModel.create(config=self.config, dataset=self.dataset)

# freeze every other entity and relation embedding
freeze_indexes_ent = list(range(0, model.dataset.num_entities(), 2))
freeze_indexes_rel = list(range(0, model.dataset.num_relations(), 2))

entity_embedder = model.get_o_embedder()
relation_embedder = model.get_p_embedder()

triples = self.dataset.split("train")
scores_before = model.score_spo(triples[:, 0], triples[:, 1], triples[:, 2])

entity_embedder.freeze(freeze_indexes_ent)
relation_embedder.freeze(freeze_indexes_rel)

scores_after = model.score_spo(triples[:, 0], triples[:, 1], triples[:, 2])

self.assertTrue(
torch.all(torch.eq(scores_before, scores_after)),
msg="Model score computation has changed after calling freeze."
)
16 changes: 15 additions & 1 deletion tests/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from kge import Config
from kge.misc import kge_base_dir

from os import path
import shutil

def create_config(test_dataset_name: str, model: str = "complex") -> Config:
config = Config()
Expand All @@ -16,3 +17,16 @@ def create_config(test_dataset_name: str, model: str = "complex") -> Config:

def get_dataset_folder(dataset_name):
return os.path.join(kge_base_dir(), "tests", "data", dataset_name)


def get_cache_dir():
return os.path.join(kge_base_dir(), "tests", "data", "cache")


def empty_cache():
for file in os.listdir(get_cache_dir()):
obj = path.join(get_cache_dir(), file)
if os.path.isfile(obj):
os.remove(obj)
else:
shutil.rmtree(obj)