# **Env. Setup**

pip installation: Colab TPU
Colab TPU runtimes come with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU (we added these lines to our finetuning code already). This code should be part of your python file, not in the colab cell :
```
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```

Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so installing jax[tpu] should be avoided on Colab. If, for any reason, you would like to update the jax & jaxlib libraries on a Colab TPU runtime, follow the CPU instructions above (i.e., install jax[cpu]).

https://github.com/google/jax#pip-installation-colab-tpu

FLAX (JAX) is also flexible to work with GPU, and to make it work with GPU; it is better to remove the below lines from each code. You will find these lines at the beginning of each finetuning code in our examples.
```
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```



In [None]:
!pip3 install transformers==4.26.1
!pip install sentencepiece
!pip install datasets
!pip3 install evaluate
!pip install farasapy
!pip install pyarabic
!pip install fuzzysearch
!pip3 install bert_score
!pip3 install rouge_score
!pip3 install camel_tools
!pip3 install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.26.1
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m52.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers==4.26.1)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.26.1)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m88.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.15.1 tokenizers-0.13.3 transform

Jaxlib and Jax of Google Colab use a special version of Jaxlib that work with google colab. A newer version of jaxlib may only work with TPU VM devices. So first, we need to check what Jax and Jaxlib versions this Google colab uses, and then we will force the installation of Flax to match both the current Jax and Jaxlib versions.

We also added wandb to our codes so you can log your experiments. More about wandb https://wandb.ai .

If you want to disable wandb, run the cell below :

In [None]:
!wandb offline
!wandb disabled

W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B.
W&B disabled.


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
import jaxlib
print(jax.__version__)
print(jaxlib.__version__)

0.3.25
0.3.25


