In [None]:
%%capture
!pip install sentencepiece transformers transformers[sentencepiece] accelerate
!pip install bitsandbytes
!pip install peft
!pip install datasets

import numpy as np
import torch

In [None]:
device = torch.device("cuda:0")
print("device: ", device)

from google.colab import drive
drive.mount('/content/drive')

device:  cuda:0
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Initializing tokenizer for 'google/mt5-small'

In [None]:
# %%capture
#from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer

model_name = 'google/mt5-small'
tokenizer = AutoTokenizer.from_pretrained(model_name)# or T5Tokenizer
# tokenizer = T5Tokenizer.from_pretrained(model_name, extra_ids=100)# or T5Tokenizer

len_tokenizer =len(tokenizer) # 32100 to get the sentinel ids
print(f"len_tokenizer = {len_tokenizer}")

tokenizer.add_tokens(['questiongeneration'])
len_tokenizer =len(tokenizer)
print(f"len_tokenizer = {len_tokenizer}")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


len_tokenizer = 250100
len_tokenizer = 250101


In [None]:
# tokenizer1 = T5Tokenizer.from_pretrained(model_name, extra_ids=100)# or T5Tokenizer

# len_tokenizer =len(tokenizer1) # 32100 to get the sentinel ids
# print(f"len_tokenizer={len_tokenizer}")

# print(tokenizer1.convert_tokens_to_ids('<extra_id_0>'))

len_tokenizer=250200
250100


In [None]:
# special_tokens_dict = {'additional_special_tokens': ['<questiongeneration:>']}
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

print("list of special tokens: ", tokenizer.all_special_tokens)

list of special tokens:  ['</s>', '<unk>', '<pad>']


 ## Masked LM/ denoising training

https://huggingface.co/docs/transformers/main/model_doc/t5#training

In [None]:
# utility class for denoised training, taken from hugging face library
class FlaxDataCollatorForT5MLM:
  """
  From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
  """
  def __init__(self,tokenizer,noise_density,mean_noise_span_length) -> None:
      self.tokenizer = tokenizer
      self.noise_density = noise_density
      self.mean_noise_span_length =mean_noise_span_length

  def create_sentinel_ids(self, mask_indices):
      """
      Sentinel ids creation given the indices that should be masked.
      The start indices of each mask are replaced by the sentinel ids in increasing
      order. Consecutive mask indices to be deleted are replaced with `-1`.
      """
      start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
      start_indices[:, 0] = mask_indices[:, 0]

      sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
      sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
      sentinel_ids -= mask_indices - start_indices

      return sentinel_ids

  def filter_input_ids(self, input_ids, sentinel_ids):
      """
      Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
      This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
      """
      batch_size = input_ids.shape[0]

      input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
      # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
      # masked tokens coming after sentinel tokens and should be removed
      input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
      input_ids = np.concatenate(
          [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
      )
      return input_ids

  def random_spans_noise_mask(self, length):
      """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
      # with the correction of this https://github.com/huggingface/transformers/pull/22938/files
      Noise mask consisting of random spans of noise tokens.
      The number of noise tokens and the number of noise spans and non-noise spans
      are determined deterministically as follows:
      num_noise_tokens = round(length * noise_density)
      num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
      Spans alternate between non-noise and noise, beginning with non-noise.
      Subject to the above restrictions, all masks are equally likely.
      Args:
          length: an int32 scalar (length of the incoming token sequence)
          noise_density: a float - approximate density of output mask
          mean_noise_span_length: a number
      Returns:
          a boolean tensor with shape [length]
      """

      orig_length = length

      num_noise_tokens = int(np.round(length * self.noise_density))
      num_nonnoise_tokens = length - num_noise_tokens
      # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
      num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
      # num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
      num_noise_spans = int(np.round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length))

      # avoid degeneracy by ensuring positive number of noise spans
      num_noise_spans = max(num_noise_spans, 1)

      # pick the lengths of the noise spans and the non-noise spans
      def _random_segmentation(num_items, num_segments):
          """Partition a sequence of items randomly into non-empty segments.
          Args:
              num_items: an integer scalar > 0
              num_segments: an integer scalar in [1, num_items]
          Returns:
              a Tensor with shape [num_segments] containing positive integers that add
              up to num_items
          """
          mask_indices = np.arange(num_items - 1) < (num_segments - 1)
          np.random.shuffle(mask_indices)
          first_in_segment = np.pad(mask_indices, [[1, 0]])
          segment_id = np.cumsum(first_in_segment)
          # count length of sub segments assuming that list is sorted
          _, segment_length = np.unique(segment_id, return_counts=True)
          return segment_length

      noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
      nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)

      interleaved_span_lengths = np.reshape(
          np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
      )
      span_starts = np.cumsum(interleaved_span_lengths)[:-1]
      span_start_indicator = np.zeros((length,), dtype=np.int8)
      span_start_indicator[span_starts] = True
      span_num = np.cumsum(span_start_indicator)
      is_noise = np.equal(span_num % 2, 1)

      return is_noise[:orig_length]


