In [13]:
import srsly
import torch

from transformers import DebertaV2Config, DebertaV2ForMaskedLM, DebertaV2Tokenizer, DebertaV2ForTokenClassification

In [59]:
class ModelPaths:
    generator_config = "deberta-v3-xsmall-changed/generator_config.json"
    generator_weights = "deberta-v3-xsmall-changed/pytorch_model.generator.bin"
    discriminator_config = "deberta-v3-xsmall-changed/config.json"
    discriminator_weights = "deberta-v3-xsmall-changed/pytorch_model.bin"
mpaths = ModelPaths()

In [57]:
def initialize_generator(mpaths: ModelPaths) -> DebertaV2ForMaskedLM:
    generator_config = DebertaV2Config(**srsly.read_json(mpaths.generator_config))
    generator = DebertaV2ForMaskedLM(generator_config)

    generator_weights = torch.load(mpaths.generator_weights, map_location=torch.device('cpu'))

    delete_keys = [
        "deberta.embeddings.word_embeddings.weight", # because we use a different vocab
        "deberta.embeddings.position_embeddings.weight",
        'lm_predictions.lm_head.bias'
    ]
    for key in delete_keys:
        del generator_weights[key] 

    rename_keys = {
        'lm_predictions.lm_head.dense.weight': 'cls.predictions.transform.dense.weight',
        'lm_predictions.lm_head.dense.bias': 'cls.predictions.transform.dense.bias',
        'lm_predictions.lm_head.LayerNorm.weight': 'cls.predictions.transform.LayerNorm.weight',
        'lm_predictions.lm_head.LayerNorm.bias': 'cls.predictions.transform.LayerNorm.bias',
    }
    for old_key, new_key in rename_keys.items():
        generator_weights[new_key] = generator_weights.pop(old_key)

    print(generator.load_state_dict(generator_weights, strict=False))

    return generator

In [58]:
generator = initialize_generator(mpaths)

_IncompatibleKeys(missing_keys=['deberta.embeddings.word_embeddings.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'], unexpected_keys=[])


In [None]:
def initialize_discriminator(mpaths: ModelPaths) -> DebertaV2ForTokenClassification:
    discriminator_config = DebertaV2Config(**srsly.read_json(mpaths.discriminator_config))
    discriminator_config.num_labels = 1
    discriminator = DebertaV2ForTokenClassification(discriminator_config)

    discriminator_weights = torch.load(mpaths.discriminator_weights, map_location=torch.device('cpu'))

    delete_keys = [
        "deberta.embeddings.word_embeddings.weight", # because we use a different vocab
    ]
    for key in delete_keys:
        del discriminator_weights[key]

    print(discriminator.load_state_dict(discriminator_weights, strict=False))

    return discriminator