In [None]:
!pip install jax==0.3.25 jaxlib==0.3.25 flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
INFO: pip is looking at multiple versions of flax to determine which version is compatible with other requirements. This could take a while.
Collecting flax
  Downloading flax-0.6.10-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.1/226.1 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading flax-0.6.8-py3-none-any.whl (216 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m216.0/216.0 kB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading flax-0.6.7-py3-none-any.whl (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.2/214.2 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading flax-0.6.6-py3-none-any.whl (210 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.1/210.1 kB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading flax-0.6.4-py3-no

# **General Direction in how to use FLAX and T5 Model**

Seq2Seq model work simply by having an input text and target text, and you add tags for each one. In contrast to extractive models like BERT, ELECTRA, and ALBERT, we use the same input-target format for all tasks. In this colab, we will adapt this code https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py . If you go down, you will find code called "T5 QA" which is an adapation of summarization code for QA tasks.

For example:
- let's say that you try to target machine translation and you have a CSV file that has the first column "english_text" and a second column "arabic_text", then your code will look like this  :
```
for english_text,arabic_text in zip(examples["english_text"],examples["arabic_text"]):
            inputs.append("english : %s" % (english_text))
            targets.append("arabic : %s" % (arabic_text))
```
After the code reads all examples, your input and output will look like this :
```
input: <english : artificial intelligence is the future>
target: <arabic : الذكاء الاصطناعي هو المستقبل >
```
We use the zip function to combine two entries, and we use %s to insert variables that come after % mark to our text. If you have two entries (e.g., in question answering task, we have both context and question as input), then you can have something like this :
```
for context,question,answer in zip(examples["context"],examples["question"],examples['answers']):
            inputs.append("question: %s context: %s" % (question,context))
            targets.append("answer: %s" % (answer["text"]))
```
Your example after preprocessing will look like this :
```
 input: <question: ماهي عاصمة ايطاليا context : روما هي عاصمة ايطاليا واكبر مدنها >
 target: < روما >
```
- For Sentiment Analysis :
Assuming you have a column named "sentence" and "labels" in your CSV files, then you can have a code like this.
```
for sentence,label in zip(examples["sentence"],examples["label"]):
            inputs.append("%s" % (sentence))
            targets.append("%s" % (label))
```
Your example after preprocessing will look like this:
```
 input: <التلفاز الذي قمت بشراءه من المتجر لم يعجبني >
 target: < negative >
```

- For Text Summarization :
Assuming you have a column named "article" and "summary" in your CSV files, then you can have a code like this
```
for article,summary in zip(examples["article"],examples["summary"]):
            inputs.append("%s" % (article))
            targets.append("%s" % (summary))
```


Observe also that we did not use prefixes in examples where we have one entry, but we need it when our example consists of two entries, like question answering.


You also need to prefix even if your example has one entry if you are planning to finetune the T5 model to target different tasks. For example, if you fine-tune your model in question answering first. You add a prefix in the input and target. Then you finetune it in sentiment analysis by adding a prefix also. The model then will understand which task you are targeting during the inference by detecting those prefixes.

Suppose you want to use a different evaluation metric. In that case, you can add your evaluation metric method (e.g., F1, Accuracy .. ) in lines 765-775 under  compute_metrics function.

Huggingface evaluate library has a list of evaluation metrics that you can load using evaluate.load("") function.

For example, to add a function for EM/F1 score for Arabic TyDi Evaluation, we first add these lines to line 605 (this EM/F1 function is adapted from this [colab](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb))




```
def normalize_answer(s):

      def remove_articles(text):
          return re.sub(r'\b(a|an|the)\b', ' ', text)

      def white_space_fix(text):
          return ' '.join(text.split())

      def remove_punc(text):
          exclude = set(string.punctuation)
          return ''.join(ch for ch in text if ch not in exclude)

      def lower(text):
          return text.lower()

      def ar_normalize(text):
          text = text.strip()
          text = normalize_alef_ar(text)
          text = normalize_alef_maksura_ar(text)
          text = normalize_teh_marbuta_ar(text)
          text=  araby.strip_diacritics(text)
          return text

      return ar_normalize(white_space_fix(remove_articles(remove_punc(lower(s)))))


    def f1_score(prediction, ground_truth):
        prediction_tokens = normalize_answer(prediction).split()
        ground_truth_tokens = normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1


    def exact_match_score(prediction, ground_truth):
        return (normalize_answer(prediction) == normalize_answer(ground_truth))


    def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
        scores_for_ground_truths = []
        for ground_truth in ground_truths:
            score = metric_fn(prediction, ground_truth)
            scores_for_ground_truths.append(score)
        return max(scores_for_ground_truths)


    def evaluate_qa(gold_answers, predictions):
        f1 = exact_match = total = 0

        for ground_truths, prediction in zip(gold_answers, predictions):
          total += 1
          exact_match += metric_max_over_ground_truths(
                        exact_match_score, prediction, ground_truths)
          f1 += metric_max_over_ground_truths(
              f1_score, prediction, ground_truths)

        exact_match = exact_match / total
        f1 = f1 / total
        return exact_match,f1
```

Then, we use this evaluation function by reporting its results in lines 709 :


```
 result["EM"],result["F1"]=evaluate_qa(references,decoded_preds)
 ```

Basically, you need to insert any evaluation score you want the code to show it up during the evaluation to [result] dictionary inside compute_metrics function.

Also, observe that in line 69, we add this code :
```
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```

You can remove it if you are using GPU.

Finally, if you prefer to work with torch code instead of FLAX, you can make the same change we did it here to https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization.py . Both FLAX and Torch codes are very similar except for small changes to fit the FLAX environment.

# **Running Flax Code on Question Answering**

First, let's process the Arabic TyDi QA dataset. TyDi combines all questions for more than 11 different languages, so we need to extract the Arabic portion and also do some pre-processing. All credit to AUB Mind LAB mind.

In [None]:
!git clone https://github.com/aub-mind/arabert
!cp arabert/examples/question-answering/utils_qa.py .
!cp arabert/examples/question-answering/trainer_qa.py .
!cp arabert/examples/question-answering/run_qa.py .
!cp arabert/examples/question-answering/squad_preprocessing.py .

Cloning into 'arabert'...
remote: Enumerating objects: 600, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 600 (delta 38), reused 45 (delta 30), pack-reused 535[K
Receiving objects: 100% (600/600), 9.14 MiB | 30.20 MiB/s, done.
Resolving deltas: 100% (339/339), done.


In [None]:
!wget https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-train.json
!wget https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-dev.json

--2023-06-13 19:20:23--  https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-train.json
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.103.128, 108.177.120.128, 142.250.159.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.103.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58004076 (55M) [application/json]
Saving to: ‘tydiqa-goldp-v1.1-train.json’


2023-06-13 19:20:23 (180 MB/s) - ‘tydiqa-goldp-v1.1-train.json’ saved [58004076/58004076]

--2023-06-13 19:20:23--  https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-dev.json
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.103.128, 108.177.120.128, 142.250.159.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.103.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5617409 (5.4M) [application/json]
Saving to: ‘tydiqa-goldp-v1.1-dev.json’


2023-06-13 19:20:23 (145 MB

In [None]:
model_name="sultan/ArabicT5-Large"

In [None]:
!rm -rf *-pre.json
!python squad_preprocessing.py \
  --input_file "tydiqa-goldp-v1.1-train.json" \
  --output_file "tydiqa-goldp-v1.1-train-pre.json" \
  --model_name=$model_name \
  --filter_tydiqa=True

sultan/ArabicT5-Large
W0613 19:20:28.659094 140288135685952 preprocess.py:264] Model provided is not in the accepted model list. Preprocessor will default to a base Arabic preprocessor
 'الإسكندر الثالث المقدوني ، المعروف بأسماء عديدة أخرى أبرزها : الإسكندر الأكبر ، و < b data - parsoid = ' { " dsr " : [1849 , 1870 , 3 , 3] } ' > الإسكندر الكبير ، و < b data - parsoid = ' { " dsr " : [1873 , 1896 , 3 , 3] } ' > الإسكندر المقدوني ، و < b data - parsoid = ' { " dsr " : [1899 , 1924 , 3 , 3] } ' > الإسكندر ذو القرنين ( باليونانية : ؛ نقحرة : ) ، هو أحد ملوك مقدونيا الإغريق ، ومن أشهر القادة العسكريين والفاتحين عبر التاريخ . ولد الإسكندر في مدينة يلا قرابة سنة 356 ق . م ، وتتلمذ على يد الفيلسوف والعالم الشهير أرسطو حتى بلغ ربيعه السادس عشر . وبحلول عامه الثلاثين ، كان قد أسس إحدى أكبر وأعظم الإمبراطوريات التي عرفها العالم القديم ، والتي امتدت من سواحل البحر الأيوني غربا وصولا إلى سلسلة جبال الهيمالايا شرقا . يعد أحد أنجح القادة العسكريين في مسيرتهم ، إذ لم يحصل أن هزم في أي معركة خاضها على

In [None]:
!python squad_preprocessing.py \
  --input_file "tydiqa-goldp-v1.1-dev.json" \
  --output_file "tydiqa-goldp-v1.1-dev-pre.json" \
  --model_name=$model_name \
  --filter_tydiqa=True

sultan/ArabicT5-Large
W0613 19:20:55.897143 140151005525824 preprocess.py:264] Model provided is not in the accepted model list. Preprocessor will default to a base Arabic preprocessor
100% 5077/5077 [00:01<00:00, 4842.00it/s]


In [None]:
import json
import csv
def squad_json_to_csv(json_file,csv_output):
 with open(csv_output,  'w', newline='') as csvfile:
    tsv_writer = csv.writer(csvfile, delimiter=',', quoting=csv.QUOTE_MINIMAL)
    with open(json_file, 'r') as tydi_json:
      data = json.loads(tydi_json.read())
      tydi=data["data"]
      tsv_writer.writerow(["id","context","title","question","answer1","answer2"])
      for entry in tydi:
        if "arabic" in entry["paragraphs"][0]["qas"][0]["id"]:
            answers=[] ### in some question of TyDi there is more than one answer (ground truths) for the question.
            for answer in entry["paragraphs"][0]["qas"][0]["answers"]:
             answers.append(answer["text"])
            if len(answers)==1: ### if there is one answer then add the first one to answer1 column and empty "" value to answer2
             tsv_writer.writerow([entry["paragraphs"][0]["qas"][0]["id"],entry["paragraphs"][0]["context"],entry["title"],entry["paragraphs"][0]["qas"][0]["question"],entry["paragraphs"][0]["qas"][0]["answers"][0]["text"],"-"])
            else: ## if there are two answers, add the first one to answer1 column and second one to answer2 column
             tsv_writer.writerow([entry["paragraphs"][0]["qas"][0]["id"],entry["paragraphs"][0]["context"],entry["title"],entry["paragraphs"][0]["qas"][0]["question"],entry["paragraphs"][0]["qas"][0]["answers"][0]["text"],entry["paragraphs"][0]["qas"][0]["answers"][1]["text"]])
squad_json_to_csv("tydiqa-goldp-v1.1-train-pre.json","tydiqa-goldp-v1.1-train-pre.csv")
squad_json_to_csv("tydiqa-goldp-v1.1-dev-pre.json","tydiqa-goldp-v1.1-dev-pre.csv")

Below we used the code "https://github.com/huggingface/transformers/tree/main/examples/flax/summarization" at the huggingface library, and we modified it for our task. You should click on the play button beside "show code" so the code will be written to this colab. If you want to use this code on your local machine, which has a GPU, you can simply save this written file to your local machine and run the code.

Make sure to remove the line

```
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```



if you are planning to use your GPU or if you want to use TPU VM directly on Google Cloud Compute.


In addtion, to make our code work with our baseline model AraBART we change the following code :

```
decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id, config.decoder_start_token_id)
```

with :



```
if "MBartModel" in str(config.architectures):
    decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id)
else:
    decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id, config.decoder_start_token_id)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
```

This is because AraBART uses MBart architecture which take only two arguments for its shift right tokens.


In [None]:
#@title T5 QA
%%writefile t5_qa.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for summarization.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Callable, Optional
from sklearn.metrics import f1_score
from collections import Counter
import string
import re
import csv
import pandas as pd
from jax.lib import xla_bridge
import datasets
import nltk  # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import pyarabic.araby as araby
from camel_tools.utils.normalize import normalize_alef_ar
from camel_tools.utils.normalize import normalize_alef_maksura_ar
from camel_tools.utils.normalize import normalize_teh_marbuta_ar
from huggingface_hub import Repository
from transformers import (
    CONFIG_MAPPING,
    FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSeq2SeqLM,
    HfArgumentParser,
    is_tensorboard_available,
)
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry


logger = logging.getLogger(__name__)
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class TrainingArguments:
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
    do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
    adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
    adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
    adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
    label_smoothing_factor: float = field(
        default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
    )
    adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
    warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
    logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
    save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
    eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
    eval_per_epoch: int = field(default=None, metadata={"help": "Run an evaluation every X epochs"})
    seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
    push_to_hub: bool = field(
        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
    )
    hub_model_id: str = field(
        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
    )
    hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
    gradient_checkpointing: bool = field(
        default=False,
        metadata={
            "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
        },
    )

    def __post_init__(self):
        if self.output_dir is not None:
            self.output_dir = os.path.expanduser(self.output_dir)

    def to_dict(self):
        """
        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
        the token values by removing their value.
        """
        d = asdict(self)
        for k, v in d.items():
            if isinstance(v, Enum):
                d[k] = v.value
            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
                d[k] = [x.value for x in v]
            if k.endswith("_token"):
                d[k] = f"<{k.upper()}>"
        return d


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": (
                "Floating-point format in which the model weights should be initialized and trained. Choose one of"
                " `[float32, float16, bfloat16]`."
            )
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    text_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    summary_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                "This argument is also used to override the `max_length` param of `model.generate`, which is used "
                "during evaluation."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    source_prefix: Optional[str] = field(
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )
    predict_with_generate: bool = field(
        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
    )
    num_beams: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
                "which is used during evaluation."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length


summarization_name_mapping = {
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
    "cnn_dailymail": ("article", "highlights"),
    "orange_sum": ("text", "summary"),
    "pn_summary": ("article", "summary"),
    "psc": ("extract_text", "summary_text"),
    "samsum": ("dialogue", "summary"),
    "thaisum": ("body", "summary"),
    "xglue": ("news_body", "news_title"),
    "xsum": ("document", "summary"),
    "wiki_summary": ("article", "highlights"),
}


class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch


def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval_{metric_name}", value, step)


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_summarization", model_args, data_args, framework="flax")

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name, token=training_args.hub_token
            )
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files this script will use the first column for the full texts and the second column for the
    # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
    #
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            keep_in_memory=False,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        dataset = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            from_pt=True,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = FlaxAutoModelForSeq2SeqLM.from_config(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    if training_args.gradient_checkpointing:
        model.enable_gradient_checkpointing()

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    elif training_args.do_eval:
        column_names = dataset["validation"].column_names
    elif training_args.do_predict:
        column_names = dataset["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

    # Get the column names for input/target.
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    # In Flax, for seq2seq models we need to pass `decoder_input_ids`
    # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
    # for that dynamically import the `shift_tokens_right` function from the model file
    model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
    shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

    # Setting padding="max_length" as we need fixed length inputs for jitted functions
    def normalize_answer(s):

      def remove_articles(text):
          return re.sub(r'\b(a|an|the)\b', ' ', text)

      def white_space_fix(text):
          return ' '.join(text.split())

      def remove_punc(text):
          exclude = set(string.punctuation)
          return ''.join(ch for ch in text if ch not in exclude)

      def lower(text):
          return text.lower()

      def ar_normalize(text):
          text = text.strip()
          text = normalize_alef_ar(text)
          text = normalize_alef_maksura_ar(text)
          text = normalize_teh_marbuta_ar(text)
          text=  araby.strip_diacritics(text)
          return text

      return ar_normalize(white_space_fix(remove_articles(remove_punc(lower(s)))))


    def f1_score(prediction, ground_truth):
        prediction_tokens = normalize_answer(prediction).split()
        ground_truth_tokens = normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1


    def exact_match_score(prediction, ground_truth):
        return (normalize_answer(prediction) == normalize_answer(ground_truth))


    def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
        scores_for_ground_truths = []
        for ground_truth in ground_truths:
            score = metric_fn(prediction, ground_truth)
            scores_for_ground_truths.append(score)
        return max(scores_for_ground_truths)


    def evaluate_qa(gold_answers, predictions):
        f1 = exact_match = total = 0

        for ground_truths, prediction in zip(gold_answers, predictions):
          total += 1
          exact_match += metric_max_over_ground_truths(
                        exact_match_score, prediction, ground_truths)
          f1 += metric_max_over_ground_truths(
              f1_score, prediction, ground_truths)

        exact_match = exact_match / total
        f1 = f1 / total
        return exact_match,f1


    def preprocess_function(examples):
        #inputs = examples["context_question"]
        inputs=[]
        targets=[]
        #### Question Answering ####
        ## although for some question in TyDi we have more than one answer, we will only use first one for training. For evaluation we will use both.
        for context,question,answer in zip(examples["context"],examples["question"],examples['answer1']):
            inputs.append("question: %s context: %s" % (question,context))
            targets.append("%s" % (answer))
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np")
        labels = tokenizer(
            text_target=targets,
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="np",
        )
        model_inputs["labels"] = labels["input_ids"]
        if "MBartModel" in str(config.architectures):
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id)
        else:
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id, config.decoder_start_token_id)
        model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)

        # We need decoder_attention_mask so we can ignore pad tokens from loss
        model_inputs["decoder_attention_mask"] = labels["attention_mask"]

        return model_inputs


    if training_args.do_train:
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = dataset["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset",
        )

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = dataset["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        eval_dataset = eval_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset",
        )

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in dataset:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = dataset["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        predict_dataset = predict_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on prediction dataset",
        )

    # Metric
    def norm_seq(seq_text):
        seq_text=seq_text.replace("<answer>","")
        seq_text=seq_text.replace("<nswer>","")
        seq_text=seq_text.replace("answer>","")
        seq_text=seq_text.replace("<title>","")
        seq_text=seq_text.replace("<itle>","")
        seq_text=seq_text.replace("title>","")
        seq_text=seq_text.replace("<question>","")
        seq_text=seq_text.replace("<uestion>","")
        seq_text=seq_text.replace("question>","")
        seq_text=seq_text.replace("<context>","")
        seq_text=seq_text.replace("<ontext>","")
        seq_text=seq_text.replace("context>","")
        return seq_text
    def postprocess_text(preds, labels):
        preds = [norm_seq(pred.strip()) for pred in preds]
        labels = [norm_seq(label.strip()) for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

    def compute_metrics(preds, labels):
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred_dict=pd.DataFrame(columns=["label","pred"])
        for label_a,pred_a in zip(decoded_labels,decoded_preds):
         pred_dict=pd.concat([pred_dict, pd.DataFrame([{"label":label_a,"pred":pred_a}])], ignore_index=True)
        pred_dict.to_csv("pred.csv",index=False)
        ### add below lines if you want to calcuate rouge results #####
        #metric = evaluate.load("rouge")
        #result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        ##############################
        #prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        #result["BLEU"]=evaluate.load("bleu").compute(predictions=decoded_preds, references=decoded_labels)["bleu"]
        #result["gen_len"] = np.mean(prediction_lens)
        #### this code block for TyDi evaluation. Remove it if you have different task ###
        result={}
        tydi_dev_file=pd.read_csv(data_args.validation_file)
        references=[]
        for index,row in tydi_dev_file.iterrows():
          answers=[]
          answers.append(str(row["answer1"]))
          if str(row["answer2"])!="":
            answers.append(str(row["answer2"]))
          references.append(answers)
        result["EM"],result["F1"]=evaluate_qa(references,decoded_preds)
        result = {k: round(v * 100, 4) for k, v in result.items()}
        return result

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set(
            [
                layer[-2:]
                for layer_norm_name in layer_norm_candidates
                for layer in flat_params.keys()
                if layer_norm_name in "".join(layer).lower()
            ]
        )
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)

    # label smoothed cross entropy
    def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
        """
        The label smoothing implementation is adapted from Flax's official example:
        https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
        """
        vocab_size = logits.shape[-1]
        confidence = 1.0 - label_smoothing_factor
        low_confidence = (1.0 - confidence) / (vocab_size - 1)
        normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
        )
        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

        loss = optax.softmax_cross_entropy(logits, soft_labels)
        loss = loss - normalizing_constant

        # ignore padded tokens from loss
        loss = loss * padding_mask
        loss = loss.sum()
        num_labels = padding_mask.sum()
        return loss, num_labels

    # Define gradient update step fn
    def train_step(state, batch, label_smoothing_factor=0.0):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
            return loss, num_labels

        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
        (loss, num_labels), grad = grad_fn(state.params)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        # true grad = total grad / total samples
        grad = jax.lax.psum(grad, "batch")
        grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
        new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch, label_smoothing_factor=0.0):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]

        loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        metrics = {"loss": loss}
        return metrics

    # Define generation function
    max_length = (
        data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
    )
    num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
    gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

    def generate_step(params, batch):
        model.params = params
        output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
        return output_ids.sequences

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(
        partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
    )
    p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
    p_generate_step = jax.pmap(generate_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0, leave=True)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for _ in tqdm(range(steps_per_epoch), desc="Training...", position=0, leave=True):
            batch = next(train_loader)
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
            f" {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_preds = []
        eval_labels = []

        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
        eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
        for _ in tqdm(range(eval_steps), desc="Evaluating...", position=0, leave=True):
            # Model forward
            batch = next(eval_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            eval_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                eval_labels.extend(labels)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(eval_preds, eval_labels)
            eval_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics and update progress bar
        loss_score=round(float(eval_metrics['loss']),4)
        desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {loss_score} | {rouge_desc})"
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0 and epoch == int(training_args.num_train_epochs)-1:
            params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params)
            tokenizer.save_pretrained(training_args.output_dir)
            if training_args.push_to_hub:
                repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

    # ======================== Prediction loop ==============================
    if training_args.do_predict:
        logger.info("*** Predict ***")

        pred_metrics = []
        pred_generations = []
        pred_labels = []

        pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
        pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
        for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=True):
            # Model forward
            batch = next(pred_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            pred_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                pred_labels.extend(labels)

        # normalize prediction metrics
        pred_metrics = get_metrics(pred_metrics)
        pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(pred_generations, pred_labels)
            pred_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics
        desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
        logger.info(desc)

        # save final metrics in json
        if jax.process_index() == 0:
            rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
            path = os.path.join(training_args.output_dir, "test_results.json")
            with open(path, "w") as f:
                json.dump(rouge_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    main()


Writing t5_qa.py


In [None]:
!python3 t5_qa.py --output_dir out --model_name_or_path sultan/ArabicT5-Base \
--tokenizer_name sultan/ArabicT5-Base \
--train_file=tydiqa-goldp-v1.1-train-pre.csv \
--validation_file=tydiqa-goldp-v1.1-dev-pre.csv \
--do_train \
--do_eval \
--predict_with_generate \
--num_train_epochs 9 \
--overwrite_cache \
--learning_rate 2e-4 \
--warmup_steps 0 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 8 \
--overwrite_output_dir \
--max_source_length 512 \
--max_target_length 128

INFO:__main__:Training/evaluation parameters TrainingArguments(output_dir='out', overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, per_device_train_batch_size=2, per_device_eval_batch_size=8, learning_rate=0.0002, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, label_smoothing_factor=0.0, adafactor=False, num_train_epochs=9.0, warmup_steps=0, logging_steps=500, save_steps=500, eval_steps=None, eval_per_epoch=None, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None, gradient_checkpointing=False)
Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-c3e0da5085c6c983/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...
Downloading data files: 100% 2/2 [00:00<00:00, 6949.97it/s]
Extracting data files: 100% 2/2 [00:00<00:00, 1072.44it/s]
Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-c3e0da5085c6c983/0.0.0/6954658bab30a358235fa864b05c

The code also will produce predictions and save it to "pred.csv" file

# **Running Flax Code on Text Summarization**

Here we made small change to the Question Answering FLAX code in line 665 by having this code (two columns called "text" and "summary") should be in your dataset:



```
for text,summary in zip(examples["text"],examples["summary"]):
            inputs.append(text)
            targets.append(summary)
```



Please note that using the rouge package that comes with evaluate library will not properly compute the rouge score for Arabic, and the score will be very low, <10 for the RougeL score. This is because the raw rouge score will remove all non-English characters.

You should also note that some papers use stemmers to calculate the rougeL score, and some don't. Xl-Sum dataset uses stemmer for the Arabic language, and we follow the same direction. Stemmer will increase the rougeL score by 4-5%.

In [None]:
!git clone https://github.com/csebuetnlp/xl-sum
!pip3 install -r xl-sum/multilingual_rouge_scoring/requirements.txt
!pip3 install  --upgrade xl-sum/multilingual_rouge_scoring

Cloning into 'xl-sum'...
remote: Enumerating objects: 1159, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 1159 (delta 21), reused 30 (delta 12), pack-reused 1109[K
Receiving objects: 100% (1159/1159), 5.53 MiB | 19.33 MiB/s, done.
Resolving deltas: 100% (311/311), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/otuncelli/turkish-stemmer-python (from -r xl-sum/multilingual_rouge_scoring/requirements.txt (line 1))
  Cloning https://github.com/otuncelli/turkish-stemmer-python to /tmp/pip-req-build-6fi8vruq
  Running command git clone --filter=blob:none --quiet https://github.com/otuncelli/turkish-stemmer-python /tmp/pip-req-build-6fi8vruq
  Resolved https://github.com/otuncelli/turkish-stemmer-python to commit 0c22380bf84a5ab1f219f4a905274c78afa04ed1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting git+http

We will evaluate our model on XL-Sum dataset. Download the Arabic portion from this link : https://github.com/csebuetnlp/xl-sum#datasets .

After downloading it (arabic_XLSum_v2.0.tar.bz2) , upload it to this notebook and run the cell below.

In [None]:
!tar -xvf arabic_XLSum_v2.0.tar.bz2

./
./arabic_test.jsonl
./arabic_val.jsonl
./arabic_train.jsonl


In [None]:
import pandas as pd
import csv
def convert_jsonl_to_csv(json_file,csv_file):
  with open(csv_file,  'w', newline='') as f:
    tsv_writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
    tsv_writer.writerow(["id","text","title","summary"])
    data = pd.read_json(json_file, lines=True,encoding='utf-8')
    for index,row in data.iterrows():
      tsv_writer.writerow([row["id"],row["text"],row["title"],row["summary"]])

convert_jsonl_to_csv("arabic_train.jsonl","XL-Sum-Train.csv")
convert_jsonl_to_csv("arabic_val.jsonl","XL-Sum-Dev.csv")
convert_jsonl_to_csv("arabic_test.jsonl","XL-Sum-Test.csv")

Lets see a sample from the dataset

In [None]:
!head -100 XL-Sum-Train.csv >> XL-Sum-Train-Sample.csv

Go to the left panel and look for the file XL-Sum-Train-Sample.csv and double-click on it to see the sample dataset.

We need to download NLTK stopwords package

In [None]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

We updated our compute metric function to calcuate rouge score using this code :
```
 scorer = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeL'], use_stemmer=True, lang="arabic")
    bleu = evaluate.load("bleu")
    def Average(lst):
        return sum(lst) / len(lst)
    def calc_rougeL(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rougeL"].fmeasure
    def calc_rouge1(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rouge1"].fmeasure
    def calc_rouge2(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rouge2"].fmeasure
    def compute_metrics(preds, labels):
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred_dict=pd.DataFrame(columns=["label","pred"])
        for label_a,pred_a in zip(decoded_labels,decoded_preds):
         pred_dict=pd.concat([pred_dict, pd.DataFrame([{"label":label_a,"pred":pred_a}])], ignore_index=True)
        pred_dict.to_csv("pred.csv",index=False)
        rougeL_score_list=[]
        rouge1_score_list=[]
        rouge2_score_list=[]
        for pred,label in zip(decoded_preds,decoded_labels):
           rougeL_score_list.append(calc_rougeL(pred,label))
           rouge1_score_list.append(calc_rouge1(pred,label))
           rouge2_score_list.append(calc_rouge2(pred,label))
        rougeL_score=Average(rougeL_score_list)
        rouge1_score=Average(rouge1_score_list)
        rouge2_score=Average(rouge2_score_list)

```



what we did here is that we compute the rouge scores for each entry, and then we took the average (mid) for all scores. This is in line with XL-Sum method to calculate the rouge score. Following the training process, we will verify this method using the official evaluation metric from Xl-Sum authors.

In [None]:
#@title t5_summary.py
%%writefile t5_summary.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for summarization.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Callable, Optional
from sklearn.metrics import f1_score
from collections import Counter
import string
import re
import csv
import pandas as pd
from jax.lib import xla_bridge
import datasets
import nltk  # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
import nltk
#nltk.download('stopwords')
from rouge_score import rouge_scorer
from transformers import (
    CONFIG_MAPPING,
    FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSeq2SeqLM,
    HfArgumentParser,
    is_tensorboard_available,
)
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
import wandb
wandb.init(project="ArabicT5-Eval")
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
logger = logging.getLogger(__name__)
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class TrainingArguments:
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
    do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
    adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
    adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
    adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
    label_smoothing_factor: float = field(
        default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
    )
    adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
    warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
    logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
    save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
    eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
    eval_per_epoch: int = field(default=None, metadata={"help": "Run an evaluation every X epochs"})
    seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
    push_to_hub: bool = field(
        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
    )
    hub_model_id: str = field(
        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
    )
    hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
    gradient_checkpointing: bool = field(
        default=False,
        metadata={
            "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
        },
    )

    def __post_init__(self):
        if self.output_dir is not None:
            self.output_dir = os.path.expanduser(self.output_dir)

    def to_dict(self):
        """
        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
        the token values by removing their value.
        """
        d = asdict(self)
        for k, v in d.items():
            if isinstance(v, Enum):
                d[k] = v.value
            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
                d[k] = [x.value for x in v]
            if k.endswith("_token"):
                d[k] = f"<{k.upper()}>"
        return d


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": (
                "Floating-point format in which the model weights should be initialized and trained. Choose one of"
                " `[float32, float16, bfloat16]`."
            )
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    text_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    summary_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                "This argument is also used to override the `max_length` param of `model.generate`, which is used "
                "during evaluation."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    source_prefix: Optional[str] = field(
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )
    predict_with_generate: bool = field(
        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
    )
    num_beams: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
                "which is used during evaluation."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length


summarization_name_mapping = {
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
    "cnn_dailymail": ("article", "highlights"),
    "orange_sum": ("text", "summary"),
    "pn_summary": ("article", "summary"),
    "psc": ("extract_text", "summary_text"),
    "samsum": ("dialogue", "summary"),
    "thaisum": ("body", "summary"),
    "xglue": ("news_body", "news_title"),
    "xsum": ("document", "summary"),
    "wiki_summary": ("article", "highlights"),
}


class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch


def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval_{metric_name}", value, step)


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    wandb.log({"task_name":"Xl_Sum_Summary","model_name":model_args.model_name_or_path,"per_device_train_batch_size":training_args.per_device_train_batch_size, "learning_rate": training_args.learning_rate, "weight_decay":training_args.weight_decay, "adam_beta1":training_args.adam_beta1, "adam_beta2" : training_args.adam_beta2, "adam_epsilon":training_args.adam_epsilon,"label_smoothing_factor": training_args.label_smoothing_factor,"num_train_epochs":training_args.num_train_epochs, "warmup_steps":training_args.warmup_steps,"seed": training_args.seed})
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_summarization", model_args, data_args, framework="flax")

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name, token=training_args.hub_token
            )
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files this script will use the first column for the full texts and the second column for the
    # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
    #
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            keep_in_memory=False,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        dataset = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            from_pt=True,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = FlaxAutoModelForSeq2SeqLM.from_config(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    if training_args.gradient_checkpointing:
        model.enable_gradient_checkpointing()

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    elif training_args.do_eval:
        column_names = dataset["validation"].column_names
    elif training_args.do_predict:
        column_names = dataset["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

    # Get the column names for input/target.
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    # In Flax, for seq2seq models we need to pass `decoder_input_ids`
    # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
    # for that dynamically import the `shift_tokens_right` function from the model file
    model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
    shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

    # Setting padding="max_length" as we need fixed length inputs for jitted functions

    def preprocess_function(examples):
        #inputs = examples["context_question"]
        inputs=[]
        targets=[]

        for text,summary in zip(examples["text"],examples["summary"]):
            inputs.append(text)
            targets.append(summary)
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np")
        labels = tokenizer(
            text_target=targets,
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="np",
        )
        model_inputs["labels"] = labels["input_ids"]
        if "MBartModel" in str(config.architectures):
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id)
        else:
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id, config.decoder_start_token_id)
        model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)

        # We need decoder_attention_mask so we can ignore pad tokens from loss
        model_inputs["decoder_attention_mask"] = labels["attention_mask"]

        return model_inputs


    if training_args.do_train:
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = dataset["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset",
        )

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = dataset["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        eval_dataset = eval_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset",
        )

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in dataset:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = dataset["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        predict_dataset = predict_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on prediction dataset",
        )

    # Metric
    #metric = evaluate.load("rouge")
    def norm_seq(seq_text):
        seq_text=seq_text.replace("<answer>","")
        seq_text=seq_text.replace("<nswer>","")
        seq_text=seq_text.replace("answer>","")
        seq_text=seq_text.replace("<title>","")
        seq_text=seq_text.replace("<itle>","")
        seq_text=seq_text.replace("title>","")
        seq_text=seq_text.replace("<question>","")
        seq_text=seq_text.replace("<uestion>","")
        seq_text=seq_text.replace("question>","")
        seq_text=seq_text.replace("<context>","")
        seq_text=seq_text.replace("<ontext>","")
        seq_text=seq_text.replace("context>","")
        return seq_text
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels
    scorer = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeL'], use_stemmer=True, lang="arabic")
    bleu = evaluate.load("bleu")
    def Average(lst):
        return sum(lst) / len(lst)
    def calc_rougeL(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rougeL"].fmeasure
    def calc_rouge1(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rouge1"].fmeasure
    def calc_rouge2(sent1,sent2):
        scores = scorer.score(sent1,sent2)
        return scores["rouge2"].fmeasure
    def compute_metrics(preds, labels):
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred_dict=pd.DataFrame(columns=["label","pred"])
        for label_a,pred_a in zip(decoded_labels,decoded_preds):
         pred_dict=pd.concat([pred_dict, pd.DataFrame([{"label":label_a,"pred":pred_a}])], ignore_index=True)
        pred_dict.to_csv("pred.csv",index=False)
        rougeL_score_list=[]
        rouge1_score_list=[]
        rouge2_score_list=[]
        for pred,label in zip(decoded_preds,decoded_labels):
           rougeL_score_list.append(calc_rougeL(pred,label))
           rouge1_score_list.append(calc_rouge1(pred,label))
           rouge2_score_list.append(calc_rouge2(pred,label))
        rougeL_score=Average(rougeL_score_list)
        rouge1_score=Average(rouge1_score_list)
        rouge2_score=Average(rouge2_score_list)
        #result = metric.compute(predictions=decoded_preds, references=decoded_labels,tokenizer=lambda x: x.split())
        result={}
        result["rouge1"]=rouge1_score
        result["rouge2"]=rouge2_score
        result["rougeL"]=rougeL_score
        result["bleu"]=bleu.compute(predictions=decoded_preds, references=decoded_labels)["bleu"]
        result = {k: round(v * 100, 4) for k, v in result.items()}
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        return result

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set(
            [
                layer[-2:]
                for layer_norm_name in layer_norm_candidates
                for layer in flat_params.keys()
                if layer_norm_name in "".join(layer).lower()
            ]
        )
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)

    # label smoothed cross entropy
    def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
        """
        The label smoothing implementation is adapted from Flax's official example:
        https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
        """
        vocab_size = logits.shape[-1]
        confidence = 1.0 - label_smoothing_factor
        low_confidence = (1.0 - confidence) / (vocab_size - 1)
        normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
        )
        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

        loss = optax.softmax_cross_entropy(logits, soft_labels)
        loss = loss - normalizing_constant

        # ignore padded tokens from loss
        loss = loss * padding_mask
        loss = loss.sum()
        num_labels = padding_mask.sum()
        return loss, num_labels

    # Define gradient update step fn
    def train_step(state, batch, label_smoothing_factor=0.0):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
            return loss, num_labels

        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
        (loss, num_labels), grad = grad_fn(state.params)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        # true grad = total grad / total samples
        grad = jax.lax.psum(grad, "batch")
        grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
        new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch, label_smoothing_factor=0.0):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]

        loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        metrics = {"loss": loss}
        return metrics

    # Define generation function
    max_length = (
        data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
    )
    num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
    gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

    def generate_step(params, batch):
        model.params = params
        output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
        return output_ids.sequences

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(
        partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
    )
    p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
    p_generate_step = jax.pmap(generate_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})",position=0, leave=True)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for _ in tqdm(range(steps_per_epoch), desc="Training...", position=0, leave=True):
            batch = next(train_loader)
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
            f" {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_preds = []
        eval_labels = []

        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
        eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
        for _ in tqdm(range(eval_steps), desc="Evaluating...", position=0, leave=True):
            # Model forward
            batch = next(eval_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            eval_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                eval_labels.extend(labels)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(eval_preds, eval_labels)
            eval_log={"Eval_" + str(key): val for key, val in rouge_metrics.items()}
            wandb.log(eval_log)
            eval_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics and update progress bar
        loss_score=round(float(eval_metrics['loss']),4)
        desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {loss_score} | {rouge_desc})"
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0 and epoch == int(training_args.num_train_epochs)-1 and 7==9:
            params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params)
            tokenizer.save_pretrained(training_args.output_dir)
            if training_args.push_to_hub:
                repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

    # ======================== Prediction loop ==============================
    if training_args.do_predict:
        logger.info("*** Predict ***")

        pred_metrics = []
        pred_generations = []
        pred_labels = []

        pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
        pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
        for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=True):
            # Model forward
            batch = next(pred_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            pred_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                pred_labels.extend(labels)

        # normalize prediction metrics
        pred_metrics = get_metrics(pred_metrics)
        pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(pred_generations, pred_labels)
            eval_log={"Test_" + str(key): val for key, val in rouge_metrics.items()}
            wandb.log(eval_log)
            pred_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics
        desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
        logger.info(desc)

        # save final metrics in json
        if jax.process_index() == 0:
            rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
            path = os.path.join(training_args.output_dir, "test_results.json")
            with open(path, "w") as f:
                json.dump(rouge_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    main()


Writing t5_summary.py


In [None]:
!python3 t5_summary.py --output_dir out \
--model_name_or_path sultan/ArabicT5-Base \
--tokenizer_name sultan/ArabicT5-Base \
--train_file=XL-Sum-Train.csv \
--validation_file=XL-Sum-Dev.csv \
--test_file=XL-Sum-Test.csv \
--do_train \
--do_eval \
--do_predict \
--predict_with_generate \
--num_train_epochs 11 \
--overwrite_cache \
--learning_rate 2e-4 \
--warmup_steps 0 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 3 \
--overwrite_output_dir \
--max_source_length 512 \
--max_target_length 64

INFO:__main__:Training/evaluation parameters TrainingArguments(output_dir='out', overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=True, per_device_train_batch_size=3, per_device_eval_batch_size=3, learning_rate=0.0002, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, label_smoothing_factor=0.0, adafactor=False, num_train_epochs=11.0, warmup_steps=0, logging_steps=500, save_steps=500, eval_steps=None, eval_per_epoch=None, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None, gradient_checkpointing=False)
Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-f5859a8dfb27cfd3/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...
Downloading data files: 100% 3/3 [00:00<00:00, 8060.80it/s]
Extracting data files: 100% 3/3 [00:00<00:00, 1196.89it/s]
Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-f5859a8dfb27cfd3/0.0.0/6954658bab30a358235fa864b05c

Let's now see our predictions (labels on the left and predictions on the right).

In [None]:
!tail pred.csv | sed  's/,/ |||| /g'

أجلت السلطات العراقية أكثر من ألف مدني من القرى الموجودة على جبهة القتال الأمامية مع مواقع سيطرة تنظيم الدولة الإسلامية في معقله بمدينة الموصل. |||| قالت الأمم المتحدة إن نحو مليون شخص فروا من مدينة الموصل العراقية، التي يسيطر عليها مسلحو تنظيم الدولة الإسلامية.
أدانت المكسيك القواعد الجديدة للهجرة التي أعلنت عنها الولايات المتحدة لترحيل المهاجرين حتى من هم غير المدانين. |||| أثارت الولايات المتحدة جدلا بشأن خطط أمريكية جديدة للهجرة، في وقت تشهد فيه البلاد جدلا بشأن خطط جديدة للهجرة.
عرض المتحدث باسم جماعة أنصار الله الحوثية في اليمن ما قال إنه تفاصيل العملية العسكرية التي استهدفت قوات التحالف الذي تقوده السعودية وذلك في المنطقة الحدودية بين البلدين. |||| قال المتحدث باسم القوات المسلحة اليمنية إن القوات الحوثية قتلت أكثر من 100 من جنودها في عملية عسكرية واسعة بمنطقة نجران، جنوبي اليمن.
استكمل الخبراء في مصر ترميم الكنيسة المعلقة في القاهرة بعد أعمال ترميم استغرقت 16 عاما. |||| افتتح رئيس الوزراء المصري، هشام قنديل، كنيسة جديدة في القاهرة، بعد ترميمها لمدة 14 عاما.
يخشى خبراء في مجال ا

Let's verify our rougeL score using an official evaluating script from Xl-Sum:

https://github.com/csebuetnlp/xl-sum/tree/master/multilingual_rouge_scoring


We will use our final prediction on the Test dataset, which our code stores in pred.csv.

Our pred.csv will store predictions for both the dev and test datasets, but since our last evaluation was on the test dataset, it corresponds to the test dataset.

In [None]:
import pandas as pd
predictions=[]
references=[]
pred=pd.read_csv("pred.csv")
with open('pred.txt', 'w', encoding='utf-8') as p:
  with open('target.txt', 'w', encoding='utf-8') as t:
   for index,row in pred.iterrows():
    pred_value=row["pred"].strip().replace('\n', ' ').replace('\r', '')
    predictions.append(pred_value)
    p.write(pred_value+"\n")
    label_value=row["label"].strip().replace('\n', ' ').replace('\r', '')
    references.append(label_value)
    t.write(label_value+"\n")

In [None]:
!python -m rouge_score.rouge --target_filepattern=target.txt \
--prediction_filepattern=pred.txt \
--output_filename=scores.csv \
--use_stemmer=True \
--lang="arabic"

I0613 17:42:17.584808 140595781957440 io.py:108] Reading targets from target.txt.
I0613 17:42:17.585003 140595781957440 io.py:109] Reading predictions from pred.txt.
I0613 17:42:24.185865 140595781957440 io.py:135] Writing results to scores.csv.
I0613 17:42:24.186297 140595781957440 io.py:148] Finished writing results.