def get_denoised(examples):
    # print(examples)
  # removed 'FlaxDataCollatorForT5MLM' as that argument was not needed, it is defined globally
  # tokenizer is also defined globally
    # prompt = [text for text in data["text"]]
    # prompt = data
    # print(prompt)
    encoded = tokenizer(text_target = examples["text"], truncation=True, padding=True,  max_length=128, return_tensors="pt")
    # encoded = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt")
    batch_size =1
    input_length = encoded.input_ids.shape[1]
    denoiser = FlaxDataCollatorForT5MLM(tokenizer,.55,1.5)
    mask_indices = np.asarray([denoiser.random_spans_noise_mask(input_length) for i in range(batch_size)])
    labels_mask = ~mask_indices
    input_ids_sentinel = denoiser.create_sentinel_ids(mask_indices.astype(np.int8))
    labels_sentinel = denoiser.create_sentinel_ids(labels_mask.astype(np.int8))
    # input_ids = denoiser.filter_input_ids(encoded.input_ids, input_ids_sentinel)
    # labels  =  denoiser.filter_input_ids(encoded.input_ids, labels_sentinel)
    # return labels,input_ids
    model_inputs = examples
    model_inputs['input_ids'] = denoiser.filter_input_ids(encoded.input_ids, input_ids_sentinel)
    labels  =  denoiser.filter_input_ids(encoded.input_ids, labels_sentinel)
    model_inputs['labels'] = labels
    return model_inputs


def print_token_id(tokenizer,token):
  # Encode the token
  encoded = tokenizer.encode(token)
  # Print the id
  # print(token,encoded[0])
  return encoded[0]

def print_special_tokens(tokenizer):
    # Special tokens and their ids
    special_tokens = {}
    for attr in tokenizer.special_tokens_map:
        special_tokens[attr] = tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map[attr])

    # Print special tokens
    # print(special_tokens)

