In [1]:
from datasets import load_from_disk

from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling

import torch
from torch.utils.data import DataLoader

from models import initialize_discriminator, initialize_generator, ModelPaths

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set torch random seed
torch.manual_seed(0)

<torch._C.Generator at 0x7f3fc37d1150>

In [3]:
class TrainArgs(ModelPaths):
    per_device_train_batch_size: int = 1
    temperature: float = 1.0
    rtd_lambda: float = 50.
targs = TrainArgs()

In [4]:
dataset = load_from_disk('ds_subset_encoded')

In [5]:
tokenizer = AutoTokenizer.from_pretrained("debertinha-v2-tokenizer")
discriminator = initialize_discriminator(targs)
generator = initialize_generator(targs)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


_IncompatibleKeys(missing_keys=['deberta.embeddings.word_embeddings.weight', 'classifier.weight', 'classifier.bias'], unexpected_keys=['lm_predictions.lm_head.bias', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.dense.weight', 'mask_predictions.dense.bias', 'mask_predictions.LayerNorm.weight', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.weight', 'mask_predictions.classifier.bias', 'deberta.embeddings.position_embeddings._weight', 'deberta.embeddings.position_embeddings.weight', 'deberta.embeddings.word_embeddings._weight'])
_IncompatibleKeys(missing_keys=['deberta.embeddings.word_embeddings.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'], unexpected_keys=[])


In [7]:
def _set_param(module, param_name, value):
    if hasattr(module, param_name):
      delattr(module, param_name)
    module.register_buffer(param_name, value)

def disentangled_hook(module, *inputs):
    g_w_ebd = generator.deberta.embeddings.word_embeddings
    d_w_ebd = discriminator.deberta.embeddings.word_embeddings
    _set_param(d_w_ebd, 'weight', g_w_ebd.weight.detach() + d_w_ebd.weight)

discriminator.register_forward_pre_hook(disentangled_hook)

<torch.utils.hooks.RemovableHandle at 0x7f3f1ab25630>

In [8]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15
)

In [9]:
train_dataloader = DataLoader(
    dataset, shuffle=True, collate_fn=data_collator, batch_size=targs.per_device_train_batch_size
)

In [10]:
batch = next(iter(train_dataloader))

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [11]:
def topk_sampling(logits, topk = 1, temp=1):
    top_p = torch.nn.functional.softmax(logits/temp, dim=-1)
    topk = max(1, topk)
    next_tokens = torch.multinomial(top_p, topk)
    return next_tokens, top_p

In [12]:
gen_outputs = generator(**batch)

In [13]:
mlm_labels = batch['labels']
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']

In [14]:
gen_logits = gen_outputs.logits
gen_logits = gen_logits.view(-1, gen_logits.size(-1))
topk_labels, _ = topk_sampling(gen_logits, topk=1, temp=targs.temperature)

In [15]:
mask_index = (mlm_labels.view(-1)>0).nonzero().view(-1)
top_ids = torch.zeros_like(mlm_labels.view(-1))
top_ids.scatter_(index=mask_index.long(), src=topk_labels.view(-1).long(), dim=-1)
top_ids = top_ids.view(mlm_labels.size())
new_ids = torch.where(mlm_labels>0, top_ids, input_ids)

In [16]:
disc_batch = {
    'input_ids': new_ids,
    'attention_mask': batch['attention_mask'],
}

In [17]:
disc_outputs = discriminator(**disc_batch)

In [18]:
disc_logits = disc_outputs.logits
disc_logits.shape

torch.Size([4, 512, 1])

In [19]:
mask_logits = disc_logits.view(-1)
_input_mask = attention_mask.view(-1).to(mask_logits)
input_idx = (_input_mask>0).nonzero().view(-1)
mask_labels = ((mlm_labels>0) & (mlm_labels!=input_ids)).view(-1)
mask_labels = torch.gather(mask_labels.to(mask_logits), 0, input_idx)
mask_loss_fn = torch.nn.BCEWithLogitsLoss()
mask_logits = torch.gather(mask_logits, 0, input_idx).float()
mask_loss = mask_loss_fn(mask_logits, mask_labels)