In [1]:
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""

'\nFine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\nGPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\nusing a masked language modeling (MLM) loss.\n'

importing necessary libraries

In [2]:
import argparse
import glob
import logging
import re
import shutil
from typing import Dict, List, Tuple

In [3]:
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing

In [4]:
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
)

In [5]:
import time
import sys
import os
import wandb
import random

In [6]:
try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

Defining a class Args with various attributes initialized with default values.

In [12]:
class Args:

    def __init__(self):
        self.output_root = 'OUTPUTS'
        self.model_type = 'roberta'
        #self.early_stop = early_stop
        self.line_by_line = True
        #self.should_continue = should_continue
        self.model_name_or_path = ''
        self.mlm = True
        self.mlm_probability = 0.2
        self.tokenizer_name = 'roberta-base'
        self.vocab_size = '50_000'
        self.cache_dir = 'cache'
        self.block_size = -1
        self.do_train = True
        self.do_eval = True
        self.evaluate_during_training = True
        self.per_gpu_train_batch_size = 2#wandb.config.batch_size
        self.per_gpu_eval_batch_size = 2#wandb.config.batch_size
        self.gradient_accumulation_steps = 4#wandb.config.gradient_accumulation_steps
        self.learning_rate = 5e-5 #wandb.config.learning_rate
        self.weight_decay = 0.0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1.0
        self.num_train_epochs = 1 #wandb.config.epochs
        self.max_steps = -1
        self.warmup_steps = 0
        self.logging_steps = 500#logging_steps
        self.save_steps = 500#logging_steps
        self.save_total_limit = None
        self.eval_all_checkpoints = False
        self.no_cuda = False
        self.overwrite_output_dir = True
        self.seed = 42
        self.fp16 = False
        self.fp16_opt_level = "O1"
        self.local_rank = -1
        self.server_ip = ""
        self.server_port = ""
        # run.save()
        # # self.name = run.name
        # if self.should_continue:
        self.output_dir = self.output_root
        # else:
        #     self.output_dir = self.output_root  + '/' + self.name


args = Args()
args.vocab_size = 50_000
args.train_data_file = '/home/user1-selab3/shradha_test/roberta/roberta/DATASET/JAVA/TOKEN/RAW/training_masked_code-few'
args.eval_data_file = '/home/user1-selab3/shradha_test/roberta/roberta/DATASET/JAVA/TOKEN/RAW/eval_masked_code-few'
args.tokenizer_name = '/home/user1-selab3/shradha_test/roberta/roberta/CODE/TrainedTokenizer'

To checks whether args.local_rank is -1 or args.no_cuda is True. If either condition is met, it sets the device to CUDA if CUDA is available and args.no_cuda is not True, otherwise it sets the device to CPU. It also sets args.n_gpu to the number of available GPUs.

In [59]:
# if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
# else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
#         torch.cuda.set_device(args.local_rank)
#         device = torch.device("cuda", args.local_rank)
#         torch.distributed.init_process_group(backend="nccl")
#         args.n_gpu = 1
args.device = device
print(args.n_gpu)
print(args.device)

2
cuda


In [60]:

logger = logging.getLogger(__name__)

configures the logging format using basicConfig, specifying the format of the log messages, including the timestamp, log level, logger name, and the actual message. 

In [61]:
# Setup logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    args.local_rank,
    device,
    args.n_gpu,
    bool(args.local_rank != -1),
    args.fp16,
)



Setting the seed for random number generation and ensuring reproducibility of results across different runs of the code. 

In [16]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
        
set_seed(args)

In [17]:
# Dictionary mapping model types to their corresponding classes
MODEL_CLASSES = {
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
}

# Retrieve the classes corresponding to the specified model type from MODEL_CLASSES
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

print(config_class, model_class, tokenizer_class)

<class 'transformers.models.roberta.configuration_roberta.RobertaConfig'> <class 'transformers.models.roberta.modeling_roberta.RobertaForMaskedLM'> <class 'transformers.models.roberta.tokenization_roberta.RobertaTokenizer'>


 Generating a configuration dictionary for a RoBERTa model and creating a RobertaConfig object using the provided arguments. Then, initializing a tokenizer using the specified tokenizer name and adjusting the block size based on the tokenizer's maximum model length.

In [20]:
# Define configuration parameters for RoBERTa model
def get_config(args):
    config = {
        "model_type": "roberta",
        "attention_probs_dropout_prob": 0.1,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.3,
        "hidden_size": 768, #wandb.config.hidden_size,
        "initializer_range": 0.02,
        "num_attention_heads": 16, #wandb.config.num_attention_heads,
        "num_hidden_layers": 12, #wandb.config.num_hidden_layers,
        "vocab_size": 1_130, #args.vocab_size,
        "intermediate_size": 4_096, #wandb.config.intermediate_size,
        "max_position_embeddings": 1024,
        "cache_dir": '' #args.cache_dir
    }
# Return a RobertaConfig object initialized with the config parameters
    return RobertaConfig(**config)
# Get the configuration
config = get_config(args)

# Initialize tokenizer
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

# Adjust block size based on tokenizer's maximum model length
if args.block_size <= 0:
    args.block_size = tokenizer.model_max_length
    # Our input block size will be the max possible for the model
else:
    args.block_size = min(args.block_size, tokenizer.model_max_length)

print(tokenizer)
print(tokenizer.model_max_length)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'RobertaTokenizer'.


RobertaTokenizer(name_or_path='/home/user1-selab3/shradha_test/roberta/roberta/CODE/TrainedTokenizer', vocab_size=261, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	4: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),
}
100000000000000001988462483

In [21]:
len(tokenizer)

261

If an exception occurs during model initialization, it is caught, and an error message is logged using the logger.error() method.

In [22]:
logger.info("Training new model from scratch")
try:
    model = model_class(config=config)
except Exception as e:
    logger.error(f'{e} Configuration not correct')

04/02/2024 13:38:29 - INFO - __main__ -   Training new model from scratch


 Adds special tokens to both the tokenizer and the model if they have not already been added.

In [23]:
def add_special_tokens_(model, tokenizer):
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer.encoder)
    # Add special tokens to the tokenizer and get the number of tokens added
    num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ['<z>', '<x>']}) # doesn't add if they are already there
    # If tokens were added
    if num_added_tokens > 0:
        # Resize the token embeddings in the model to accommodate the newly added tokens
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)

