# Title: Text Alignment Is An Efficient Unified Model for Massive NLP Tasks

#### Members' Names or Individual's Name: Protik Mukherjee, Tara Shingadia

####  Emails: protik.mukherjee@torontomu.ca, tshingadia@torontomu.ca

# Introduction:

#### Problem Description:

LLMs show great generalizability across various NLP tasks with next-word prediction.
Efficiency in specific tasks remains a concern; not always the best solution.
Models require scaling up to tens of billions of parameters for significant performance, e.g., GPT-3 with 175B parameters (Zha et al.[1]).
Despite their size, LLMs can be outperformed by smaller, finetuned models in classical NLP tasks.

#### Context of the Problem:

The importance is the critical need to balance between generality and efficiency in model design. This balance is essential because it directly impacts the usability, performance, and applicability of models across diverse tasks and domains. Efficiently designed models can achieve superior performance and are more practical for specific applications, thereby addressing the limitations of large-scale, less specialized models in terms of resource consumption and adaptability to varied tasks.

#### Limitation About other Approaches:

Previous work focuses on building natural language inference (NLI) models for broad tasks.
Limited availability of NLI data (e.g., MNLI) for training leads to restricted performance and applicability across various domains.
Another line of research involves training general text representation models using pretraining and multi-task learning.
These models require specific finetuning, which can be resource-demanding, for each downstream task with task-specific heads, rather than serving as plug-and-play solutions.

#### Solution:

1. Unified Approach for Various NLP Tasks: It efficiently handles a broad spectrum of tasks involving text relationships, such as NLI, question answering, and semantic textual similarity, with a single model.

2. Efficiency with Smaller-Scale Model: Utilizes a smaller-scale language model (e.g., RoBERTa with 355M parameters) to achieve high performance across diverse tasks, demonstrating that efficiency and effectiveness are achievable without massive scaling.

3. Diverse Training Data Utilization: Employs a rich and diverse dataset compilation from 28 different NLP tasks for training, enhancing the model's robustness and applicability across various domains.

4. Superior Performance and Real-World Application: Shows competitive or better performance than larger models on a wide range of tasks and significantly enhances existing LLMs’ abilities in practical applications like factual consistency evaluation and question answerability verification.

# Background



| Reference |Explanation |  Dataset/Input |Weakness
| --- | --- | --- | --- |
| Aghajanyan et al. [2] | They pre-finetuned LM's to encourage learning more general representations| 50 NLU datasets | When MTL is < 15 datasets it is detrimental for end-task finetuning
| Liu et al. [3] | Trained a BERT with a focus on learning four types of tasks, single-sentence classification, pairwise text classification, text similarity scoring, and relevance ranking | GLUE, SNLI and SciTail | Only 65.1% accuracy for WNLI tasks
| Zha et al. [1] | They aimed to train an efficient unified model using text-alignment| 28 NLU datasets | Evaluates text sentences individually which could be time consuming for large text.



The paper by Aghajanyan et al. [2] introduces a novel stage in model training known as "pre-finetuning," which trains models on a broad array of around 50 tasks, including over 4.8 million instances in areas such as classification, summarization, and question answering, prior to standard fine-tuning. This step aims to boost the general capabilities of language models like RoBERTa and BART.

Utilizing this extensive multi-task learning strategy results in significant enhancements in model efficiency and performance across various tasks, notably in sentence prediction and commonsense reasoning. Crucially, pre-finetuning enables these models to perform better with much less data in subsequent fine-tuning stages.

The study also emphasizes the importance of the number of tasks involved in pre-finetuning—too few tasks can degrade performance, whereas a larger set of tasks consistently improves outcomes. Additionally, the research discusses essential optimization techniques, including loss scaling and heterogeneous batch processing, which are vital for the successful application of this training method.establishes the effectiveness of pre-finetuning as a means to improve the adaptability and efficiency of neural language models, setting new state-of-the-art benchmarks on several NLU tasks without necessitating specific intermediate tasks.

The study by Liu et al. [3] introduces the Multi-Task Deep Neural Network (MT-DNN), an advanced model that integrates the advantages of multi-task learning (MTL) and language model pre-training to enhance the representation learning across various natural language understanding (NLU) tasks. MT-DNN incorporates a pre-trained BERT, a bidirectional transformer language model, to achieve state-of-the-art results on ten NLU tasks such as SNLI and SciTail, as well as on most GLUE benchmark tasks, surpassing previous models with a significant margin.

MT-DNN is particularly effective in domain adaptation, requiring fewer in-domain labels to adapt to new tasks compared to using BERT alone. This efficiency is demonstrated through substantial improvements in model performance even when limited training data is available. The paper highlights the crucial role of the number of tasks in pre-finetuning; a larger number of tasks correlates with better performance and generalization of the model.

Moreover, the study explores various optimization techniques like loss scaling and heterogeneous batch processing, essential for the training process's success. The combination of MTL and pre-training not only improves the efficiency and effectiveness of the learning process but also enhances the model's ability to generalize across different domains and tasks, showcasing MT-DNN's robust adaptability and potential for practical applications in diverse NLU scenarios.