In [None]:
def shift_tokens_right(input_ids, pad_token_id, eos_token_id):
  """ Shift input ids one token to the right, and add pad token at the first position, and eos token to the last """
  # Create a larger tensor that includes space for the EOS token
  shifted_input_ids = torch.zeros((input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype)

  # Shift input_ids one step to the right
  shifted_input_ids[:, 1:] = input_ids

  # Set the first token to the pad_token_id
  shifted_input_ids[:, 0] = pad_token_id

  # Set the last token to the eos_token_id
  shifted_input_ids[:, -1] = eos_token_id

  return shifted_input_ids

In [None]:
# # prompt = "The cute dog walks in the green park"
# labels, input_ids = get_denoised(FlaxDataCollatorForT5MLM, tokenizer, prompt1)
# print(f"denoised input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}")
# print(f"denoised labels decoded   = {tokenizer.decode(*labels,skip_special_tokens=False)}")
# print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}") # todo should this be equal
# denoised_input_ids = torch.from_numpy(input_ids).to(device)
# denoised_labels = torch.from_numpy(labels).to(device)
# denoised_attention_mask = torch.ones(input_ids.shape).to(device)

# model.train()
# for epoch in range(2):
#     outputs = model(input_ids=denoised_input_ids,attention_mask=denoised_attention_mask,
#                     labels=denoised_labels)
#     loss = outputs.loss
#     if epoch % 20 == 0:
#         print(f"Epoch {epoch}  Loss {loss}")
#     loss.backward()
#     optimizer.step()
#     optimizer.zero_grad()
# print(f"Epoch {epoch}  Loss {loss.cpu().detach()}")
# #-------------------------------------------------------------

In [None]:
# !torch.cuda.empty_cache()

# Preparing DATASET!

In [None]:
def reduce_data(raw_datasets, no_samples):

  shuffled_dataset = raw_datasets.shuffle(seed=42)
  shuffled_dataset = shuffled_dataset.select(range(no_samples))
  return shuffled_dataset

In [None]:
#squad
from datasets import load_dataset

squad_dataset = load_dataset("squad")
squad_dataset = squad_dataset['train']
squad_dataset = squad_dataset.remove_columns(['id', 'title', 'answers'])
squad_dataset = reduce_data(squad_dataset, 80000)
squad_dataset

Downloading builder script:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.67k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset({
    features: ['context', 'question'],
    num_rows: 80000
})

In [None]:
print(squad_dataset[0])

{'context': 'The Pew Forum on Religion & Public Life ranks Egypt as the fifth worst country in the world for religious freedom. The United States Commission on International Religious Freedom, a bipartisan independent agency of the US government, has placed Egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. According to a 2010 Pew Global Attitudes survey, 84% of Egyptians polled supported the death penalty for those who leave Islam; 77% supported whippings and cutting off of hands for theft and robbery; and 82% support stoning a person who commits adultery.', 'question': 'What percentage of Egyptians polled support death penalty for those leaving Islam?'}


In [None]:
#mc4 telugu validation set
n_rows = 5000

mc4_dataset_te = load_dataset('json', data_files='/content/drive/MyDrive/IRE/mc4_te-train_partial.jsonl')
mc4_dataset_te = mc4_dataset_te.remove_columns([ 'timestamp', 'url'])
mc4_dataset_te = mc4_dataset_te['train']
mc4_dataset_te

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 5000
})

In [None]:
from datasets import load_dataset, Dataset

#tydiqa telugu train set
tydiqa_dataset_te = load_dataset("tydiqa", "primary_task", split="train")
tydiqa_dataset_te = tydiqa_dataset_te.remove_columns(['passage_answer_candidates', 'document_title','annotations', 'document_plaintext', 'document_url'])
tydiqa_dataset_te

# filtering language on telugu
filtered_dataset = tydiqa_dataset_te.filter(lambda example: example['language'] == 'telugu')

# Now, limit to 'n' rows
tydiqa_dataset_te_ = filtered_dataset.select(range(n_rows))

# remove language feature
tydiqa_dataset_te_ = tydiqa_dataset_te_.remove_columns(['language'])
# rename feature
final_tydiqa_dataset = tydiqa_dataset_te_.rename_column('question_text', 'text')
final_tydiqa_dataset = reduce_data(final_tydiqa_dataset, 5000)
print(final_tydiqa_dataset)
print(final_tydiqa_dataset[0])

Downloading builder script:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/58.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.62M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/166916 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18670 [00:00<?, ? examples/s]

Filter:   0%|          | 0/166916 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 5000
})
{'text': 'రైట్ బ్రదర్స్ కనుగొన్న ప్రయాణ సాధనం ఏమిటి?'}


In [None]:
# MLM_data = [item for pair in zip(final_tydiqa_dataset, mc4_data) for item in pair]
# print(len(MLM_data))

interleaved_rows = []
for i in range(n_rows):
    interleaved_rows.append(final_tydiqa_dataset[i]['text'])
    interleaved_rows.append(mc4_dataset_te[i]['text'])

combined_mlm_data = Dataset.from_dict({'text': interleaved_rows})
combined_mlm_data

# print(f"denoised input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}")
# print(f"denoised labels decoded   = {tokenizer.decode(*labels,skip_special_tokens=False)}")
# print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}") # todo should this be equal

Dataset({
    features: ['text'],
    num_rows: 10000
})