In [None]:
!cat /content/scores.csv | grep -E "score|F"

score_type,low,mid,high
rouge1-F,0.344347,0.348071,0.352368
rouge2-F,0.145052,0.148253,0.151825
rougeL-F,0.287075,0.290823,0.294424


compare F-scores (mid) from cell above with what we got from our predictions on the test dataset which is as follow :

`INFO:__main__:Predict Loss: 2.512944221496582 | Predict rouge1: 34.8084 | Predict rouge2: 14.8332 | Predict rougeL: 29.0797 | Predict bleu: 7.5588 | Predict gen_len: 27.30688846235871 |)`


without stemmer

In [None]:
!python -m rouge_score.rouge --target_filepattern=target.txt \
--prediction_filepattern=pred.txt \
--output_filename=scores.csv \
--use_stemmer=False \
--lang="arabic"

I0613 17:44:37.366029 140330213648192 io.py:108] Reading targets from target.txt.
I0613 17:44:37.366205 140330213648192 io.py:109] Reading predictions from pred.txt.
I0613 17:44:41.766485 140330213648192 io.py:135] Writing results to scores.csv.
I0613 17:44:41.766892 140330213648192 io.py:148] Finished writing results.


In [None]:
!cat /content/scores.csv | grep -E "score|F"

score_type,low,mid,high
rouge1-F,0.282811,0.286766,0.290816
rouge2-F,0.122022,0.125431,0.128796
rougeL-F,0.244366,0.247818,0.251644