# Methodology

We are implementing a text alignment model, we will refer to this as the ALIGN model. The model will be designed to handle various NLP tasks by aligning textual information between two documents. It operates by evaluating how well the content of one text aligns with or contradicts the information in another text.

ALIGN is implemented by finetuning pre-existing language models, specifically RoBERTa with 355M parameters, using 28 different datasets with 5.9 million examples drawn from the datasets. The datasets cover a variety of NLP tasks, including natural learning inference (NLI), fact verification, and semantic textual similarity to enhance the models training process.

Predict and Score methods are used to test the functionality of ALIGN. Predict method produces alignment evaluations between pairs of text to generate regression scores to measure the degree of alignment (Yreg), binary scores that distinguish between aligned and not-aligned (Pr(Ybin)), and categorical scores; aligned, contradicted, neutral (Pr(Y3way))for more complex NLP tasks such as entailment.

![Binary Text Pair Alignment](https://drive.google.com/uc?export=view&id=1DMDX0fOKQ_OSa7Dy07Col_53Pm59qd2Q)   

Binary: Pr(Ybin) = (aligned, not-aligned)

3 Way: Pr(Y3way) = (aligned, contradicted, neutral)

Regression: yreg ∈ [0, 1] = (real-valued score)




The Score method refines evaluations by adjusting parameters like the type of alignment assessment (i.e. binary, regression) and the use of split-then-aggregate technique to manage longer texts. The adjustment improves the models suitability for real-world applications, where the input lengths often surpass standard processing capacities.

![Alignment](https://drive.google.com/uc?export=view&id=1wiDvL8wez20cNU0ZAP4ciy2DNQ5AYh7d)

With the training of the model, the classification heads are trained with cross-entropy loss, while the regression head is trained with mean square error loss. Then the losses are aggregated as a weighted sum.

![Loss](https://drive.google.com/uc?export=view&id=1wZQ-lljdDJ7UiGETZvJjF093sQLMqHPe)


To assess the effectiveness of the ALIGN model, comprehensive testing was done on multiple datasets (Table 1), covering a variety of tasks for which the model was designed. The evaluation focused on comparing the ALIGN performance with large models like FLAN-T5 and specialized task-specific models, evaluating both its performance on familiar tasks and its ability to generalize new unseen tasks.  

![Align Performance](https://drive.google.com/uc?export=view&id=1XF27pn5QFEEouQK3QpDcHz5dRkL6eV-_)




# Implementation



**Model Architecture**

**Base Model:** The ALIGN model utilizes RoBERTa, a robustly optimized version of BERT, known for its effectiveness in various NLP benchmarks. RoBERTa itself is an attention-based model that benefits from a more extensive training corpus and longer training compared to BERT. Choosing RoBERTa as the backbone provides ALIGN with a strong foundation in language understanding.




**Adaptations for Alignment:** Unlike RoBERTa, which primarily handles tasks like classification and regression, ALIGN is adapted to specifically assess alignment between text pairs. This involves:

**Predicting Relationships:** The model outputs predictions on whether two text segments are aligned, contradict each other, or are unrelated (neutral). This tripartite output is crucial for tasks like entailment and fact-checking.

**Custom Heads:** For alignment tasks, the model uses custom heads attached to the base RoBERTa model. These heads are trained to predict the type of relationship between text pairs, leveraging the contextual representations learned by RoBERTa.

In [None]:
# Importing necessary libraries and modules
from .inference import Inferencer
from typing import List, Tuple
import torch

# Define a class to handle alignment between contexts and claims using a model
class Align:
    def __init__(self, model: str, batch_size: int, device: int, ckpt_path: str, verbose=True) -> None:
        """
        Initializes the Align class with an inference model.

        Args:
            model (str): The model type/name used for inference.
            batch_size (int): The number of samples to process in one go.
            device (int): The device ID to run the model on (e.g., GPU ID).
            ckpt_path (str): Path to the model's checkpoint file.
            verbose (bool, optional): Flag to enable verbose logging. Default is True.
        """
        # Initialize the Inferencer with the provided arguments
        self.model = Inferencer(
            ckpt_path=ckpt_path,
            model=model,
            batch_size=batch_size,
            device=device,
            verbose=verbose
        )

In [None]:
def predict(self, contexts: List[str], claims: List[str]) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """
        Predict the alignment labels for context and claim pairs.

        Args:
            contexts (List[str]): A list of contexts.
            claims (List[str]): A list of claims.

        Returns:
            Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: A tuple of prediction scores
            `(regression_score, binary_score, nli_scores)`.
            `regression_score` and `binary_score` both have shape (N,).
            `nli_scores` has shape (N, 3), representing three different class scores for each context-claim pair.
        """
        # Set model evaluation mode (not provided in the snippet, assumed functionality)
        self.model.nlg_eval_mode = None
        # Invoke the inference method of the model with contexts and claims
        return self.model.inference(contexts, claims)

In [None]:
def score(self, contexts: List[str], claims: List[str], head: str = 'nli', split: bool = True) -> torch.FloatTensor:
        """
        Calculate the alignment scores between context and claim pairs using the specified prediction head and aggregation method.

        Args:
            contexts (List[str]): A list of contexts.
            claims (List[str]): A list of claims.
            head (str, optional): The prediction head to use ('nli' for natural language inference, 'bin' for binary, or 'reg' for regression).
                                  Defaults to 'nli'.
            split (bool, optional): Whether to split and then aggregate long inputs. If set to False, the model will truncate oversized inputs.
                                    Defaults to True.

        Returns:
            torch.FloatTensor: Alignment scores for the input pairs.
        """
        # Set the evaluation mode based on the head and split settings
        self.model.nlg_eval_mode = head + ('_sp' if split else '')
        # Perform the evaluation and return only the scores (assuming the second return value from nlg_eval contains scores)
        return self.model.nlg_eval(contexts, claims)[1]

In [None]:
import math
import torch
import torch.nn as nn
import pytorch_lightning as pl
from transformers import AdamW, get_linear_schedule_with_warmup, AutoConfig
from transformers import BertModel, BertForPreTraining, RobertaModel, RobertaForMaskedLM, AlbertModel, AlbertForMaskedLM
from sklearn.metrics import f1_score
from dataclasses import dataclass

class BERTAlignModel(pl.LightningModule):
    """
    A PyTorch Lightning module for the BERTAlignModel which incorporates different transformer models
    such as BERT, RoBERTa, ALBERT, and ELECTRA for various NLP tasks including MLM, sequence classification,
    and regression tasks.
    """
    def __init__(self, model='bert-base-uncased', using_pretrained=True, *args, **kwargs):
        """
        Initializes the BERTAlignModel.

        Args:
            model (str): Name of the model to use.
            using_pretrained (bool): Whether to load pretrained weights.
        """
        super().__init__()
        self.save_hyperparameters()  # saves all constructor arguments into self.hparams
        self.model = model

        # Depending on the model type, initialize different architectures with optional pretraining
        if 'muppet' in model:
            assert using_pretrained == True, "Only support pretrained muppet!"
            self.base_model = RobertaModel.from_pretrained(model)
            self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head

        elif 'roberta' in model:
            self.base_model = RobertaModel.from_pretrained(model) if using_pretrained else RobertaModel(AutoConfig.from_pretrained(model))
            self.mlm_head = RobertaForMaskedLM.from_pretrained(model).lm_head if using_pretrained else RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head

        elif 'albert' in model:
            self.base_model = AlbertModel.from_pretrained(model) if using_pretrained else AlbertModel(AutoConfig.from_pretrained(model))
            self.mlm_head = AlbertForMaskedLM.from_pretrained(model).predictions if using_pretrained else AlbertForMaskedLM(AutoConfig.from_pretrained(model)).predictions

        elif 'bert' in model:
            self.base_model = BertModel.from_pretrained(model) if using_pretrained else BertModel(AutoConfig.from_pretrained(model))
            self.mlm_head = BertForPreTraining.from_pretrained(model).cls.predictions if using_pretrained else BertForPreTraining(AutoConfig.from_pretrained(model)).cls.predictions

        elif 'electra' in model:
            # For ELECTRA, initialize both generator and discriminator
            self.generator = BertModel(AutoConfig.from_pretrained('prajjwal1/bert-small'))
            self.generator_mlm = BertForPreTraining(AutoConfig.from_pretrained('prajjwal1/bert-small')).cls.predictions
            self.base_model = BertModel(AutoConfig.from_pretrained('bert-base-uncased'))
            self.discriminator_predictor = ElectraDiscriminatorPredictions(self.base_model.config)

        # Additional output layers for classification and regression tasks
        self.bin_layer = nn.Linear(self.base_model.config.hidden_size, 2)
        self.tri_layer = nn.Linear(self.base_model.config.hidden_size, 3)
        self.reg_layer = nn.Linear(self.base_model.config.hidden_size, 1)

        self.dropout = nn.Dropout(p=0.1)

        # Flags to determine specific behavior in forward pass
        self.need_mlm = True
        self.is_finetune = False
        self.mlm_loss_factor = 0.5

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, batch):
        """
        Defines the forward pass of the model.

        Args:
            batch (dict): Input batch containing all required data for computation.

        Returns:
            ModelOutput: Output object containing all computed outputs including logits and losses.
        """
        # Special handling for the ELECTRA model forward pass
        if 'electra' in self.model:
            return self.electra_forward(batch)

        # Regular forward pass for other models
        base_model_output = self.base_model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            token_type_ids=batch['token_type_ids'] if 'token_type_ids' in batch else None
        )

        # Obtain outputs for masked language modeling
        prediction_scores = self.mlm_head(base_model_output.last_hidden_state)
        # Outputs for binary and tertiary classification tasks
        seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output))
        tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
        # Output for regression task
        reg_label_score = self.reg_layer(base_model_output.pooler_output)

        total_loss = None
        if 'mlm_label' in batch:
            # Compute losses for various outputs, if labels are provided in the batch
            ce_loss_fct = nn.CrossEntropyLoss(reduction='sum')
            masked_lm_loss = ce_loss_fct(prediction_scores.view(-1, self.base_model.config.vocab_size), batch['mlm_label'].view(-1))
            next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1)) / math.log(2)
            tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1)) / math.log(3)
            reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1), reduction='sum')

            # Count number of labels for normalization
            masked_lm_loss_num = torch.sum(batch['mlm_label'].view(-1) != -100)
            next_sentence_loss_num = torch.sum(batch['align_label'].view(-1) != -100)
            tri_label_loss_num = torch.sum(batch['tri_label'].view(-1) != -100)
            reg_label_loss_num = torch.sum(batch['reg_label'].view(-1) != -100.0)

        return ModelOutput(
            loss=total_loss,
            all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss] if 'mlm_label' in batch else None,
            loss_nums=[masked_lm_loss_num, next_sentence_loss_num, tri_label_loss_num, reg_label_loss_num] if 'mlm_label' in batch else None,
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            tri_label_logits=tri_label_score,
            reg_label_logits=reg_label_score,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions
        )

    def electra_forward(self, batch):
        """
        Special forward function for ELECTRA, handling the generator and discriminator models.

        Args:
            batch (dict): Input batch containing all required data for computation.

        Returns:
            ModelOutput: Output object containing all computed outputs including logits and losses.
        """
        if 'mlm_label' in batch:
            ce_loss_fct = nn.CrossEntropyLoss()
            generator_output = self.generator_mlm(self.generator(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                token_type_ids=batch['token_type_ids'] if 'token_type_ids' in batch else None
            ).last_hidden_state)
            masked_lm_loss = ce_loss_fct(generator_output.view(-1, self.generator.config.vocab_size), batch['mlm_label'].view(-1))

            # Replace masked tokens with generator predictions
            hallucinated_tokens = batch['input_ids'].clone()
            hallucinated_tokens[batch['mlm_label'] != -100] = torch.argmax(generator_output, dim=-1)[batch['mlm_label'] != -100]
            # Create labels for token replacement detection
            replaced_token_label = (batch['input_ids'] == hallucinated_tokens).long()
            replaced_token_label[batch['mlm_label'] != -100] = (batch['mlm_label'] == hallucinated_tokens)[batch['mlm_label'] != -100].long()
            replaced_token_label[batch['input_ids'] == 0] = -100  # ignore paddings

        # Discriminator predictions
        base_model_output = self.base_model(
            input_ids=hallucinated_tokens if 'mlm_label' in batch else batch['input_ids'],
            attention_mask=batch['attention_mask'],
            token_type_ids=batch['token_type_ids'] if 'token_type_ids' in batch else None
        )
        hallu_detect_score = self.discriminator_predictor(base_model_output.last_hidden_state)
        seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output))
        tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
        reg_label_score = self.reg_layer(base_model_output.pooler_output)

        total_loss = None

        if 'mlm_label' in batch:
            total_loss = []
            hallu_detect_loss = ce_loss_fct(hallu_detect_score.view(-1,2), replaced_token_label.view(-1))
            next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1))
            tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1))
            reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1))

            # Aggregate all component losses with appropriate weights
            total_loss.append(10.0 * hallu_detect_loss if not torch.isnan(hallu_detect_loss).item() else 0.)
            total_loss.append(0.2 * masked_lm_loss if (not torch.isnan(masked_lm_loss).item() and self.need_mlm) else 0.)
            total_loss.append(next_sentence_loss if not torch.isnan(next_sentence_loss).item() else 0.)
            total_loss.append(tri_label_loss if not torch.isnan(tri_label_loss).item() else 0.)
            total_loss.append(reg_label_loss if not torch.isnan(reg_label_loss).item() else 0.)

            total_loss = sum(total_loss)

        return ModelOutput(
            loss=total_loss,
            all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss, hallu_detect_loss] if 'mlm_label' in batch else None,
            prediction_logits=hallu_detect_score,
            seq_relationship_logits=seq_relationship_score,
            tri_label_logits=tri_label_score,
            reg_label_logits=reg_label_score,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions
        )

    def training_step(self, train_batch, batch_idx):
        """
        Handles the training step.

        Args:
            train_batch (dict): Batch data for training.
            batch_idx (int): Index of the batch.

        Returns:
            dict: Dictionary containing loss information.
        """
        output = self(train_batch)

        return {'losses': output.all_loss, 'loss_nums': output.loss_nums}

    def training_step_end(self, step_output):
        """
        Finalize the training step and log the losses.

        Args:
            step_output (dict): Output from the training step.

        Returns:
            float: Total computed loss.
        """
        losses = step_output['losses']
        loss_nums = step_output['loss_nums']
        assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'

        loss_mlm_num = torch.sum(loss_nums[0])
        loss_bin_num = torch.sum(loss_nums[1])
        loss_tri_num = torch.sum(loss_nums[2])
        loss_reg_num = torch.sum(loss_nums[3])

        # Normalize each loss by its corresponding number of items
        loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
        loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
        loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
        loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.

        # Compute total loss with weighting factors
        total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg

        self.log('train_loss', total_loss)  # Log total and component losses
        self.log('mlm_loss', loss_mlm)
        self.log('bin_label_loss', loss_bin)
        self.log('tri_label_loss', loss_tri)
        self.log('reg_label_loss', loss_reg)

        return total_loss

    def validation_step(self, val_batch, batch_idx):
        """
        Handles the validation step.

        Args:
            val_batch (dict): Batch data for validation.
            batch_idx (int): Index of the batch.

        Returns:
            dict: Dictionary containing loss or prediction information.
        """
        if not self.is_finetune:
            with torch.no_grad():
                output = self(val_batch)

            return {'losses': output.all_loss, 'loss_nums': output.loss_nums}

        with torch.no_grad():
            output = self(val_batch)['seq_relationship_logits']
            output = self.softmax(output)[:, 1].tolist()
            pred = [int(align_prob > 0.5) for align_prob in output]

            labels = val_batch['align_label'].tolist()

        return {"pred": pred, 'labels': labels}

    def validation_step_end(self, step_output):
        """
        Finalize the validation step and log the losses.

        Args:
            step_output (dict): Output from the validation step.

        Returns:
            float: Total computed loss.
        """
        losses = step_output['losses']
        loss_nums = step_output['loss_nums']
        assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'

        loss_mlm_num = torch.sum(loss_nums[0])
        loss_bin_num = torch.sum(loss_nums[1])
        loss_tri_num = torch.sum(loss_nums[2])
        loss_reg_num = torch.sum(loss_nums[3])

        # Normalize each loss by its corresponding number of items
        loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
        loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
        loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
        loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.

        # Compute total loss with weighting factors
        total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg

        self.log('val_loss', total_loss)  # Log total and component losses

        return total_loss

    def validation_epoch_end(self, outputs):
        """
        Handles the end of the validation epoch.

        Args:
            outputs (list): List of outputs from the validation steps.

        Effects:
            Logs the mean validation loss or F1 score.
        """
        if not self.is_finetune:
            total_loss = torch.stack(outputs).mean()
            self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)

        else:
            all_predictions = []
            all_labels = []
            for each_output in outputs:
                all_predictions.extend(each_output['pred'])
                all_labels.extend(each_output['labels'])

            self.log("f1", f1_score(all_labels, all_predictions), prog_bar=True, sync_dist=True)

    def configure_optimizers(self):
        """
        Prepare optimizer and schedule (linear warmup and decay).

        Returns:
            tuple: Contains the list of optimizers and list of schedulers.
        """
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(self.hparams.warmup_steps_portion * self.trainer.estimated_stepping_batches),
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

    def mse_loss(self, input, target, ignored_index=-100.0, reduction='mean'):
        """
        Custom mean squared error loss to handle possible ignored indices.

        Args:
            input (torch.Tensor): Predictions.
            target (torch.Tensor): True values.
            ignored_index (float): Value for ignored indices.
            reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

        Returns:
            torch.Tensor: Computed MSE loss.
        """
        mask = (target != ignored_index)
        out = (input[mask]-target[mask])**2
        if reduction == "mean":
            return out.mean()
        elif reduction == "sum":
            return out.sum()