In [None]:
print(combined_mlm_data[0])
print(combined_mlm_data[1])

{'text': 'రైట్ బ్రదర్స్ కనుగొన్న ప్రయాణ సాధనం ఏమిటి?'}
{'text': 'మిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ లో Indiaఆఫర్స్ , Pictures & పూర్తి లక్షణాలుధర | PriceDekho.com\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫\n655 రేటింగ్స్\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ ధర\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ మీరుIndianవిఫణిలో విడుదల 2013-02-05 మరియు కొనుగోలు అందుబాటులో ఉంది.\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - వేరియంట్ జాబితా\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ టాబ్లెట్ బ్లాక్\n(663 రేటింగ్స్)\nఉత్తమ 4,698 వివరాలు చూడండి\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - ధర తనది కాదను వ్యక్తి\nతాజా ధర మిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ 07 Dec 2017 పొందిన జరిగినది.\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ యూజర్ సమీక్షలు\nగుడ్ , 655 రేటింగ్ల ఆధారంగా\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - లక్షణాలు\nమోడల్ నామ Micromax Funbook Infinity\nడిస్ప్లే ఫీచర్స్ Pinch-to-Zoom, Multi-touch Screen\nవెయిట్ 360 g\nసౌండ్ MP3, WAV, FLAC\nఅలెర్ట్ టైప్స్ WAV, FLAC, MP3\nనెట్వర్క్ టైపు Sim Not Supported\nబాటరీ టైపు 4000 mAh\nటాక్ టైం 6

In [None]:
for i in range(3):
  # print(filtered_dataset[i])
  print(final_tydiqa_dataset[i])
  print(combined_mlm_data[i])
# print(mc4_dataset_te[0])
# print(combined_mlm_data[1])

{'text': 'రైట్ బ్రదర్స్ కనుగొన్న ప్రయాణ సాధనం ఏమిటి?'}
{'text': 'రైట్ బ్రదర్స్ కనుగొన్న ప్రయాణ సాధనం ఏమిటి?'}
{'text': 'పరమాణు సంఖ్య 65 గల మూలకం ఏది?'}
{'text': 'మిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ లో Indiaఆఫర్స్ , Pictures & పూర్తి లక్షణాలుధర | PriceDekho.com\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫\n655 రేటింగ్స్\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ ధర\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ మీరుIndianవిఫణిలో విడుదల 2013-02-05 మరియు కొనుగోలు అందుబాటులో ఉంది.\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - వేరియంట్ జాబితా\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ టాబ్లెట్ బ్లాక్\n(663 రేటింగ్స్)\nఉత్తమ 4,698 వివరాలు చూడండి\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - ధర తనది కాదను వ్యక్తి\nతాజా ధర మిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ 07 Dec 2017 పొందిన జరిగినది.\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ యూజర్ సమీక్షలు\nగుడ్ , 655 రేటింగ్ల ఆధారంగా\nమిక్రోమాస్ ఫంబూక్ ఇన్ఫినిటీ ప్రొ౨౭౫ - లక్షణాలు\nమోడల్ నామ Micromax Funbook Infinity\nడిస్ప్లే ఫీచర్స్ Pinch-to-Zoom, Multi-touch Screen\nవెయిట్ 360 g\nసౌండ్ MP3, WAV, FLAC\

In [None]:
reduced_dataset = reduce_data(combined_mlm_data, 5)
get_denoised(reduced_dataset[0])

{'text': 'సంతానం చిత్ర నిర్మాత ఎవరు?',
 'input_ids': array([[109955, 250100,  33648, 250099,   4819,    291, 250098,      1]]),
 'labels': array([[250100,  99870, 250099,  89901,   6999,  85266, 250098,      1,
              1]])}

In [None]:
MLM_data_tuple = combined_mlm_data.map(get_denoised, batched=True)
MLM_data_tuple

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'input_ids', 'labels'],
    num_rows: 10000
})

In [None]:
print(MLM_data_tuple[0])
print(MLM_data_tuple[1])