If you want to disable stemmer in our T5 code, change the line 706 in t5_summary.py file:

`scorer = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeL'], use_stemmer=False, lang="arabic")`

# **Running Flax Code on Text Classificaiton (Sentiment Analysis)**

In this example, we will finetune ArabicT5 on ArSarcasm-v2, which is part of WANLP 2021 shared task on sarcasm and sentiment detection in Arabic. ArSarcasm-v2 has two tasks :
 * Sentimental Analysis (POS, NEG, NEU)
 * Scarcasm (True, False)

Here we will focus on the Sentimental Analysis task. The official evaluation metric for the Sentimental Analysis task is F1 PN which only counts POS and NEG classes.

Refer to this paper to get reported results of other Arabic NLP models:

[Benchmarking Transformer-based Language Models for Arabic Sentiment and Sarcasm Detection](https://aclanthology.org/2021.wanlp-1.3/).  Please note that seq2seq models like GPT-2, BART, and T5 are mainly designed to target generative tasks like Machine Translation, summarization, and Question Generation. However, here we extend the evaluation of T5 models to Text Classification tasks to show another case of T5 applications. Abstractive (extractive) language models like AraBERT, AraELECTRA, and our model [ArabicTransformer](https://aclanthology.org/2021.findings-emnlp.108/) should perform better in these tasks.


Here we made a small change to the Question Answering FLAX code in line 665 by having this code :



```
for tweet,sentiment in zip(examples["tweet"],examples["sentiment"]):
            inputs.append(tweet)
            targets.append(sentiment)
```
and we also added a function to calculate F1 PN score for in line 603 :



```
def calc_scarcasm_score(y_pred,y_true):
        accuracy=accuracy_score(y_true, y_pred)
        f1_pn=f1_score(y_true, y_pred,labels=['negative','positive'],average="macro")
        return accuracy,f1_pn
```


and use this function in line 721

```
result={}
accuracy,f1_pn=calc_scarcasm_score(decoded_preds,decoded_labels)
result["Accuracy"]=accuracy
result["F1_PN"]=f1_pn
```



In [None]:
#@title T5 FLAX Text Classification
%%writefile t5_text_class.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for summarization.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
from sklearn.metrics import f1_score,classification_report,accuracy_score
import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Callable, Optional
from sklearn.metrics import f1_score
from collections import Counter
import string
import re
import csv
import pandas as pd
from jax.lib import xla_bridge
import datasets
import nltk  # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import (
    CONFIG_MAPPING,
    FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSeq2SeqLM,
    HfArgumentParser,
    is_tensorboard_available,
)
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry


logger = logging.getLogger(__name__)
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class TrainingArguments:
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
    do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
    adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
    adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
    adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
    label_smoothing_factor: float = field(
        default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
    )
    adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
    warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
    logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
    save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
    eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
    eval_per_epoch: int = field(default=None, metadata={"help": "Run an evaluation every X epochs"})
    seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
    push_to_hub: bool = field(
        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
    )
    hub_model_id: str = field(
        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
    )
    hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
    gradient_checkpointing: bool = field(
        default=False,
        metadata={
            "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
        },
    )

    def __post_init__(self):
        if self.output_dir is not None:
            self.output_dir = os.path.expanduser(self.output_dir)

    def to_dict(self):
        """
        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
        the token values by removing their value.
        """
        d = asdict(self)
        for k, v in d.items():
            if isinstance(v, Enum):
                d[k] = v.value
            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
                d[k] = [x.value for x in v]
            if k.endswith("_token"):
                d[k] = f"<{k.upper()}>"
        return d


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": (
                "Floating-point format in which the model weights should be initialized and trained. Choose one of"
                " `[float32, float16, bfloat16]`."
            )
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    text_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    summary_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                "This argument is also used to override the `max_length` param of `model.generate`, which is used "
                "during evaluation."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    source_prefix: Optional[str] = field(
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )
    predict_with_generate: bool = field(
        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
    )
    num_beams: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
                "which is used during evaluation."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length


