In [21]:
import json
import logging
import os
from argparse import Namespace

import click
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import WEIGHTS_NAME

from utils.entity_vocab import MASK_TOKEN

from exp_utils import set_seed
from exp_utils.trainer import Trainer, trainer_args
from re_model import LukeForRelationClassification
from re_utils import HEAD_TOKEN, TAIL_TOKEN, convert_examples_to_features, DatasetProcessor
from transformers.tokenization_roberta import RobertaTokenizer

import numpy as np

from transformers import (
    AutoConfig,
    AutoModelForPreTraining,
    get_constant_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)

from luke.model import LukeConfig
from luke.optimization import LukeAdamW
from luke.pretraining.batch_generator import LukePretrainingBatchGenerator, MultilingualBatchGenerator
from luke.pretraining.dataset import WikipediaPretrainingDataset
from luke.pretraining.model import LukePretrainingModel
from luke.utils.model_utils import ENTITY_VOCAB_FILE


In [22]:
from types import SimpleNamespace
metadata_folder = "luke_model/"

class obj(object):
    def __init__(self, d):
        for a, b in d.items():
            if isinstance(b, (list, tuple)):
               setattr(self, a, [obj(x) if isinstance(x, dict) else x for x in b])
            else:
               setattr(self, a, obj(b) if isinstance(b, dict) else b)

with open(os.path.join(metadata_folder, "metadata.json")) as f:
    model_config = obj(json.load(f)["model_config"])

print(model_config.vocab_size)

50265


In [23]:
class params:
    def __init__(self, model_config):
        self.data_dir = "data/tacred/json"
        self.do_train = "--no-train"
        self.train_batch_size = 4
        self.num_train_epochs = 5.0
        self.do_val = "--no-eval"
        self.eval_batch_size = 128
        self.seed = 42
        self.bert_model_name = "roberta-large"
        self.max_mention_length = 30
        self.local_rank = -1
        self.tokenizer =  RobertaTokenizer.from_pretrained(self.bert_model_name)
        self.model_config = model_config
        self.model_weights = {"embeddings.word_embeddings.weight":0.25, "entity_embeddings.entity_embeddings.weight":0.25}
        self.adam_b1 =  0.9
        self.adam_b2 =  0.999
        self.adam_eps =  1e-06
        self.batch_size =  2048
        self.dataset_dir =  "enwiki_20181220_dataset_500k_roberta_cand30_unk"
        self.entity_emb_size =  256
#         self.fix_bert_weights =  False
        self.fp16 =  False
        self.fp16_master_weights =  True
        self.fp16_max_loss_scale =  4
        self.fp16_min_loss_scale =  1
        self.fp16_opt_level =  "O2"
#         self.global_step =  96454
#         self.grad_avg_on_cpu =  False
        self.gradient_accumulation_steps =  64
        self.learning_rate =  1e-05
#         self.local_rank =  0
#         self.log_dir =  "log_mon/roberta_large_luke7_500k_nonoise_mlm0.15_ment0.15_emb256_b2048_lrate1e-5_warm2500_epoch20_pre20"
        self.lr_schedule =  "warmup_linear"
        self.masked_entity_prob =  0.15
        self.masked_lm_prob =  0.15
        self.max_grad_norm =  0.0
#         self.model_file =  "out_mon/roberta_large_luke7_500k_nonoise_mlm0.15_ment0.15_emb256_b2048_lrate1e-5_warm2500_epoch20_pre20/model_step0096454.bin"
#         self.num_epochs =  20
        self.optimizer_file =  "out_mon/roberta_large_luke7_500k_nonoise_mlm0.15_ment0.15_emb256_b2048_lrate1e-5_warm2500_epoch20_pre20/optimizer_step0096454.bin"
        self.output_dir =  "out_mon/roberta_large_luke7_500k_nonoise_mlm0.15_ment0.15_emb256_b2048_lrate1e-5_warm2500_epoch20_pre20"
#         self.parallel =  True
#         self.save_interval_sec =  1600
#         self.scheduler_file =  "out_mon/roberta_large_luke7_500k_nonoise_mlm0.15_ment0.15_emb256_b2048_lrate1e-5_warm2500_epoch20_pre20/scheduler_step0096454.bin"
        self.warmup_steps =  2500
        self.weight_decay =  0.01