# Add special tokens to the model and tokenizer
add_special_tokens_(model, tokenizer)
# Move the model to the specified device
model.to(args.device)

RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(263, 768)
      (position_embeddings): Embedding(1024, 768, padding_idx=1)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.3, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,)

To process a text file line by line using a Byte Level Byte Pair Encoding (ByteLevelBPETokenizer) tokenizer

In [24]:
class LineByLineDatasetWithBPETokenizer(Dataset):
    def __init__(self, file_path: str = None, tokenizer_path: str = None):
        # Initialize the ByteLevelBPETokenizer using the provided tokenizer path
        tokenizer = ByteLevelBPETokenizer(
            tokenizer_path + "/vocab.json",
            tokenizer_path + "/merges.txt",
        )
        # Set up special tokens for the tokenizer
        tokenizer._tokenizer.post_processor = BertProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )
        # Enable truncation of sequences to a maximum length of 512 tokens
        tokenizer.enable_truncation(max_length=512)

        # Initialize an empty list to store tokenized examples
        self.examples = []

        # Read lines from the input file
        with open(file_path, encoding="utf-8") as f:
            lines = f.readlines()
            # Filter out empty lines and lines with only whitespace characters
            lines = [line for line in lines if (len(line) > 0 and not line.isspace())]
            # Tokenize each non-empty line and add the token IDs to self.examples
            self.examples += [x.ids for x in tokenizer.encode_batch(lines)]

    def __len__(self):
        # Return the total number of examples in the dataset
        return len(self.examples)

    def __getitem__(self, i):
        # Return a single example from the dataset at the given index i
        # Represented as a PyTorch tensor containing token IDs
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i])

 loading and caching examples from the training or evaluation data files using the LineByLineDatasetWithBPETokenizer

In [25]:
def load_and_cache_examples(args, evaluate=False):
    # Determine the file path based on whether it's for evaluation or training
    file_path = args.eval_data_file if evaluate else args.train_data_file
    print(file_path)
    print(args.tokenizer_name)
    
    # Return a LineByLineDatasetWithBPETokenizer object initialized with the file path and tokenizer name
    return LineByLineDatasetWithBPETokenizer(file_path, args.tokenizer_name)


In [26]:
train_dataset = load_and_cache_examples(args, evaluate=False)

/home/user1-selab3/shradha_test/roberta/roberta/DATASET/JAVA/TOKEN/RAW/training_masked_code-few
/home/user1-selab3/shradha_test/roberta/roberta/CODE/TrainedTokenizer


to retrieve and sort checkpoint paths based on certain criteria

In [27]:
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    # Initialize an empty list to store checkpoint paths and their associated information
    ordering_and_checkpoint_path = []

    # Find all files matching the checkpoint pattern in the output directory
    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    # Process each found checkpoint path
    for path in glob_checkpoints:
        if use_mtime:
            # If use_mtime is True, append a tuple containing the modification time and path
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            # If use_mtime is False, extract the checkpoint number and append a tuple containing it and path
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    # Sort the list of checkpoint paths based on the criteria specified
    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    
    # Extract only the paths from the sorted list of tuples
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    
    # Return the sorted list of checkpoint paths
    return checkpoints_sorted

for managing checkpoints by deleting older checkpoints when the number of checkpoints exceeds a specified limit (args.save_total_limit)

In [28]:
import shutil

def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    # If save_total_limit is not set or is 0 or less, there's no need to rotate checkpoints
    if not args.save_total_limit or args.save_total_limit <= 0:
        return

    # Get a sorted list of checkpoint paths
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    
    # If the number of checkpoints is within the limit, there's no need to delete any checkpoints
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    # Calculate the number of checkpoints to delete
    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    # Select the oldest checkpoints to delete
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    
    # Delete the selected checkpoints
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)

In [29]:
def decision(probability): #this function takes probability as input and returns true with probability
    #It essentially simulates a random decision based on the provided probability. 
    return random.random() < probability

read_masked_dataset function processes a batch of inputs and their corresponding labels to prepare them for training. It adjusts the lengths of inputs and labels to match the maximum length among them, and pads them accordingly.

In [50]:
def read_masked_dataset(tokenizer: PreTrainedTokenizer, batch, labels_to_process) -> Tuple[torch.Tensor, torch.Tensor]:
    # The inputs are already masked in the training file
    # Clone the batch of inputs to prevent modifying the original data
    tmp_inputs = batch.clone()

    # Process inputs
    tmp_inputs_list = []
    for input in tmp_inputs:
        # Decode input, then re-encode it without special tokens
        decoded_input = tokenizer.decode(input)
        encoded_back = tokenizer.encode(decoded_input)[1:-1] # Removes the additional <s> and </s> added
        tmp_inputs_list.append(encoded_back)

    # Gets the maximum length between inputs and labels_lines
    # We then need to adapt one or the other to have the same length through padding
        
    # # Calculate maximum length among inputs and labels
    max_length_inputs = max([len(input) for input in tmp_inputs_list])
    max_length_labels_lines = max([len(label) for label in labels_to_process])
    max_length = max_length_inputs
    if max_length_labels_lines > max_length_inputs:
        max_length = max_length_labels_lines
    print(max_length)
    print("%%")
    print(tmp_inputs_list)

    # Create the labels tensor
    labels_to_convert_in_tensor = []

    #It encodes each label using the tokenizer, excluding special tokens, and pads the encoded label to match the maximum length.
    i = 0
    while i < len(batch):
        l1_tmp = tokenizer.encode(labels_to_process[i])
        label_to_add = []
        for token in l1_tmp:
            if token != tokenizer.bos_token_id and token != tokenizer.eos_token_id:  # Remove special tokens
                label_to_add.append(token)

        j = len(label_to_add)
        while j < max_length:
            label_to_add.append(-100)  # we only compute loss for masked tokens
            j += 1

        labels_to_convert_in_tensor.append(label_to_add)
        i += 1

    labels = torch.as_tensor(labels_to_convert_in_tensor)

    # Process inputs
    inputs_to_convert = []
    for input in tmp_inputs_list:
        tmp_input = []
        for token in input:
            tmp_input.append(token)
 #It pads each input to match the maximum length, using the tokenizer's pad token ID.
        i = len(tmp_input)
        while i < max_length:
            tmp_input.append(tokenizer.pad_token_id)
            i += 1
        inputs_to_convert.append(tmp_input)

    inputs = torch.as_tensor(inputs_to_convert)

    return inputs, labels