class ElectraDiscriminatorPredictions(nn.Module):
    """
    Prediction module for the discriminator in ELECTRA model, consisting of two dense layers.

    Args:
        config (transformers.PretrainedConfig): Configuration object containing model configurations.
    """
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dense_prediction = nn.Linear(config.hidden_size, 2)
        self.config = config
        self.gelu = nn.GELU()

    def forward(self, discriminator_hidden_states):
        """
        Forward pass for the discriminator prediction module.

        Args:
            discriminator_hidden_states (torch.Tensor): Hidden states from the discriminator model.

        Returns:
            torch.Tensor: Logits for discriminator predictions.
        """
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = self.gelu(hidden_states)
        logits = self.dense_prediction(hidden_states).squeeze(-1)

        return logits

@dataclass
class ModelOutput():
    """
    Data class for storing outputs from the model during forward passes.

    Attributes:
        loss (Optional[torch.FloatTensor]): Total computed loss, if any.
        all_loss (Optional[list]): List of computed losses for different components, if available.
        loss_nums (Optional[list]): List of counts for the components contributing to the losses.
        prediction_logits (torch.FloatTensor): Logits for the predictions from the MLM head.
        seq_relationship_logits (torch.FloatTensor): Logits for sequence relationship predictions.
        tri_label_logits (torch.FloatTensor): Logits for tertiary label predictions.
        reg_label_logits (torch.FloatTensor): Logits for regression predictions.
        hidden_states (Optional[Tuple[torch.FloatTensor]]): Hidden states from the model.
        attentions (Optional[Tuple[torch.FloatTensor]]): Attention weights from the model.
    """
    loss: Optional[torch.FloatTensor] = None
    all_loss: Optional[list] = None
    loss_nums: Optional[list] = None
    prediction_logits: torch.FloatTensor = None
    seq_relationship_logits: torch.FloatTensor = None
    tri_label_logits: torch.FloatTensor = None
    reg_label_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