summarization_name_mapping = {
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
    "cnn_dailymail": ("article", "highlights"),
    "orange_sum": ("text", "summary"),
    "pn_summary": ("article", "summary"),
    "psc": ("extract_text", "summary_text"),
    "samsum": ("dialogue", "summary"),
    "thaisum": ("body", "summary"),
    "xglue": ("news_body", "news_title"),
    "xsum": ("document", "summary"),
    "wiki_summary": ("article", "highlights"),
}


class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch


def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval_{metric_name}", value, step)


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_summarization", model_args, data_args, framework="flax")

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name, token=training_args.hub_token
            )
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files this script will use the first column for the full texts and the second column for the
    # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
    #
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            keep_in_memory=False,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        dataset = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            from_pt=True,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = FlaxAutoModelForSeq2SeqLM.from_config(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    if training_args.gradient_checkpointing:
        model.enable_gradient_checkpointing()

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    elif training_args.do_eval:
        column_names = dataset["validation"].column_names
    elif training_args.do_predict:
        column_names = dataset["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

    # Get the column names for input/target.
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    # In Flax, for seq2seq models we need to pass `decoder_input_ids`
    # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
    # for that dynamically import the `shift_tokens_right` function from the model file
    model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
    shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

    # Setting padding="max_length" as we need fixed length inputs for jitted functions
    def calc_scarcasm_score(y_pred,y_true):
        accuracy=accuracy_score(y_true, y_pred)
        f1_pn=f1_score(y_true, y_pred,labels=['negative','positive'],average="macro")
        return accuracy,f1_pn



    def preprocess_function(examples):
        inputs=[]
        targets=[]

        for tweet,sentiment in zip(examples["tweet"],examples["sentiment"]):
            inputs.append(tweet)
            targets.append(sentiment)
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np")
        labels = tokenizer(
            text_target=targets,
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="np",
        )
        model_inputs["labels"] = labels["input_ids"]
        if "MBartModel" in str(config.architectures):
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id)
        else:
            decoder_input_ids = shift_tokens_right_fn(labels["input_ids"], config.pad_token_id, config.decoder_start_token_id)
        model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)

        # We need decoder_attention_mask so we can ignore pad tokens from loss
        model_inputs["decoder_attention_mask"] = labels["attention_mask"]

        return model_inputs


    if training_args.do_train:
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = dataset["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset",
        )

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = dataset["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        eval_dataset = eval_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset",
        )

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in dataset:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = dataset["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        predict_dataset = predict_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on prediction dataset",
        )

    # Metric
    def norm_seq(seq_text):
        seq_text=seq_text.replace("<answer>","")
        seq_text=seq_text.replace("<nswer>","")
        seq_text=seq_text.replace("answer>","")
        seq_text=seq_text.replace("<title>","")
        seq_text=seq_text.replace("<itle>","")
        seq_text=seq_text.replace("title>","")
        seq_text=seq_text.replace("<question>","")
        seq_text=seq_text.replace("<uestion>","")
        seq_text=seq_text.replace("question>","")
        seq_text=seq_text.replace("<context>","")
        seq_text=seq_text.replace("<ontext>","")
        seq_text=seq_text.replace("context>","")
        return seq_text
    def postprocess_text(preds, labels):
        preds = [norm_seq(pred.strip()) for pred in preds]
        labels = [norm_seq(label.strip()) for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

    def compute_metrics(preds, labels):
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred_dict=pd.DataFrame(columns=["label","pred"])
        for label_a,pred_a in zip(decoded_labels,decoded_preds):
         pred_dict=pd.concat([pred_dict, pd.DataFrame([{"label":label_a,"pred":pred_a}])], ignore_index=True)
        pred_dict.to_csv("pred.csv",index=False)
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result={}
        accuracy,f1_pn=calc_scarcasm_score(decoded_preds,decoded_labels)
        result["Accuracy"]=accuracy
        result["F1_PN"]=f1_pn
        result = {k: round(v * 100, 4) for k, v in result.items()}
        return result

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set(
            [
                layer[-2:]
                for layer_norm_name in layer_norm_candidates
                for layer in flat_params.keys()
                if layer_norm_name in "".join(layer).lower()
            ]
        )
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)

    # label smoothed cross entropy
    def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
        """
        The label smoothing implementation is adapted from Flax's official example:
        https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
        """
        vocab_size = logits.shape[-1]
        confidence = 1.0 - label_smoothing_factor
        low_confidence = (1.0 - confidence) / (vocab_size - 1)
        normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
        )
        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

        loss = optax.softmax_cross_entropy(logits, soft_labels)
        loss = loss - normalizing_constant

        # ignore padded tokens from loss
        loss = loss * padding_mask
        loss = loss.sum()
        num_labels = padding_mask.sum()
        return loss, num_labels

    # Define gradient update step fn
    def train_step(state, batch, label_smoothing_factor=0.0):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
            return loss, num_labels

        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
        (loss, num_labels), grad = grad_fn(state.params)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        # true grad = total grad / total samples
        grad = jax.lax.psum(grad, "batch")
        grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
        new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch, label_smoothing_factor=0.0):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]

        loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
        num_labels = jax.lax.psum(num_labels, "batch")

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        metrics = {"loss": loss}
        return metrics

    # Define generation function
    max_length = (
        data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
    )
    num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
    gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

    def generate_step(params, batch):
        model.params = params
        output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
        return output_ids.sequences

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(
        partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
    )
    p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
    p_generate_step = jax.pmap(generate_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0, leave=True)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for _ in tqdm(range(steps_per_epoch), desc="Training...", position=0, leave=True):
            batch = next(train_loader)
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
            f" {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_preds = []
        eval_labels = []

        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
        eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
        for _ in tqdm(range(eval_steps), desc="Evaluating...", position=0, leave=True):
            # Model forward
            batch = next(eval_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            eval_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                eval_labels.extend(labels)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(eval_preds, eval_labels)
            eval_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics and update progress bar
        loss_score=round(float(eval_metrics['loss']),4)
        desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {loss_score} | {rouge_desc})"
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0 and epoch == int(training_args.num_train_epochs)-1:
            params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params)
            tokenizer.save_pretrained(training_args.output_dir)
            if training_args.push_to_hub:
                repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

    # ======================== Prediction loop ==============================
    if training_args.do_predict:
        logger.info("*** Predict ***")

        pred_metrics = []
        pred_generations = []
        pred_labels = []

        pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
        pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
        for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=True):
            # Model forward
            batch = next(pred_loader)
            labels = batch["labels"]

            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
            pred_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                pred_labels.extend(labels)

        # normalize prediction metrics
        pred_metrics = get_metrics(pred_metrics)
        pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(pred_generations, pred_labels)
            pred_metrics.update(rouge_metrics)
            rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])

        # Print metrics
        desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
        logger.info(desc)

        # save final metrics in json
        if jax.process_index() == 0:
            rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
            path = os.path.join(training_args.output_dir, "test_results.json")
            with open(path, "w") as f:
                json.dump(rouge_metrics, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    main()


Overwriting t5_text_class.py


In [None]:
!wget -O arsarcasm-v2_dev.csv https://raw.githubusercontent.com/iabufarha/ArSarcasm-v2/main/ArSarcasm-v2/testing_data.csv
!wget -O arsarcasm-v2_train.csv https://raw.githubusercontent.com/iabufarha/ArSarcasm-v2/main/ArSarcasm-v2/training_data.csv
'''
we find an issue with T5 model when label are ["POS","NEG","NEU"] and model would truncate last letter so
we will replace them to full words. This issue may caused by how our tokenizer handle non-arabic characters. We will fix this issue soon.
This should not affect the evaluation or training on this task
'''
### we will
!sed -i 's/POS/positive/g' arsarcasm-v2_dev.csv
!sed -i 's/NEG/negative/g' arsarcasm-v2_dev.csv
!sed -i 's/NEU/neutral/g' arsarcasm-v2_dev.csv
!sed -i 's/POS/positive/g' arsarcasm-v2_train.csv
!sed -i 's/NEG/negative/g' arsarcasm-v2_train.csv
!sed -i 's/NEU/neutral/g' arsarcasm-v2_train.csv

--2023-06-13 17:56:25--  https://raw.githubusercontent.com/iabufarha/ArSarcasm-v2/main/ArSarcasm-v2/testing_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 585081 (571K) [text/plain]
Saving to: ‘arsarcasm-v2_dev.csv’


2023-06-13 17:56:25 (13.5 MB/s) - ‘arsarcasm-v2_dev.csv’ saved [585081/585081]

--2023-06-13 17:56:25--  https://raw.githubusercontent.com/iabufarha/ArSarcasm-v2/main/ArSarcasm-v2/training_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2393168 (2.3M) [text/plain]
Saving to: ‘arsarcasm-v2_train.csv

In [None]:
!python3 t5_text_class.py --output_dir out --model_name_or_path sultan/ArabicT5-Base --tokenizer_name sultan/ArabicT5-Base --train_file=arsarcasm-v2_train.csv --validation_file=arsarcasm-v2_dev.csv --do_train --do_eval --predict_with_generate --num_train_epochs 6 --learning_rate 1e-4 --warmup_steps 0 --per_device_train_batch_size 2 --per_device_eval_batch_size 4 --overwrite_output_dir --max_source_length 128 --max_target_length 8

INFO:__main__:Training/evaluation parameters TrainingArguments(output_dir='out', overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, per_device_train_batch_size=2, per_device_eval_batch_size=4, learning_rate=0.0001, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, label_smoothing_factor=0.0, adafactor=False, num_train_epochs=6.0, warmup_steps=0, logging_steps=500, save_steps=500, eval_steps=None, eval_per_epoch=None, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None, gradient_checkpointing=False)
100% 2/2 [00:00<00:00, 725.85it/s]
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--sultan--ArabicT5-Base/snapshots/3bb572922fd98ff6eee794b8d58da3fd5d310814/config.json
Model config T5Config {
  "_name_or_path": "sultan/ArabicT5-Base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",


refer to this paper for SOTA results on this task [Benchmarking Transformer-based Language Models for Arabic Sentiment and Sarcasm Detection.](https://aclanthology.org/2021.wanlp-1.3/) <- Table 2 - Sentiment Analysis section (FPN score)