reads masked instances but provides them as non-masked to the model. It essentially removes the masking and prepares the inputs for the model, where the model is expected to understand that if no masked tokens are present, nothing must be produced except for the <z> token.

In [51]:
# Reads the masked instances, but provide them as non-masked to the model
# This is used in 10% of cases during training
def get_non_masked_instances(tokenizer: PreTrainedTokenizer, batch, labels_to_process) -> Tuple[torch.Tensor, torch.Tensor]:
    tmp_inputs_list = []
    i = 0
    while i < len(batch):
        # Decode the input, replace <x> tokens with labels, then re-encode the modified input
        decoded_input = tokenizer.decode(batch[i]).replace('<x>',str(labels_to_process[i]).replace('<z>\n',''))
        encoded_back = tokenizer.encode(decoded_input)[1:-1]
        tmp_inputs_list.append(encoded_back)
        i += 1

    # Gets the maximum length
    max_length = max([len(input) for input in tmp_inputs_list])

    inputs_to_convert = []
    labels_to_convert = []
    for input in tmp_inputs_list:
        tmp_input = []
        tmp_label = []
        tmp_label.append(tokenizer.convert_tokens_to_ids('<z>'))
        for token in input:
            tmp_input.append(token)
            tmp_label.append(-100)

        del tmp_label[-1] #Accounts for the fact that tmp_label already contains <z>

        i = len(tmp_input)
        while i < max_length:
            tmp_input.append(tokenizer.pad_token_id)
            tmp_label.append(-100)
            i += 1
        labels_to_convert.append(tmp_label)
        inputs_to_convert.append(tmp_input)

    inputs = torch.as_tensor(inputs_to_convert)
    labels = torch.as_tensor(labels_to_convert)

    # We train the model to understand that if no masked tokens are present, nothing must be produced, only <z>
    return inputs, labels

read_perfect_predictions_from_file function reads the number of perfect predictions from a file. It searches for a specific pattern perfect_predictions = <number> in the content of the file and extracts the number following that pattern.

In [52]:
def read_perfect_predictions_from_file(file_path):
    # Open the file and read its content
    with open(file_path) as f:
        content = f.read()
    
    # Compile a regular expression pattern to search for perfect predictions
    p = re.compile('perfect_predictions = (.*?)\n')
    
    # Search for the pattern in the content and extract the number
    perfect_predictions = int(p.search(content).group(1))
    
    return perfect_predictions


This get_number_perfect_predictions function computes the number of perfect predictions made by a model on a given evaluation dataset. It compares the model's predictions with the ground truth labels and counts the instances where the prediction matches the label perfectly.

In [53]:
def get_number_perfect_predictions(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, eval_data_file):
    labels_file = str(eval_data_file).replace('masked_code_', 'mask_')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Inputs
    with open(eval_data_file) as f:
        inputs = f.readlines()
    inputs = [x.strip() for x in inputs]

    # Targets
    with open(labels_file) as f:
        targets = f.readlines()
    targets = [x.strip() for x in targets]

    n_perfect_predictions = 0
    i = 0
    while i < len(inputs):
        input = inputs[i]
        target = "".join(targets[i].split()).replace('<z>', '')

        indexed_tokens = tokenizer.encode(input)
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.to(device)
        with torch.no_grad():
            outputs = model(tokens_tensor)
            predictions = outputs[0]

        predicted_sentence = []
        for token in torch.argmax(predictions[0], 1).cpu().numpy():
            if token != tokenizer.convert_tokens_to_ids('<z>'):
                predicted_sentence.append(token)
            else:
                break

        prediction = tokenizer.decode(predicted_sentence)
        prediction = "".join(prediction.split())
        if target == prediction:
            n_perfect_predictions += 1
        i += 1

    return n_perfect_predictions, len(inputs)

to evaluate the performance of a trained model on an evaluation dataset. 

In [54]:
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, evaluate=True)

    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
    )

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    labels_file = str(args.eval_data_file).replace('masked_code_', 'mask_')
    labels_lines = [line.rstrip() for line in open(labels_file)]

    step = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        # Get the labels lines to process
        start = step * len(batch)
        end = start + len(batch) + 1
        lables_to_process = labels_lines[start:end]

        step += 1

        inputs, labels = read_masked_dataset(tokenizer, batch, lables_to_process)
        inputs = inputs.to(args.device)
        labels = labels.to(args.device)

        with torch.no_grad():
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))
    perfect_predictions, num_examples = get_number_perfect_predictions(model, tokenizer, args.eval_data_file)
    result = {"perplexity": perplexity, "loss": eval_loss,
              "perfect_predictions": perfect_predictions, "total_eval_examples": num_examples}

    logger.log({'val_perplexity': perplexity, 'avg_val_loss': eval_loss})
    logger.log({'perfect_predictions': perfect_predictions})
    logger.log({'perfect_predictions_percentage': perfect_predictions / num_examples})

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results_" + str(time.time()) + ".txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    if args.early_stop > 0:
        # Early stop has been required by the user, check performance
        eval_results_files = glob.glob(os.path.join(eval_output_dir, prefix,'eval_results_*.txt'))
        eval_results_files.sort(key=lambda x: os.stat(os.path.join(eval_output_dir, x)).st_mtime)
        if len(eval_results_files) > args.early_stop:
            perfect_predictions_before = read_perfect_predictions_from_file(eval_results_files[len(eval_results_files)-(args.early_stop+1)])
            if perfect_predictions <= perfect_predictions_before:
                return None

    return result

The train function is responsible for training the provided model using the provided training dataset