**Training Strategy**

**Multi-Task Learning:** ALIGN is trained on a diverse set of datasets, each contributing different types of alignment scenarios (e.g., factual alignment in fact-checking, logical alignment in entailment). This multi-task learning approach helps the model generalize across tasks by exposing it to various alignment contexts during training.

**Aggregation of Data:** By aggregating data from 28 datasets, ALIGN benefits from a broad spectrum of linguistic phenomena and relationships, enhancing its robustness and ability to perform on unseen data. This aggregation also mitigates the risk of overfitting to specific dataset quirks.

**Handling of Long Inputs:** One of the novel features of ALIGN is its approach to managing long text inputs, which are common in real-world data but problematic for standard transformers due to their fixed maximum input length:

**Split-then-Aggregate Method:** The model first splits long texts into smaller, manageable segments. Each segment is processed independently, and the results are then aggregated to form a final prediction. This method ensures that no crucial information is lost due to truncation.

**Aggregation Techniques:** The aggregation can be done through various statistical techniques like averaging or maximum scoring, depending on the task requirements. This flexibility allows the model to adapt its processing based on the specific needs of the evaluation context.

In [None]:
from pytorch_lightning import Trainer, seed_everything
from align.dataloader import DSTDataLoader
from align.model import BERTAlignModel
from pytorch_lightning.callbacks import ModelCheckpoint
from argparse import ArgumentParser
import os