#         self.whole_word_masking =  True
# BASED ON GUSS from Here
        self.adam_correct_bias = 0  #based on guess
        self.warmup_proportion = 0.025 #based on guess
        self.device = "cpu" #cuda - didn't work
#         self.tokenizer = {"max_len": 512, "bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>", "init_inputs": []}


In [24]:
args = params(model_config)
args.tokenizer.pad_token_id
logger = logging.getLogger(__name__)


        


In [25]:
def load_and_cache_examples(args, fold="train"):

    processor = DatasetProcessor()
    if fold == "train":
        examples = processor.get_train_examples(args.data_dir)
    elif fold == "dev":
        examples = processor.get_dev_examples(args.data_dir)
    else:
        examples = processor.get_test_examples(args.data_dir)

    label_list = processor.get_label_list(args.data_dir)

    bert_model_name = args.bert_model_name

    cache_file = os.path.join(
        args.data_dir,
        "cached_" + "_".join((args.bert_model_name.split("-")[0], str(args.max_mention_length), fold)) + ".pkl",
    )
    print('cache_file', cache_file)
    
    if os.path.exists(cache_file):
        logger.info("Loading features from cached file %s", cache_file)
        tokens = [] # This is is only used for testing and is available when we are calculating first time. It is not available when we read from the cache.
        features = torch.load(cache_file)
    else:
        logger.info("Creating features from dataset file")
        features, tokens = convert_examples_to_features(examples, label_list, args.tokenizer, args.max_mention_length)

        if args.local_rank in (-1, 0):
            torch.save(features, cache_file)

    
    def collate_fn(batch):
        def create_padded_sequence(attr_name, padding_value):
            tensors = [torch.tensor(getattr(o, attr_name), dtype=torch.long) for o in batch]
            return torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        return dict(
            word_ids=create_padded_sequence("word_ids", args.tokenizer.pad_token_id),
            word_attention_mask=create_padded_sequence("word_attention_mask", 0),
            word_segment_ids=create_padded_sequence("word_segment_ids", 0),
            entity_ids=create_padded_sequence("entity_ids", 0),
            entity_attention_mask=create_padded_sequence("entity_attention_mask", 0),
            entity_position_ids=create_padded_sequence("entity_position_ids", -1),
            entity_segment_ids=create_padded_sequence("entity_segment_ids", 0),
            label=torch.tensor([o.label for o in batch], dtype=torch.long),
        )

    if fold in ("dev", "test"):
        dataloader = DataLoader(features, batch_size=args.eval_batch_size, shuffle=False, collate_fn=collate_fn)
    else:
        if args.local_rank == -1:
            sampler = RandomSampler(features)
        else:
            sampler = DistributedSampler(features)
        dataloader = DataLoader(features, sampler=sampler, batch_size=args.train_batch_size, collate_fn=collate_fn)

    return dataloader, examples, features, tokens, label_list

In [26]:
# args.model_config.vocab_size += 2
# word_emb = args.model_weights["embeddings.word_embeddings.weight"]
# head_emb = word_emb[args.tokenizer.convert_tokens_to_ids(["@"])[0]].unsqueeze(0)
# tail_emb = word_emb[args.tokenizer.convert_tokens_to_ids(["#"])[0]].unsqueeze(0)
# args.model_weights["embeddings.word_embeddings.weight"] = torch.cat([word_emb, head_emb, tail_emb])
# args.tokenizer.add_special_tokens(dict(additional_special_tokens=[HEAD_TOKEN, TAIL_TOKEN]))

# entity_emb = args.model_weights["entity_embeddings.entity_embeddings.weight"]
# mask_emb = entity_emb[args.entity_vocab[MASK_TOKEN]].unsqueeze(0).expand(2, -1)
# args.model_config.entity_vocab_size = 3
# args.model_weights["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb[:1], mask_emb])


In [27]:
dataloader, examples, features, tokens, label_list = load_and_cache_examples(args)

  0%|                                                                              | 75/68124 [00:00<01:31, 744.56it/s]

cache_file data/tacred/json\cached_roberta_30_train.pkl