{'text': 'రైట్ బ్రదర్స్ కనుగొన్న ప్రయాణ సాధనం ఏమిటి?', 'input_ids': [21940, 250100, 95236, 250099, 97793, 250098, 4664, 250097, 24057, 159112, 12384, 76643, 250096, 101722, 250095, 291, 250094, 0, 0, 250093, 0, 250092, 0, 0, 0, 250091, 0, 250090, 0, 250089, 0, 250088, 0, 250087, 0, 0, 0, 250086, 0, 0, 250085, 0, 0, 250084, 0, 0, 250083, 0, 250082, 0, 250081, 0, 250080, 0, 250079, 0, 250078, 0, 250077, 0, 250076, 0, 250075, 0, 250074, 0, 0, 0, 250073, 0, 0, 250072, 0, 250071, 0, 0, 250070, 0, 250069, 0, 250068, 0, 250067, 0, 0, 250066, 0, 250065, 0, 250064, 0, 0, 250063, 0, 0, 0, 250062, 1], 'labels': [250100, 167146, 250099, 18425, 250098, 4974, 250097, 79846, 250096, 29113, 250095, 147349, 250094, 1, 0, 250093, 0, 0, 250092, 0, 250091, 0, 0, 0, 250090, 0, 250089, 0, 0, 0, 0, 0, 250088, 0, 250087, 0, 0, 0, 250086, 0, 250085, 0, 250084, 0, 250083, 0, 250082, 0, 250081, 0, 0, 0, 0, 250080, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 250079, 0, 250078, 0, 250077, 0, 0, 250076, 0, 0, 250075, 0, 0, 0, 25

In [None]:
# input_ids, labels = MLM_data_tuple[0][0].tolist(), MLM_data_tuple[0][1].tolist()

In [None]:
# len(MLM_data_tuple)

In [None]:
# print(type(input_ids))
# print(len(input_ids))
# print(input_ids[0])
# print(len(input_ids[0]))
# print("######################")
# print(type(labels))
# print(len(labels))
# print(labels[0])
# print(len(labels[0]))

In [None]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    # function to tokenize squad input
    inputs = ['questiongeneration: ' + doc for doc in examples["context"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["question"], max_length=max_target_length, truncation=True,)
    # labels = tokenizer(labels, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
print(squad_dataset[0])

{'context': 'The Pew Forum on Religion & Public Life ranks Egypt as the fifth worst country in the world for religious freedom. The United States Commission on International Religious Freedom, a bipartisan independent agency of the US government, has placed Egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. According to a 2010 Pew Global Attitudes survey, 84% of Egyptians polled supported the death penalty for those who leave Islam; 77% supported whippings and cutting off of hands for theft and robbery; and 82% support stoning a person who commits adultery.', 'question': 'What percentage of Egyptians polled support death penalty for those leaving Islam?'}


In [None]:
tokenized_datasets = squad_dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns( ['context', 'question'])
print(tokenized_datasets)
print(tokenized_datasets[0])

Map:   0%|          | 0/80000 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 80000
})
{'input_ids': [250100, 259, 267, 486, 1668, 481, 6273, 351, 71394, 549, 7305, 7785, 23022, 263, 32213, 527, 287, 62181, 807, 48461, 270, 11395, 281, 287, 4836, 332, 259, 81421, 259, 69737, 260, 486, 4248, 9629, 22407, 351, 4265, 157663, 438, 259, 70277, 261, 259, 262, 837, 325, 74767, 22285, 259, 48970, 304, 287, 2279, 13749, 261, 1070, 2554, 285, 32213, 351, 2476, 17577, 2602, 304, 28901, 533, 27906, 9192, 52342, 5370, 288, 287, 13992, 305, 90865, 270, 304, 22049, 6540, 304, 259, 81421, 259, 69737, 52764, 285, 281, 631, 120249, 7670, 455, 287, 13749, 260, 259, 39609, 288, 259, 262, 1068, 1668, 481, 9454, 2584, 110386, 32944, 261, 259, 129416, 304, 32213, 39796, 42441, 345, 259, 66337, 287, 20862, 19208, 1421, 332, 259, 5480, 1866, 259, 15176, 7400, 296, 259, 139262, 259, 66337, 87023, 64860, 263, 305, 259, 66127, 4368, 304, 31521, 332, 287, 2508, 305, 186071, 20099, 296, 305, 259, 119999, 2478, 4

In [None]:
from tqdm import tqdm

final_interleaved_rows_input_ids = []
final_interleaved_rows_labels = []
loops = 3000
ctr = 0
for i in tqdm(range(loops)):
    # squad dataset
    # print(i, end=": ")
    for j in range(8):
      final_interleaved_rows_input_ids.append(tokenized_datasets[(i*8) + j]['input_ids'])
      final_interleaved_rows_labels.append(tokenized_datasets[(i*8) + j]['labels'])

      # print(ctr, "sq", final_interleaved_rows_input_ids[ctr])
      # ctr += 1

    # mlm task dataset
    final_interleaved_rows_input_ids.append(MLM_data_tuple[i]['input_ids'])
    final_interleaved_rows_labels.append(MLM_data_tuple[i]['labels'])

    # print(ctr, "mlm",  final_interleaved_rows_input_ids[ctr])
    # ctr += 1
    # print()

final_mt5_data = Dataset.from_dict({'input_ids': final_interleaved_rows_input_ids, 'labels': final_interleaved_rows_labels})
print(final_mt5_data)

for i in range(20):
  print(i, final_interleaved_rows_input_ids[i])
  print(i, final_mt5_data[i]['input_ids'])

# print(final_mt5_data[0])
# print(final_mt5_data[8])
# print(final_mt5_data[17])

100%|██████████| 1000/1000 [00:06<00:00, 149.09it/s]


Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 9000
})
0 [250100, 259, 267, 486, 1668, 481, 6273, 351, 71394, 549, 7305, 7785, 23022, 263, 32213, 527, 287, 62181, 807, 48461, 270, 11395, 281, 287, 4836, 332, 259, 81421, 259, 69737, 260, 486, 4248, 9629, 22407, 351, 4265, 157663, 438, 259, 70277, 261, 259, 262, 837, 325, 74767, 22285, 259, 48970, 304, 287, 2279, 13749, 261, 1070, 2554, 285, 32213, 351, 2476, 17577, 2602, 304, 28901, 533, 27906, 9192, 52342, 5370, 288, 287, 13992, 305, 90865, 270, 304, 22049, 6540, 304, 259, 81421, 259, 69737, 52764, 285, 281, 631, 120249, 7670, 455, 287, 13749, 260, 259, 39609, 288, 259, 262, 1068, 1668, 481, 9454, 2584, 110386, 32944, 261, 259, 129416, 304, 32213, 39796, 42441, 345, 259, 66337, 287, 20862, 19208, 1421, 332, 259, 5480, 1866, 259, 15176, 7400, 296, 259, 139262, 259, 66337, 87023, 64860, 263, 305, 259, 66127, 4368, 304, 31521, 332, 287, 2508, 305, 186071, 20099, 296, 305, 259, 119999, 2478, 4788, 1009, 259, 262, 2985, 1866

In [None]:
print(MLM_data_tuple[0]['input_ids'])
print(MLM_data_tuple[1]['input_ids'])

[21940, 250100, 95236, 250099, 97793, 250098, 4664, 250097, 24057, 159112, 12384, 76643, 250096, 101722, 250095, 291, 250094, 0, 0, 250093, 0, 250092, 0, 0, 0, 250091, 0, 250090, 0, 250089, 0, 250088, 0, 250087, 0, 0, 0, 250086, 0, 0, 250085, 0, 0, 250084, 0, 0, 250083, 0, 250082, 0, 250081, 0, 250080, 0, 250079, 0, 250078, 0, 250077, 0, 250076, 0, 250075, 0, 250074, 0, 0, 0, 250073, 0, 0, 250072, 0, 250071, 0, 0, 250070, 0, 250069, 0, 250068, 0, 250067, 0, 0, 250066, 0, 250065, 0, 250064, 0, 0, 250063, 0, 0, 0, 250062, 1]
[5442, 250100, 11831, 250099, 6972, 250098, 46375, 250097, 11582, 104296, 148873, 3769, 250096, 161765, 250095, 242470, 250094, 4783, 61010, 250093, 259, 250092, 47390, 549, 68048, 250091, 8248, 250090, 307, 250089, 5442, 250088, 11831, 250087, 46375, 10107, 11582, 250086, 148873, 3769, 250085, 161765, 238978, 250084, 239888, 133075, 250083, 88358, 250082, 5442, 250081, 17104, 250080, 239888, 250079, 5442, 250078, 11831, 250077, 17104, 250076, 11582, 250075, 16564, 2

In [None]:
# _reduced = reduce_data(tokenized_datasets, 9000)
data_dicts = final_mt5_data.to_dict()
print(data_dicts['input_ids'][0])
print(data_dicts['input_ids'][8])
print(data_dicts['input_ids'][17])

[250100, 259, 267, 486, 1668, 481, 6273, 351, 71394, 549, 7305, 7785, 23022, 263, 32213, 527, 287, 62181, 807, 48461, 270, 11395, 281, 287, 4836, 332, 259, 81421, 259, 69737, 260, 486, 4248, 9629, 22407, 351, 4265, 157663, 438, 259, 70277, 261, 259, 262, 837, 325, 74767, 22285, 259, 48970, 304, 287, 2279, 13749, 261, 1070, 2554, 285, 32213, 351, 2476, 17577, 2602, 304, 28901, 533, 27906, 9192, 52342, 5370, 288, 287, 13992, 305, 90865, 270, 304, 22049, 6540, 304, 259, 81421, 259, 69737, 52764, 285, 281, 631, 120249, 7670, 455, 287, 13749, 260, 259, 39609, 288, 259, 262, 1068, 1668, 481, 9454, 2584, 110386, 32944, 261, 259, 129416, 304, 32213, 39796, 42441, 345, 259, 66337, 287, 20862, 19208, 1421, 332, 259, 5480, 1866, 259, 15176, 7400, 296, 259, 139262, 259, 66337, 87023, 64860, 263, 305, 259, 66127, 4368, 304, 31521, 332, 287, 2508, 305, 186071, 20099, 296, 305, 259, 119999, 2478, 4788, 1009, 259, 262, 2985, 1866, 91386, 263, 90161, 1472, 260, 1]
[21940, 250100, 95236, 250099, 97793, 

In [None]:
import json

# data_dicts = final_mt5_data.to_dict()
with open('/content/drive/MyDrive/IRE/datasets/final_mt5_dataset.json', 'w') as fp:
    json.dump(data_dicts, fp, sort_keys=True, indent=4)

In [None]:
# # Initialize final_dataset with empty lists if not already initialized
# from tqdm import tqdm
# final_dataset = {"input_ids": [], "labels": []}

# # Using list comprehensions for efficiency
# final_dataset["input_ids"] = [item for i in tqdm(range(6000, 7000)) for item in tokenized_datasets["train"]["input_ids"][i*8:(i+1)*8] + MLM_data_tuple[i][0].tolist()]
# final_dataset["labels"] = [item for i in tqdm(range(6000, 7000)) for item in tokenized_datasets["train"]["labels"][i*8:(i+1)*8] + MLM_data_tuple[i][1].tolist()]

In [None]:
# # saving dataset
# import json
# with open('/content/drive/MyDrive/IRE/project/IRE_final_data_7.json', 'w') as fp:
#     json.dump(final_dataset, fp, sort_keys=True, indent=4)

In [None]:
# # Loading dataset
# with open('data.json', 'r') as fp:
#     data = json.load(fp)

In [None]:
# print(final_dataset["input_ids"][0])
# print(len(final_dataset["input_ids"][0]))
# print(final_dataset["labels"][0])
# print(len(final_dataset["labels"][0]))
# print(final_dataset["input_ids"][8])
# print(len(final_dataset["input_ids"][8]))
# print(final_dataset["labels"][8])
# print(len(final_dataset["labels"][8]))
# #########################################
# print(final_dataset["input_ids"][17])
# print(len(final_dataset["input_ids"][17]))
# print(final_dataset["labels"][17])
# print(len(final_dataset["labels"][17]))