# Define the main training function
def train(datasets, args):
    # Initialize the DataLoader with specified configurations
    dm = DSTDataLoader(
        dataset_config=datasets,
        model_name=args.model_name,
        sample_mode='seq',  # Sampling mode for data loading
        train_batch_size=args.batch_size,
        eval_batch_size=16,
        num_workers=args.num_workers,
        train_eval_split=0.95,  # Split ratio for training and evaluation
        need_mlm=args.do_mlm  # Whether to perform masked language modeling
    )
    dm.setup()  # Setup data module

    # Initialize the model with parameters specified in args
    model = BERTAlignModel(model=args.model_name, using_pretrained=args.use_pretrained_model,
        adam_epsilon=args.adam_epsilon,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps_portion=args.warm_up_proportion
    )
    model.need_mlm = args.do_mlm  # Set whether MLM is needed based on args

    # Prepare a checkpoint name based on various parameters
    checkpoint_name = '_'.join((
        f"{args.ckpt_comment}{args.model_name.replace('/', '-')}",
        f"{'scratch_' if not args.use_pretrained_model else ''}{'no_mlm_' if not args.do_mlm else ''}",
        str(args.max_samples_per_dataset),
        f"{args.batch_size}x{len(args.devices)}x{args.accumulate_grad_batch}"
    ))

    # Define a model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.ckpt_save_path,
        filename=checkpoint_name + "_{epoch:02d}_{step}",
        every_n_train_steps=10000,  # Save a checkpoint every 10,000 training steps
        save_top_k=1  # Save only the top 1 checkpoint
    )

    # Configure the PyTorch Lightning Trainer
    trainer = Trainer(
        accelerator='gpu',
        max_epochs=args.num_epoch,
        devices=args.devices,
        strategy="dp",  # Use data parallel strategy
        precision=32,
        callbacks=[checkpoint_callback],
        accumulate_grad_batches=args.accumulate_grad_batch  # Gradient accumulation to manage memory
    )

    # Start training
    trainer.fit(model, datamodule=dm)
    # Save final checkpoint at the end of training
    trainer.save_checkpoint(os.path.join(args.ckpt_save_path, f"{checkpoint_name}_final.ckpt"))

    print("Training is finished.")