In [55]:
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()
        
    #sets the batch size for training.
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    #defines the data sampler for training.
    train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    #Creates a DataLoader to iterate over the training dataset.
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        #calculates the total number of training steps based on the number of epochs and gradient accumulation steps.
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    # initializes the AdamW optimizer with the defined parameters and  creates a linear scheduler with warmup for adjusting the learning rate during training.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    #iterates over the training data for the specified number of epochs.
    #Within each epoch, it iterates over batches of data.
    #It handles cases where training is resumed from a checkpoint (global_step, epochs_trained, steps_trained_in_current_epoch).
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility

    labels_file = str(args.train_data_file).replace('masked_code_','mask_')
    labels_lines = [line.rstrip() for line in open(labels_file)]

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        # step is the count of the steps performed, batch contains the actual input data

        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            # Get the labels lines to process
            start = step * len(batch)
            end = start + len(batch) + 1
            lables_to_process = labels_lines[start:end]

            # In 90% of cases, we used the inputs with the masked tokens
            # In 10% of cases we don't mask any token
            if decision(0.9):
                inputs, labels = read_masked_dataset(tokenizer, batch, lables_to_process)
            else:
                inputs, labels = get_non_masked_instances(tokenizer, batch, lables_to_process)

            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            print("$$")
            print(inputs.shape)
            print(labels.shape)
            print("**")
            model.train()
            outputs = model(inputs, labels=labels) if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        if results is None:
                            print("Stopping condition reached, no improvement in evaluation set")
                            sys.exit(0)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step

In [56]:
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

04/02/2024 13:48:33 - INFO - __main__ -   ***** Running training *****
04/02/2024 13:48:33 - INFO - __main__ -     Num examples = 16906
04/02/2024 13:48:33 - INFO - __main__ -     Num Epochs = 1
04/02/2024 13:48:33 - INFO - __main__ -     Instantaneous batch size per GPU = 2
04/02/2024 13:48:33 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 8
04/02/2024 13:48:33 - INFO - __main__ -     Gradient Accumulation steps = 4
04/02/2024 13:48:33 - INFO - __main__ -     Total optimization steps = 2113
Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