100%|██████████████████████████████████████████████████████████████████████████| 68124/68124 [00:27<00:00, 2459.50it/s]


In [28]:
type(dataloader), type(examples), type(features), type(label_list)

(torch.utils.data.dataloader.DataLoader, list, list, list)

In [29]:
np.array(examples).shape, np.array(features).shape, np.array(label_list).shape, 

((68124,), (68124,), (42,))

In [30]:
print(label_list[0:20])

['no_relation', 'org:alternate_names', 'org:city_of_headquarters', 'org:country_of_headquarters', 'org:dissolved', 'org:founded', 'org:founded_by', 'org:member_of', 'org:members', 'org:number_of_employees/members', 'org:parents', 'org:political/religious_affiliation', 'org:shareholders', 'org:stateorprovince_of_headquarters', 'org:subsidiaries', 'org:top_members/employees', 'org:website', 'per:age', 'per:alternate_names', 'per:cause_of_death']


In [31]:
from re_utils import InputExample
for a in examples[0:5]:
    print ( a.span_a, a.span_b, a.type_a, a.type_b, a.label)

[54, 77] [0, 12] ORGANIZATION PERSON org:founded_by
[35, 44] [95, 103] PERSON PERSON no_relation
[137, 141] [36, 49] ORGANIZATION ORGANIZATION no_relation
[306, 313] [155, 167] ORGANIZATION NUMBER no_relation
[128, 159] [72, 78] ORGANIZATION DATE no_relation


In [35]:
for b in tokens[100:150]:
    print (b)

ĠRes
istant
Ġ:
Ġ100
m
/
330
ft
ĠCrystal
Ġ:
ĠSc
ratch
ĠRes
istant
ĠSapphire
ĠOur
ĠPrice
Ġ:
Ġhttp
://
www
.
wh
oles
ale
-
w
atches
.
org
/
w
rist
watch
-
251
.
html
</s>


In [39]:
for b in features[0:10]:
    print ( b.word_ids,
         b.word_segment_ids,
         b.word_attention_mask,
         b.entity_ids,
         b.entity_position_ids,
         b.entity_segment_ids,
         b.entity_attention_mask,
         b.label)