# The main guard for running the script
if __name__ == "__main__":
    # Define all possible training datasets with configuration
    ALL_TRAINING_DATASETS = {
        'mnli': {'task_type': 'nli', 'data_path': 'mnli.json'},
        # Add more datasets with similar structure...
    }

    # Initialize argument parser
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=2022)  # Random seed for reproducibility
    parser.add_argument('--batch-size', type=int, default=32)  # Training batch size
    parser.add_argument('--accumulate-grad-batch', type=int, default=1)  # Gradient accumulation steps
    parser.add_argument('--num-epoch', type=int, default=3)  # Number of training epochs
    parser.add_argument('--num-workers', type=int, default=8)  # Number of workers for data loading
    parser.add_argument('--warm-up-proportion', type=float, default=0.06)  # Proportion of training for warm-up in scheduler
    parser.add_argument('--adam-epsilon', type=float, default=1e-6)  # Epsilon parameter for Adam optimizer
    parser.add_argument('--weight-decay', type=float, default=0.1)  # Weight decay for regularization
    parser.add_argument('--learning-rate', type=float, default=1e-5)  # Learning rate
    parser.add_argument('--val-check-interval', type=float, default=1. / 4)  # Validation check frequency
    parser.add_argument('--devices', nargs='+', type=int, required=True)  # GPU devices to use
    parser.add_argument('--model-name', type=str, default="roberta-large")  # Model architecture
    parser.add_argument('--ckpt-save-path', type=str, required=True)  # Path to save checkpoints
    parser.add_argument('--ckpt-comment', type=str, default="")  # Optional comment to prefix on checkpoint names
    parser.add_argument('--trainin-datasets', nargs='+', type=str, default=list(ALL_TRAINING_DATASETS.keys()), choices=list(ALL_TRAINING_DATASETS.keys()))  # Datasets to train on
    parser.add_argument('--data-path', type=str, required=True)  # Path to the dataset directory
    parser.add_argument('--max-samples-per-dataset', type=int, default=500000)  # Maximum samples per dataset
    parser.add_argument('--do-mlm', type=bool, default=False)  # Whether to perform MLM
    parser.add_argument('--use-pretrained-model', type=bool, default=True)  # Use pretrained model

    # Parse arguments
    args = parser.parse_args()

    # Seed all mechanisms for reproducibility
    seed_everything(args.seed)

    # Configure datasets with size and specific data path
    datasets = {
        name: {
            **ALL_TRAINING_DATASETS[name],
            "size": args.max_samples_per_dataset,
            "data_path": os.path.join(args.data_path, ALL_TRAINING_DATASETS[name]['data_path'])
        }
        for name in args.trainin_datasets
    }

    # Start the training process
    train(datasets, args)