372
%%
[[0, 84, 89, 70, 80, 77, 71, 225, 90, 83, 77, 72, 225, 89, 84, 72, 69, 88, 73, 48, 83, 71, 79, 72, 83, 91, 82, 41, 92, 71, 73, 84, 88, 77, 83, 82, 87, 12, 49, 69, 82, 69, 75, 73, 72, 51, 70, 78, 73, 71, 88, 54, 73, 74, 73, 86, 73, 82, 71, 73, 225, 67, 88, 76, 77, 87, 16, 225, 55, 88, 86, 77, 82, 75, 63, 65, 225, 89, 87, 73, 86, 87, 13, 225, 88, 76, 86, 83, 91, 87, 225, 37, 89, 88, 76, 49, 77, 82, 77, 81, 89, 81, 37, 72, 81, 77, 82, 52, 73, 86, 81, 77, 87, 87, 77, 83, 82, 16, 225, 54, 73, 81, 83, 88, 73, 41, 92, 71, 73, 84, 88, 77, 83, 82, 16, 225, 54, 89, 82, 88, 77, 81, 73, 42, 69, 89, 80, 88, 225, 261, 225, 37, 86, 75, 89, 81, 73, 82, 88, 63, 65, 225, 84, 69, 86, 69, 81, 87, 225, 33, 225, 82, 73, 91, 225, 37, 86, 75, 89, 81, 73, 82, 88, 63, 22, 65, 31, 225, 84, 69, 86, 69, 81, 87, 63, 20, 65, 225, 33, 225, 82, 73, 91, 225, 37, 86, 75, 89, 81, 73, 82, 88, 12, 6, 67, 88, 76, 77, 87, 6, 16, 225, 6, 49, 69, 82, 69, 75, 73, 72, 51, 70, 78, 73, 71, 88, 54, 73, 74, 73, 86, 73, 82, 71



474
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 39, 83, 80, 80, 73, 71, 88, 77, 83, 82, 32, 37, 75, 73, 82, 88, 52, 86, 83, 78, 73, 71, 88, 45, 82, 74, 83, 34, 225, 71, 83, 80, 80, 73, 71, 88, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 12, 55, 88, 86, 77, 82, 75, 225, 74, 83, 80, 72, 73, 86, 13, 225, 95, 225, 77, 74, 225, 12, 5, 77, 82, 71, 80, 89, 72, 73, 40, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 13, 225, 261, 225, 72, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 225, 33, 225, 74, 77, 82, 72, 40, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 12, 74, 83, 80, 72, 73, 86, 13, 31, 225, 97, 225, 42, 77, 80, 73, 225, 93, 69, 86, 82, 48, 83, 71, 79, 225, 33, 225, 82, 73, 91, 225, 42, 77, 80, 73, 12, 74, 83, 80, 72, 73, 86, 225, 15, 225, 74, 77, 80, 73, 55, 73, 84, 69, 86, 69, 88, 83, 86, 225, 15, 225, 61, 37, 54, 50, 67, 48, 51, 39, 47, 13, 31, 225, 70, 83, 83, 80, 73, 69, 82, 225, 93, 69, 86, 82, 48, 83, 71, 79, 4



473
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 39, 83, 80, 80, 73, 71, 88, 77, 83, 82, 32, 37, 75, 73, 82, 88, 52, 86, 83, 78, 73, 71, 88, 45, 82, 74, 83, 34, 225, 71, 83, 80, 80, 73, 71, 88, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 12, 55, 88, 86, 77, 82, 75, 225, 74, 83, 80, 72, 73, 86, 13, 225, 95, 225, 77, 74, 225, 12, 5, 77, 82, 71, 80, 89, 72, 73, 40, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 13, 95, 225, 72, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 225, 33, 225, 74, 77, 82, 72, 40, 73, 90, 40, 73, 84, 73, 82, 72, 73, 82, 71, 77, 73, 87, 12, 74, 83, 80, 72, 73, 86, 13, 31, 225, 97, 225, 42, 77, 80, 73, 225, 93, 69, 86, 82, 48, 83, 71, 79, 225, 33, 225, 82, 73, 91, 225, 42, 77, 80, 73, 12, 74, 83, 80, 72, 73, 86, 225, 15, 225, 74, 77, 80, 73, 55, 73, 84, 69, 86, 69, 88, 83, 86, 225, 15, 225, 61, 37, 54, 50, 67, 48, 51, 39, 47, 13, 31, 225, 70, 83, 83, 80, 73, 69, 82, 225, 93, 69, 86, 82, 48, 83, 71, 79, 42, 83,



301
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 55, 88, 86, 77, 82, 75, 38, 89, 77, 80, 72, 73, 86, 225, 261, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 87, 89, 84, 73, 86, 18, 75, 73, 88, 56, 69, 87, 79, 12, 13, 13, 31, 225, 77, 74, 225, 12, 77, 87, 39, 83, 81, 84, 80, 73, 88, 73, 12, 13, 13, 225, 95, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 30, 225, 71, 83, 81, 84, 80, 73, 88, 73, 18, 6, 13, 31, 225, 97, 225, 73, 80, 87, 73, 225, 95, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 225, 7, 6, 13, 18, 69, 84, 84, 73, 82, 72, 12, 75, 73, 88, 52, 86, 83, 71, 73, 87, 87, 73, 72, 12, 13, 15, 21, 13, 18, 69, 84, 84, 73, 82, 72, 12, 11, 19, 11, 13, 18, 69, 84, 84, 73, 82, 72, 12, 75, 73, 88, 56, 83, 88, 69, 80, 12, 13, 13, 31, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 30, 225, 6, 13, 18, 69, 84, 84, 73, 82, 72, 12, 75, 73, 88, 55, 88, 73, 84, 56, 77, 88, 80, 73, 12, 13, 13, 31, 225, 97, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 



304
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 55, 88, 86, 77, 82, 75, 38, 89, 77, 80, 72, 73, 86, 225, 69, 84, 84, 73, 82, 72, 56, 83, 38, 89, 74, 74, 73, 86, 12, 55, 88, 86, 77, 82, 75, 38, 89, 77, 80, 72, 73, 86, 225, 70, 89, 74, 13, 225, 95, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 87, 89, 84, 73, 86, 18, 75, 73, 88, 56, 69, 87, 79, 12, 13, 13, 31, 225, 77, 74, 225, 12, 77, 87, 39, 83, 81, 84, 80, 73, 88, 73, 12, 13, 13, 225, 95, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 30, 225, 71, 83, 81, 84, 80, 73, 88, 73, 18, 6, 13, 31, 225, 97, 225, 73, 80, 87, 73, 225, 95, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 225, 7, 6, 13, 18, 69, 84, 84, 73, 82, 72, 12, 75, 73, 88, 52, 86, 83, 71, 73, 87, 87, 73, 72, 12, 13, 15, 21, 13, 18, 69, 84, 84, 73, 82, 72, 12, 11, 19, 11, 13, 18, 69, 84, 84, 73, 82, 72, 12, 75, 73, 88, 56, 83, 88, 69, 80, 12, 13, 13, 31, 225, 70, 89, 74, 18, 69, 84, 84, 73, 82, 72, 12, 6, 30, 225, 6, 13, 18, 69, 8



356
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 39, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 225, 71, 86, 73, 69, 88, 73, 12, 80, 83, 82, 75, 225, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 45, 72, 13, 225, 95, 225, 39, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 225, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 225, 33, 225, 82, 73, 91, 225, 39, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 45, 81, 84, 80, 12, 13, 31, 225, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 18, 87, 73, 88, 50, 73, 91, 12, 88, 86, 89, 73, 13, 31, 225, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 18, 87, 73, 88, 52, 86, 77, 81, 69, 86, 93, 47, 73, 93, 12, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 45, 72, 13, 31, 225, 55, 88, 86, 77, 82, 75, 225, 225, 261, 225, 71, 83, 81, 81, 73, 86, 71, 73, 39, 89, 86, 86, 73, 82, 71, 93, 18, 87,



360
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 57, 54, 48, 225, 71, 86, 73, 69, 88, 73, 57, 54, 48, 12, 74, 77, 82, 69, 80, 225, 55, 88, 86, 77, 82, 75, 225, 89, 86, 80, 55, 88, 86, 77, 82, 75, 13, 225, 88, 76, 86, 83, 91, 87, 225, 49, 69, 80, 74, 83, 86, 81, 73, 72, 57, 54, 48, 41, 92, 71, 73, 84, 88, 77, 83, 82, 225, 95, 225, 57, 54, 48, 225, 89, 86, 80, 31, 225, 88, 86, 93, 225, 95, 225, 89, 86, 80, 225, 33, 225, 37, 71, 71, 73, 87, 87, 39, 83, 82, 88, 86, 83, 80, 80, 73, 86, 18, 72, 83, 52, 86, 77, 90, 77, 80, 73, 75, 73, 72, 12, 82, 73, 91, 225, 52, 86, 77, 90, 77, 80, 73, 75, 73, 72, 41, 92, 71, 73, 84, 88, 77, 83, 82, 37, 71, 88, 77, 83, 82, 32, 57, 54, 48, 34, 12, 13, 225, 95, 225, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 57, 54, 48, 225, 86, 89, 82, 12, 13, 225, 88, 76, 86, 83, 91, 87, 225, 49, 69, 80, 74, 83, 86, 81, 73, 72, 57, 54, 48, 41, 92, 71, 73, 84, 88, 77, 83, 82, 225, 95, 225, 86, 73, 88, 89, 86, 82, 225, 82, 73, 91, 225, 261, 225, 97, 225,



278
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 90, 83, 77, 72, 225, 72, 73, 80, 73, 88, 73, 52, 86, 73, 90, 77, 83, 89, 87, 12, 55, 88, 86, 77, 82, 75, 225, 88, 73, 92, 88, 13, 225, 95, 225, 55, 53, 48, 77, 88, 73, 51, 84, 73, 82, 44, 73, 80, 84, 73, 86, 225, 76, 73, 80, 84, 73, 86, 225, 33, 225, 82, 73, 91, 225, 261, 225, 88, 86, 93, 225, 12, 55, 53, 48, 77, 88, 73, 40, 69, 88, 69, 70, 69, 87, 73, 225, 72, 70, 225, 33, 225, 76, 73, 80, 84, 73, 86, 18, 75, 73, 88, 59, 86, 77, 88, 69, 70, 80, 73, 40, 69, 88, 69, 70, 69, 87, 73, 12, 13, 13, 225, 95, 225, 72, 70, 18, 72, 73, 80, 73, 88, 73, 12, 40, 38, 44, 73, 80, 84, 73, 86, 18, 56, 37, 38, 48, 41, 67, 50, 37, 49, 41, 16, 225, 40, 38, 44, 73, 80, 84, 73, 86, 18, 56, 41, 60, 56, 67, 39, 51, 48, 225, 15, 225, 6, 33, 35, 6, 16, 225, 82, 73, 91, 225, 55, 88, 86, 77, 82, 75, 63, 65, 225, 95, 225, 88, 73, 92, 88, 225, 97, 13, 31, 225, 97, 225, 71, 69, 88, 71, 76, 225, 12, 55, 53, 48, 41, 92, 71, 73, 84, 88, 77, 83, 82, 225, 87, 85, 80, 73, 13, 2



210
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 90, 83, 77, 72, 225, 74, 77, 80, 88, 73, 86, 12, 55, 88, 86, 77, 82, 75, 225, 57, 54, 48, 16, 225, 70, 93, 88, 73, 63, 65, 225, 71, 83, 82, 88, 73, 82, 88, 16, 225, 40, 83, 71, 89, 81, 73, 82, 88, 42, 86, 69, 75, 81, 73, 82, 88, 225, 72, 83, 71, 16, 225, 52, 69, 86, 87, 73, 54, 73, 87, 89, 80, 88, 225, 84, 69, 86, 87, 73, 13, 225, 95, 225, 55, 88, 86, 77, 82, 75, 63, 65, 225, 88, 69, 75, 87, 225, 33, 225, 71, 83, 80, 80, 73, 71, 88, 77, 83, 82, 87, 225, 261, 225, 77, 74, 225, 12, 88, 69, 75, 87, 18, 80, 73, 82, 75, 88, 76, 225, 34, 225, 20, 13, 225, 95, 225, 84, 69, 86, 87, 73, 18, 75, 73, 88, 12, 57, 54, 48, 13, 18, 75, 73, 88, 49, 73, 88, 69, 72, 69, 88, 69, 12, 13, 18, 87, 73, 88, 58, 69, 80, 89, 73, 87, 12, 79, 73, 93, 16, 225, 88, 69, 75, 87, 13, 31, 225, 97, 225, 97, 203, 2], [0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 90, 83, 77, 72, 225, 74, 77, 80, 88, 73, 86, 12, 55



329
%%
[[0, 36, 51, 90, 73, 86, 86, 77, 72, 73, 225, 84, 89, 70, 80, 77, 71, 225, 49, 69, 84, 51, 70, 78, 73, 71, 88, 49, 69, 84, 225, 76, 69, 90, 77, 82, 75, 12, 225, 55, 88, 86, 77, 82, 75, 225, 74, 77, 73, 80, 72, 21, 16, 225, 51, 70, 78, 73, 71, 88, 225, 90, 69, 80, 89, 73, 21, 13, 225, 95, 225, 77, 74, 12, 83, 70, 78, 225, 33, 33, 225, 82, 89, 80, 80, 13, 225, 86, 73, 88, 89, 86, 82, 225, 88, 76, 77, 87, 31, 225, 83, 70, 78, 225, 33, 225, 49, 69, 84, 51, 70, 78, 73, 71, 88, 57, 88, 77, 80, 87, 18, 76, 69, 90, 77, 82, 75, 12, 12, 39, 83, 80, 80, 73, 71, 88, 77, 83, 82, 13, 83, 70, 78, 16, 225, 74, 77, 73, 80, 72, 21, 16, 225, 90, 69, 80, 89, 73, 21, 13, 31, 225, 86, 73, 88, 89, 86, 82, 225, 261, 225, 97, 203, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1



250
%%
[[0, 84, 89, 70, 80, 77, 71, 225, 49, 83, 72, 89, 80, 73, 45, 82, 74, 83, 48, 77, 87, 88, 225, 74, 77, 80, 88, 73, 86, 12, 74, 77, 82, 69, 80, 225, 49, 83, 72, 89, 80, 73, 45, 82, 74, 83, 42, 77, 80, 88, 73, 86, 225, 74, 77, 80, 88, 73, 86, 13, 225, 95, 225, 74, 77, 82, 69, 80, 225, 49, 83, 72, 89, 80, 73, 45, 82, 74, 83, 48, 77, 87, 88, 225, 81, 83, 72, 89, 80, 73, 45, 82, 74, 83, 42, 77, 80, 88, 73, 86, 73, 72, 225, 261, 225, 74, 83, 86, 225, 12, 74, 77, 82, 69, 80, 225, 49, 83, 72, 89, 80, 73, 45, 82, 74, 83, 225, 86, 73, 87, 83, 89, 86, 71, 73, 225, 30, 225, 88, 76, 77, 87, 13, 225, 95, 225, 77, 74, 225, 12, 74, 77, 80, 88, 73, 86, 18, 69, 71, 71, 73, 84, 88, 12, 86, 73, 87, 83, 89, 86, 71, 73, 13, 13, 225, 95, 225, 81, 83, 72, 89, 80, 73, 45, 82, 74, 83, 42, 77, 80, 88, 73, 86, 73, 72, 18, 69, 72, 72, 12, 86, 73, 87, 83, 89, 86, 71, 73, 13, 31, 225, 97, 225, 97, 225, 86, 73, 88, 89, 86, 82, 225, 81, 83, 72, 89, 80, 73, 45, 82, 74, 83, 42, 77, 80, 88, 73, 86, 73, 72, 31, 225



255
%%
[[0, 84, 89, 70, 80, 77, 71, 225, 56, 73, 87, 88, 55, 88, 73, 84, 225, 71, 86, 73, 69, 88, 73, 12, 13, 225, 95, 225, 56, 73, 87, 88, 55, 88, 73, 84, 225, 87, 88, 73, 84, 225, 33, 225, 82, 73, 91, 225, 56, 73, 87, 88, 55, 88, 73, 84, 12, 13, 31, 225, 87, 88, 73, 84, 18, 87, 73, 88, 39, 89, 86, 86, 73, 82, 88, 39, 69, 87, 73, 50, 83, 12, 6, 20, 20, 21, 6, 13, 31, 225, 87, 88, 73, 84, 18, 87, 73, 88, 56, 73, 87, 88, 40, 69, 88, 69, 12, 6, 20, 20, 21, 6, 225, 261, 225, 86, 73, 88, 89, 86, 82, 225, 87, 88, 73, 84, 31, 225, 97, 203, 2, 1, 1], [0, 84, 89, 70, 80, 77, 71, 225, 56, 73, 87, 88, 55, 88, 73, 84, 225, 71, 86, 73, 69, 88, 73, 12, 13, 225, 95, 225, 56, 73, 87, 88, 55, 88, 73, 84, 225, 87, 88, 73, 84, 225, 33, 225, 82, 73, 91, 225, 56, 73, 87, 88, 55, 88, 73, 84, 12, 13, 31, 225, 87, 88, 73, 84, 18, 87, 73, 88, 39, 89, 86, 86, 73, 82, 88, 39, 69, 87, 73, 50, 83, 12, 6, 20, 20, 21, 6, 13, 31, 225, 87, 88, 73, 84, 18, 87, 73, 88, 56, 73, 87, 88, 40, 69, 88, 69, 12, 6, 20, 20, 21,



271
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 55, 88, 86, 77, 82, 75, 225, 75, 73, 88, 39, 80, 69, 87, 87, 84, 69, 88, 76, 12, 13, 225, 95, 225, 77, 74, 225, 12, 71, 80, 69, 87, 87, 52, 69, 88, 76, 225, 33, 33, 225, 82, 89, 80, 80, 13, 225, 95, 225, 77, 74, 225, 12, 71, 80, 69, 87, 87, 48, 83, 69, 72, 73, 86, 225, 77, 82, 87, 88, 69, 82, 71, 73, 83, 74, 225, 57, 54, 48, 39, 80, 69, 87, 87, 48, 83, 69, 72, 73, 86, 13, 225, 95, 225, 71, 80, 69, 87, 87, 52, 69, 88, 76, 225, 33, 225, 75, 73, 88, 39, 80, 69, 87, 87, 52, 69, 88, 76, 12, 12, 57, 54, 48, 39, 80, 69, 87, 87, 48, 83, 69, 72, 73, 86, 13, 225, 71, 80, 69, 87, 87, 48, 83, 69, 72, 73, 86, 13, 31, 225, 97, 225, 73, 80, 87, 73, 225, 261, 225, 88, 76, 86, 83, 91, 225, 82, 73, 91, 225, 45, 80, 80, 73, 75, 69, 80, 37, 86, 75, 89, 81, 73, 82, 88, 41, 92, 71, 73, 84, 88, 77, 83, 82, 12, 6, 57, 82, 87, 89, 84, 84, 83, 86, 88, 73, 72, 225, 39, 80, 69, 87, 87, 48, 83, 69, 72, 73, 86, 225, 6, 225, 15, 225, 71, 80, 69, 87, 87, 48, 83, 69, 72, 73



263
%%
[[0, 84, 89, 70, 80, 77, 71, 225, 90, 83, 77, 72, 225, 87, 73, 88, 37, 88, 88, 86, 77, 70, 89, 88, 73, 12, 55, 88, 86, 77, 82, 75, 225, 82, 69, 81, 73, 16, 225, 37, 88, 88, 86, 77, 70, 89, 88, 73, 225, 90, 69, 80, 89, 73, 13, 225, 95, 225, 77, 74, 12, 90, 69, 80, 89, 73, 5, 33, 225, 82, 89, 80, 80, 13, 225, 95, 225, 88, 76, 77, 87, 18, 69, 88, 88, 86, 77, 70, 89, 88, 73, 87, 18, 84, 89, 88, 12, 82, 69, 81, 73, 16, 225, 90, 69, 80, 89, 73, 13, 31, 225, 97, 225, 73, 80, 87, 73, 225, 261, 225, 88, 76, 77, 87, 18, 69, 88, 88, 86, 77, 70, 89, 88, 73, 87, 18, 86, 73, 81, 83, 90, 73, 12, 82, 69, 81, 73, 13, 31, 225, 97, 225, 97, 203, 2], [0, 84, 89, 70, 80, 77, 71, 225, 90, 83, 77, 72, 225, 87, 73, 88, 37, 88, 88, 86, 77, 70, 89, 88, 73, 12, 55, 88, 86, 77, 82, 75, 225, 82, 69, 81, 73, 16, 225, 37, 88, 88, 86, 77, 70, 89, 88, 73, 225, 90, 69, 80, 89, 73, 13, 225, 95, 225, 77, 74, 12, 90, 69, 80, 89, 73, 5, 33, 225, 82, 89, 80, 80, 13, 225, 95, 225, 88, 76, 77, 87, 18, 69, 88, 88, 86, 7



222
%%
[[0, 84, 89, 70, 80, 77, 71, 225, 42, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 225, 225, 261, 225, 42, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 225, 74, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 225, 33, 225, 75, 73, 88, 42, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 12, 69, 87, 87, 77, 75, 82, 81, 73, 82, 88, 13, 31, 225, 37, 87, 87, 77, 75, 82, 81, 73, 82, 88, 225, 74, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 37, 87, 87, 77, 75, 82, 81, 73, 82, 88, 225, 33, 225, 90, 69, 86, 77, 69, 70, 80, 73, 87, 18, 88, 83, 37, 87, 87, 77, 75, 82, 81, 73, 82, 88, 12, 69, 87, 87, 77, 75, 82, 81, 73, 82, 88, 13, 31, 225, 86, 73, 88, 89, 86, 82, 225, 74, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 18, 71, 83, 82, 72, 77, 88, 77, 83, 82, 69, 80, 12, 74, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 37, 87, 87, 77, 75, 82, 81, 73, 82, 88, 13, 31, 225, 97, 203, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 84, 89, 70, 80, 77, 71, 225, 42, 69, 71, 88, 83, 86, 43, 86, 69, 84, 76, 225, 71, 83, 82, 72, 77, 88, 77, 83,



281
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 52, 69, 87, 87, 91, 83, 86, 72, 38, 69, 87, 73, 72, 39, 77, 84, 76, 73, 86, 42, 69, 71, 88, 83, 86, 93, 225, 75, 73, 88, 52, 38, 41, 42, 69, 71, 88, 83, 86, 93, 12, 55, 88, 86, 77, 82, 75, 225, 76, 77, 82, 88, 13, 225, 95, 225, 88, 86, 93, 225, 95, 225, 86, 73, 88, 89, 86, 82, 225, 88, 76, 77, 87, 18, 81, 69, 82, 69, 75, 73, 86, 18, 75, 73, 88, 45, 82, 87, 88, 69, 82, 71, 73, 12, 52, 69, 87, 87, 91, 83, 86, 72, 38, 69, 87, 73, 72, 39, 77, 84, 76, 73, 86, 42, 69, 71, 88, 83, 86, 93, 225, 261, 225, 97, 225, 71, 69, 88, 71, 76, 225, 12, 39, 83, 81, 84, 83, 82, 73, 82, 88, 48, 83, 83, 79, 89, 84, 41, 92, 71, 73, 84, 88, 77, 83, 82, 225, 73, 13, 225, 95, 225, 88, 76, 86, 83, 91, 225, 82, 73, 91, 225, 57, 82, 87, 89, 84, 84, 83, 86, 88, 73, 72, 51, 84, 73, 86, 69, 88, 77, 83, 82, 41, 92, 71, 73, 84, 88, 77, 83, 82, 12, 6, 52, 69, 87, 87, 91, 83, 86, 72, 225, 70, 69, 87, 73, 72, 225, 71, 77, 84, 76, 73, 86, 225, 74, 69, 71, 88, 83, 86, 93, 225, 82,



$$
torch.Size([2, 336])
torch.Size([2, 336])
**
371
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 70, 83, 83, 80, 73, 69, 82, 225, 77, 87, 39, 40, 45, 41, 82, 69, 70, 80, 73, 72, 225, 261, 225, 38, 83, 83, 80, 73, 69, 82, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 71, 72, 77, 55, 88, 69, 88, 89, 87, 49, 69, 84, 18, 75, 73, 88, 12, 70, 72, 69, 18, 75, 73, 88, 45, 72, 12, 13, 13, 31, 225, 77, 74, 225, 12, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 33, 225, 82, 89, 80, 80, 13, 225, 95, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 70, 72, 69, 18, 76, 69, 87, 38, 73, 69, 82, 87, 12, 13, 225, 96, 96, 225, 70, 72, 69, 18, 77, 87, 41, 92, 88, 73, 82, 87, 77, 83, 82, 12, 13, 31, 225, 71, 72, 77, 55, 88, 69, 88, 89, 87, 49, 69, 84, 18, 84, 89, 88, 12, 70, 72, 69, 18, 75, 73, 88, 45, 72, 12, 13, 16, 225, 76, 69, 87, 38, 73, 69, 82, 87, 13, 31, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 96, 96, 225, 77, 87, 39, 40, 45, 41, 82, 69, 70, 80, 73, 72,



374
%%
[[0, 84, 86, 77, 90, 69, 88, 73, 225, 70, 83, 83, 80, 73, 69, 82, 225, 77, 87, 39, 40, 45, 41, 82, 69, 70, 80, 73, 72, 12, 59, 73, 70, 55, 84, 76, 73, 86, 73, 38, 73, 69, 82, 40, 73, 84, 80, 83, 93, 81, 73, 82, 88, 37, 86, 71, 76, 77, 90, 73, 225, 70, 72, 69, 13, 225, 95, 225, 38, 83, 83, 80, 73, 69, 82, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 71, 72, 77, 55, 88, 69, 88, 89, 87, 49, 69, 84, 18, 75, 73, 88, 12, 70, 72, 69, 18, 75, 73, 88, 45, 72, 12, 13, 13, 31, 225, 77, 74, 225, 12, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 33, 225, 82, 89, 80, 80, 13, 225, 95, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 70, 72, 69, 18, 76, 69, 87, 38, 73, 69, 82, 87, 12, 13, 225, 96, 96, 225, 70, 72, 69, 18, 77, 87, 41, 92, 88, 73, 82, 87, 77, 83, 82, 12, 13, 31, 225, 71, 72, 77, 55, 88, 69, 88, 89, 87, 49, 69, 84, 18, 84, 89, 88, 12, 70, 72, 69, 18, 75, 73, 88, 45, 72, 12, 13, 16, 225, 76, 69, 87, 38, 73, 69, 82, 87, 13, 31, 225, 76, 69, 87, 38, 73, 69, 82, 87, 225, 33, 225, 7

Iteration:   1%|          | 53/8453 [00:04<11:23, 12.29it/s]
Epoch:   0%|          | 0/1 [00:04<?, ?it/s]


$$
torch.Size([2, 1065])
torch.Size([2, 1065])
**


RuntimeError: The expanded size of the tensor (1065) must match the existing size (1024) at non-singleton dimension 1.  Target sizes: [2, 1065].  Tensor sizes: [1, 1024]

In [None]:
# if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
#         # Create output directory if needed
#         if args.local_rank in [-1, 0]:
#             os.makedirs(args.output_dir, exist_ok=True)

#         logger.info("Saving model checkpoint to %s", args.output_dir)
#         # Save a trained model, configuration and tokenizer using `save_pretrained()`.
#         # They can then be reloaded using `from_pretrained()`
#         model_to_save = (
#             model.module if hasattr(model, "module") else model
#         )  # Take care of distributed/parallel training
#         model_to_save.save_pretrained(args.output_dir)
#         tokenizer.save_pretrained(args.output_dir)

#         # Good practice: save your training arguments together with the trained model
#         torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

#         # Load a trained model and vocabulary that you have fine-tuned
#         model = model_class.from_pretrained(args.output_dir)
#         tokenizer = tokenizer_class.from_pretrained(args.output_dir)
#         model.to(args.device)

#     # Evaluation
# results = {}
# if args.do_eval and args.local_rank in [-1, 0]:
#         checkpoints = [args.output_dir]
#         if args.eval_all_checkpoints:
#             checkpoints = list(
#                 os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
#             )
#             logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
#         logger.info("Evaluate the following checkpoints: %s", checkpoints)
#         for checkpoint in checkpoints:
#             global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
#             prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

#             model = model_class.from_pretrained(checkpoint)
#             model.to(args.device)
#             result = evaluate(args, model, tokenizer, prefix=prefix)
#             if result is None:
#                 print("Stopping condition reached, no improvement in evaluation set")
#                 sys.exit(0)
#             result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
#             results.update(result)

# return results