[0, 3, 1560, 2032, 873, 1728, 1437, 3, 6490, 11, 779, 94, 76, 7, 1026, 5, 1437, 3, 404, 7093, 6157, 139, 9127, 1437, 3, 111, 574, 26770, 12, 3943, 111, 25733, 387, 12, 2156, 6724, 5, 1929, 19, 601, 453, 9, 3589, 2156, 3735, 6100, 20303, 1745, 40702, 324, 6395, 7, 30887, 3589, 8, 486, 5, 6788, 729, 479, 2] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [1, 2] [[17, 18, 19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]] [0, 0] [1, 1] 6
[0, 96, 13668, 2156, 10, 76, 71, 5, 2669, 2156, 1437, 3, 18095, 

In [13]:
args.tokenizer.convert_tokens_to_ids(["@"])[0]

1039

In [14]:
num_labels = len(label_list)
# model_weights= [0.34,0.33,0.33]

In [15]:
# from luke.pretraining.model import LukePretrainingModel
# path = "saved_model/luke.bin"
# model = LukePretrainingModel.from_pretrained(path, cache_dir=None)

In [16]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/pretrain_model')

In [17]:
# modelA = TheModelAClass(*args, **kwargs)
model = torch.load("saved_model/luke.bin",map_location=torch.device('cpu'))
type(model)
# model = torch.load("C:/prabhu/edu/code/w266/Luke/model/luke_20200528.tar")
# writer.add_graph(model)
# writer.close()
i=0
for key, value in model.items():
    i+=1
    print (i, "\n", key, " ", torch.tensor(value),  "\n")
    if i > 10:
        break

1 
 encoder.layer.0.attention.self.query.weight   tensor([[-0.0029,  0.0352,  0.0007,  ...,  0.0023,  0.0595, -0.0426],
        [-0.0248,  0.0529, -0.0145,  ..., -0.0303, -0.0143,  0.0116],
        [ 0.0061,  0.0708, -0.0336,  ...,  0.0807,  0.0115, -0.0131],
        ...,
        [-0.0589,  0.0206, -0.0426,  ..., -0.0298,  0.0041,  0.0700],
        [ 0.0421,  0.0225, -0.0608,  ..., -0.0552, -0.0157,  0.0173],
        [-0.0184, -0.0457, -0.0103,  ...,  0.0474,  0.0225, -0.0182]]) 

2 
 encoder.layer.0.attention.self.query.bias   tensor([ 0.3121,  0.0556, -0.0751,  ..., -0.0704, -0.0500, -0.0664]) 

3 
 encoder.layer.0.attention.self.key.weight   tensor([[-0.0043, -0.0184, -0.0136,  ..., -0.0037,  0.0096, -0.0156],
        [-0.0238, -0.0002,  0.0253,  ...,  0.0403,  0.0436, -0.0195],
        [-0.0264, -0.0522, -0.0125,  ..., -0.0359,  0.0077,  0.0150],
        ...,
        [-0.0718, -0.0261, -0.0203,  ..., -0.0186,  0.0097,  0.1023],
        [ 0.0157,  0.0065, -0.0171,  ..., -0.0038, -0.

  print (i, "\n", key, " ", torch.tensor(value),  "\n")


## Below code works for loading the model from saved model. 

In [18]:
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from luke.pretraining.model import EntityPredictionHeadTransform, LukePretrainingModel,EntityPredictionHead

In [19]:
path = "saved_model/luke.bin"


bert_model_name = "roberta-large"


bert_config = AutoConfig.from_pretrained(bert_model_name)

print("bert configuration", bert_config.to_dict())

config = LukeConfig(
    entity_vocab_size=500000,
    bert_model_name=bert_model_name,
    entity_emb_size=256,
    **bert_config.to_dict(),
)
model = LukePretrainingModel(config)

bert configuration {'return_dict': False, 'output_hidden_states': False, 'output_attentions': False, 'use_cache': True, 'torchscript': False, 'use_bfloat16': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'is_encoder_decoder': False, 'is_decoder': False, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'chunk_size_feed_forward': 0, 'architectures': ['RobertaForMaskedLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 0, 'pad_token_id': 1, 'eos_token_id': 2, 'decoder_start_token_id': None, 'task_specific_params': None, 'xla_device': None, 'model_type': 'roberta', 'vocab_size': 50265, 

In [20]:
from numba import cuda
from torch import nn

# model.to(device='cpu')
# model = nn.DataParallel(model)
model.load_state_dict(torch.load(path), strict=False)
# model.load_state_dict(torch.load(path), device='cpu')
# writer.add_graph(model)
model.eval()


LukePretrainingModel(
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=1024, out_features=4096, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=4096, out_features=1024, bias=True)
          (LayerNorm): LayerNorm((1024,), ep

## Trying to use trainer class -- didn't work

In [21]:
# gradient_accumulation_steps = 64
# num_train_steps_per_epoch = len(dataloader) // gradient_accumulation_steps
# num_train_steps = int(num_train_steps_per_epoch * args.num_train_epochs)

# print('train steps per epoch: ', num_train_steps_per_epoch ,  '  train steps: ', num_train_steps  )

# best_dev_f1 = [-1]
# best_weights = [None]

# def step_callback(model, global_step):
#     print('global  step: ', global_step )
#     if global_step % num_train_steps_per_epoch == 0 and args.local_rank in (0, -1):
#         epoch = int(global_step / num_train_steps_per_epoch - 1)
#         dev_results = evaluate(args, model, fold="dev")
#         args.experiment.log_metrics({f"dev_{k}_epoch{epoch}": v for k, v in dev_results.items()}, epoch=epoch)
#         results.update({f"dev_{k}_epoch{epoch}": v for k, v in dev_results.items()})
#         tqdm.write("dev: " + str(dev_results))

#         if dev_results["f1"] > best_dev_f1[0]:
#             if hasattr(model, "module"):
#                 best_weights[0] = {k: v.to("cpu").clone() for k, v in model.module.state_dict().items()}
#             else:
#                 best_weights[0] = {k: v.to("cpu").clone() for k, v in model.state_dict().items()}
#             best_dev_f1[0] = dev_results["f1"]
#             results["best_epoch"] = epoch

#         model.train()

# trainer = Trainer(
#     args, model=model, dataloader=dataloader, num_train_steps=num_train_steps, step_callback=step_callback
# )
# trainer.train()


## Exploring more on Model parameters

### Create optimizer.  Seperating LayerNorm weight and bias from others.

In [22]:
from transformers import WEIGHTS_NAME, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.weight"]
optimizer_parameters = [
    {
        "params": [p for n, p in param_optimizer if p.requires_grad and not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in param_optimizer if p.requires_grad and any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
print( np.array(optimizer_parameters).shape)
optimizer= AdamW(
            optimizer_parameters,
            lr= args.learning_rate,
            eps= args.adam_eps,
            betas=( args.adam_b1, args.adam_b2),
            correct_bias=args.adam_correct_bias,
        )

(2,)


In [23]:
optimizer

AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: 0
    eps: 1e-06
    lr: 1e-05
    weight_decay: 0.01

Parameter Group 1
    betas: (0.9, 0.999)
    correct_bias: 0
    eps: 1e-06
    lr: 1e-05
    weight_decay: 0.0
)

### Create Scheduler

In [24]:
num_train_steps = 10 # temporarilty assigned some value, need to find right value
warmup_steps = int(num_train_steps * args.warmup_proportion)
if args.lr_schedule == "warmup_linear":
    scheduler= get_linear_schedule_with_warmup(optimizer, warmup_steps, num_train_steps)
if args.lr_schedule == "warmup_constant":
    scheduler= get_constant_schedule_with_warmup(optimizer, warmup_steps)
else:
    RuntimeError("Unsupported scheduler: " + args.lr_schedule)

In [25]:
scheduler

<torch.optim.lr_scheduler.LambdaLR at 0x24390ab6610>

### Parallel data loader - Ignore for now

In [26]:
# model = torch.nn.parallel.DistributedDataParallel(
#     model,
#     device_ids=[args.local_rank],
#     output_device=args.local_rank,
#     find_unused_parameters=True,
# )

In [27]:
def _create_model_arguments( batch):
    return batch

In [28]:
for step, batch in enumerate(dataloader):
    print(step, batch)
    if step > 4:
        break

0 {'word_ids': tensor([[    0,     3,  5991,   344,   324,  1437,     3,     8,  1437,     3,
         43154,  8316,  3181,   493,  1437,     3,  1008,    11,     5,   200,
           457,   296,     7,   483,   436,     7,    10,   132,    12,   288,
           339,    81,   188,  3324,  2156,  2749,    62,    10,   297,  6156,
          8430,   136,  8683,    15,   395,   479,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1],
        [    0,    20,  5226,    34,    57,     5,   652,     9,  1437,     3,
          3943,  1437,     3,   128,    29,  1953,    31,  1437,     3,   666,
          1437,     3,     8,  1752,  2156,  8165,  4595,    15,     5,  1546,
           128,    29,  9852,  2308,  1230,   340,   479,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,

## Start Training

In [29]:
epoch = 0
global_step = 0
tr_loss = 0.0

num_workers = torch.cuda.device_count()

def maybe_no_sync(step):
    if (
        hasattr(model, "no_sync")
        and num_workers > 1
        and (step + 1) % self.args.gradient_accumulation_steps != 0
    ):
        return model.no_sync()
    else:
        return contextlib.ExitStack()
counter = 0;
model.train()

counter = 0

with tqdm(total=num_train_steps, disable=args.local_rank not in (-1, 0)) as pbar:
    while True:
        for step, batch in enumerate(dataloader):
            inputs = {k: v.to(args.device) for k, v in _create_model_arguments(batch).items()}
            outputs = model(**inputs)
            print('inputs:', inputs)
            print('outputs', outputs)
            counter += 1
            if counter > 0:
                break
        break
#             loss = outputs[0]
#             if args.gradient_accumulation_steps > 1:
#                 loss = loss / args.gradient_accumulation_steps

#             with maybe_no_sync(step):
#                 if args.fp16:
#                     with amp.scale_loss(loss, optimizer) as scaled_loss:
#                         scaled_loss.backward()
#                 else:
#                     loss.backward()

#             tr_loss += loss.item()
#             if (step + 1) % args.gradient_accumulation_steps == 0:
#                 if args.max_grad_norm != 0.0:
#                     if args.fp16:
#                         torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
#                     else:
#                         torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

#                 optimizer.step()
#                 scheduler.step()
#                 model.zero_grad()

#                 pbar.set_description("epoch: %d loss: %.7f" % (epoch, loss.item()))
#                 pbar.update()
#                 global_step += 1

#                 if step_callback is not None:
#                     step_callback(model, global_step)

#                 if (
#                     args.local_rank in (-1, 0)
#                     and args.output_dir
#                     and args.save_steps > 0
#                     and global_step % args.save_steps == 0
#                 ):
#                     output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))

#                     if hasattr(model, "module"):
#                         torch.save(model.module.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
#                     else:
#                         torch.save(model.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))

#                 if global_step == num_train_steps:
#                     break
# USED FOR TESTING
#         if global_step == num_train_steps:
#             break
#         epoch += 1

# logger.info("global_step = %s, average loss = %s", global_step, tr_loss / global_step)


  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

inputs: {'word_ids': tensor([[    0,   382,  2096,  1863,     9,   331,  1437,     3, 10730, 11247,
          1437,     3,    21, 14589,    15,     5,  2570,    30,    10, 27622,
          8090,   259,    15,  1437,     3,   294,  1437,     3,  2156,    10,
           183,   137,     5,   563,    21,  2633,     7,  2604,  1863,  1292,
          5981, 11488,    12, 16956,  2156,     5,   331,   641,    26,   479,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1],
        [    0, 34628,  1313,  5457,   270,  5141,  2235,  2156,  3287,   270,
          4282,  1284,  2156,  1437,     3,  1863,     9,  4545,  1437,     3,
          2156,  1437,     3,   610,  9153,  1437,     3,  2156,  4753,     4,
           479,     2,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1, 




example1 =  {'word_ids': tensor([[    0,     3,   229,  8629,  1222,  1437,     3,  2156,    10,  1437,
             3,  4423,  1437,     3,  6239,     8,  1030,  3313,     9,  1600,
          2156,    21,    11,     5,   609,     9,  1959,    10,  1859,  6239,
            77,    37,    21,  1128,    11,  1752,    11,   628,  5155,   479,
             2,     1,     1,     1,     1,     1,     1],
        [    0,   286,  1246,  2156,    13,   107,    89,   128,    29,    57,
             5, 38162,   470,   788,   268,  1544,   111,   574, 26770,    12,
          1437,     3,    83,  3813,  1437,     3,   111, 25733,   387,    12,
          4248,     5,    55,  1611, 20391,   730,   128,    29,  1437,     3,
          2573,   788,   268,  1437,     3,   479,     2],
        [    0,    96,  1285,     7,  1437,     3,    69,  1437,     3,   979,
          2156,  1437,     3,    79,  1437,     3,    16,  5601,    30,    10,
         21002,     8,    10,   372,    12, 11377, 26243,   479,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1],
        [    0,   370,    32,  9180,    14,     5,  1437,     3,  1148,     9,
           391,  1704,  4466,  1890,  2485,  1437,     3,    16,  1826,    10,
           529,     9,    63,  1505,  1674,   111,   574, 26770,    12,  1437,
             3,  9841,  1437,     3,   111, 25733,   387,    12,    31,   601,
            12,   844,   772,  3010,   479,     2,     1]]), 
            'word_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]), 
             'word_segment_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
             'entity_ids': tensor([[1, 2],
        [1, 2],
        [1, 2],
        [1, 2]]), 
             'entity_attention_mask': tensor([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1]]), 
             'entity_position_ids': tensor([[[ 1,  2,  3,  4,  5,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],

        [[21, 22, 23, 24, 25, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [39, 40, 41, 42, 43, 44, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],

        [[12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [ 5,  6,  7,  8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],

        [[30, 31, 32, 33, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [ 7,  8,  9, 10, 11, 12, 13, 14, 15, 16, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]]), 'entity_segment_ids': tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0]]), 
             'label': tensor([31,  0,  0,  0])}