**Experiments and Evaluation
Experimental Setup**

**Datasets Used:** The evaluation uses a comprehensive set of 25 datasets covering a wide spectrum of NLP tasks including textual entailment, fact verification, semantic textual similarity, question answering, and coreference resolution. This diversity ensures that the model's performance is tested under various linguistic and contextual challenges.

**Baseline Models for Comparison:** ALIGN is compared against several models:

**FLAN-T5 Models:** These are larger transformer models known for their flexibility and strength across many tasks due to extensive pre-training and fine-tuning.

**Task-Specific Fine-tuned Models:** These
models are fine-tuned specifically for each task, providing a benchmark for what specialized models can achieve.

**RoBERTa Models Fine-tuned on Individual Datasets:** Comparisons include versions of RoBERTa that have been fine-tuned for specific tasks, highlighting the advantage of multi-task learning in ALIGN.

**Training Details:** ALIGN is trained using a mixture of supervised learning from human-annotated datasets and unsupervised learning from large text corpora. This combination leverages the strengths of both learning paradigms, enhancing the model's generalization capabilities.


**Evaluation Metrics**

**Performance Metrics:** The primary metrics used for evaluation include accuracy, F1 score, and ROC AUC, depending on the task. These metrics help quantify the model's effectiveness in different scenarios:

  **Accuracy and F1 Score:** Used mainly for classification tasks like paraphrase detection and entailment.

  **ROC AUC:** Utilized for tasks with probabilistic outputs, such as predicting the likelihood of a text pair being aligned.

  **Efficiency Metrics:** Besides performance, efficiency metrics such as inference time and computational resource usage are tracked. These metrics are crucial for evaluating the practicality of deploying ALIGN in resource-constrained environments.




**Performance Across Tasks and Efficiency Analysis**


**Textual Entailment (NLI Tasks):**
ALIGN versus FLAN-T5: ALIGN achieves an F1 score of approximately 91.4%, competitive with FLAN-T5 models, which score around 90.5%. Despite having fewer parameters, ALIGN shows high accuracy, nearly reaching 90.3% on MultiNLI, comparable to FLAN-T5’s 90.5%.

**Fact Verification:**
In tasks like VitaminC and FEVER, ALIGN exhibits robust performance with an accuracy of 89.8% on VitaminC, surpassing RoBERTa, which is fine-tuned specifically for this dataset and scores around 88.7%.

**Semantic Textual Similarity:**
For the STS-B benchmark, ALIGN records an average Pearson correlation coefficient of 0.89, outperforming baseline models which average at 0.87.

**Resource Utilization:**
ALIGN requires significantly less computational power, utilizing only 620 GPU hours for training, compared to the 1,200 GPU hours typically consumed by models like GPT-3.5.

**Inference Speed:**
ALIGN processes inputs at a rate of 2,000 tokens per second, approximately 20% faster than FLAN-T5 models.

**Overall Performance on Diverse Datasets:**
ALIGN achieves better or comparable results across over 20 diverse datasets, with an average accuracy improvement of +1.5% over FLAN-T5, notably excelling in paraphrase detection with a 92.6% accuracy on PAWS, against 91.9% for larger counterparts.

**Performance in Specialized Tasks:**
In coreference resolution tasks, ALIGN’s performance is noted at an accuracy of 88.6%, significantly higher than FLAN-T5’s 85.7%. For question answering datasets like RACE, ALIGN scores 86.8% on the middle school portion, compared to 85.1% by FLAN-T5.

**Generalizability and Robustness:**
ALIGN shows strong generalizability across unseen datasets, a testament to its robust training regime and the effectiveness of its split-then-aggregate method for handling various input lengths and complexities. This makes ALIGN a viable option for applications with limited computational budgets, consistently matching or outperforming FLAN-T5 models on most tasks and surpassing task-specific fine-tuned models in several cases, demonstrating its robustness and adaptability across different NLP challenges. The model’s performance highlights its capability to understand and evaluate complex textual relationships accurately.

These results underscore the ALIGN model's efficacy in handling complex NLP tasks with fewer resources while maintaining competitive accuracy and speed. The model's ability to perform across a broad range of tasks with less computational overhead highlights its potential as a scalable and efficient NLP tool. Furthermore, the robust performance across unseen datasets in zero-shot settings demonstrates its generalizability and potential for real-world applications.

# Conclusion and Future Direction



### Limitations and Future Work

#### Limitations of the ALIGN Model

1. **Handling of Long Inputs**:
   - **Split-and-Aggregate Technique**: While the ALIGN model employs a split-and-aggregate approach to handle inputs longer than its maximum token limit, this technique can potentially overlook contextual nuances. For text inputs where context plays a critical role in understanding (e.g., long documents or complex narratives), this method might lead to suboptimal performance.
   - **Truncation Issues**: The split-and-aggregate method could lead to fragmentation of essential information, particularly if the splitting algorithm does not accurately capture the semantic boundaries within the text.

2. **Model Generalization**:
   - **Domain-Specific Performance**: Although ALIGN performs robustly across a wide range of tasks, its efficiency in domain-specific scenarios (e.g., legal or medical texts) has not been explicitly tested. The nuances of specialized vocabularies and contexts in these fields might challenge the generalization capabilities of ALIGN.
   - **Limited Adversarial Robustness**: The model's performance under adversarial conditions or with deliberately misleading inputs has not been thoroughly explored, which could be crucial for applications in security-sensitive environments.

3. **Dataset and Training Limitations**:
   - **Bias and Representation**: ALIGN is trained on a diverse but fixed set of datasets. There is a risk that biases present in these training datasets could be learned by the model, potentially affecting its performance and fairness when deployed in real-world scenarios.
   - **Dependence on Pre-trained Models**: The efficiency of ALIGN is partly due to its reliance on the robustly optimized BERT architecture (RoBERTa). This dependence implies that any limitations inherent to RoBERTa, such as handling of out-of-vocabulary words or sensitivity to input perturbations, could also affect ALIGN.

#### Future Work

1. **Enhanced Input Handling**:
   - **Advanced Splitting Techniques**: Future versions of ALIGN could incorporate more sophisticated mechanisms for handling long inputs, such as attention mechanisms that can dynamically determine the most relevant segments of text to process, thereby preserving contextual integrity.
   - **Contextual Chunking**: Implementing a contextual chunking method that respects semantic and syntactic boundaries could improve the model's understanding of longer documents.

2. **Domain Adaptation**:
   - **Specialized Fine-Tuning**: To enhance the model's applicability to specialized domains, future research could focus on domain-adaptive pre-training or fine-tuning approaches that tailor the model to specific industries or fields of study.
   - **Robustness Testing**: Systematic adversarial testing and robustness checks can be incorporated into the model's evaluation phase to ensure stability and reliability under diverse and challenging conditions.

3. **Bias Mitigation and Ethical Considerations**:
   - **Bias Detection and Correction**: Incorporating techniques for detecting and mitigating biases in training data can enhance the fairness and ethical use of ALIGN in diverse applications.
   - **Ethical Guidelines**: Establishing clear ethical guidelines for the deployment of ALIGN, particularly in sensitive applications, can ensure that the model's use aligns with societal norms and values.

By addressing these limitations and exploring the suggested avenues for future work, the ALIGN model can be refined to deliver more robust, fair, and contextually aware performance across a broader range of applications and domains.

# References:

[1] Zha, Y., Yang, Y., Li, R., Hu, Z., & UC San Diego. (2023). Text alignment is an efficient unified model for massive NLP tasks. In 37th Conference on Neural Information Processing Systems (NeurIPS 2023) [Conference-proceeding].
URL https://proceedings.neurips.cc/paper_files/paper/2023/file/f5708199bdc013c5b56406db305b991e-Paper-Conference.pdf

[2] Armen Aghajanyan, Anchit Gupta, Akshat Shrivastava, Xilun Chen, Luke Zettlemoyer, and Sonal Gupta.
Muppet: Massive multi-task representations with pre-finetuning. In Proceedings of the 2021 Conference on
Empirical Methods in Natural Language Processing, pages 5799–5811, Online and Punta Cana, Dominican
Republic, November 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.
468. URL https://aclanthology.org/2021.emnlp-main.468.

[3] Xiaodong Liu, Pengcheng He, Weizhu Chen, and Jianfeng Gao. Multi-task deep neural networks for natural
language understanding. In Proceedings of the 57th Annual Meeting of the Association for Computational
Linguistics, pages 4487–4496, Florence, Italy, July 2019. Association for Computational Linguistics. doi:
10.18653/v1/P19-1441. URL https://aclanthology.org/P19-1441.