In this tutorial we will introduce our unified model to handle the three tasks: sequence classification, span detection, and pair classification. Full model structure is clearly depicted in the below [figure](../assets/unicausal_model.png). 

### Model Structure

It adopts a BERT embedding layer as a backbone to transform input texts into sequence embeddings. We may use [predefined HuggingFace model configuration](https://huggingface.co/docs/transformers/model_doc/bert) to initialize the BERT layer. It then feeds the token embeddings into a couple of downstream span token classifiers - the exact number depends on how many spans are there in a sequence. Although named span detection, it is actually a classification task as we are assigning a range of labels to input tokens.

Each span classifier is connected in the sense that prior classifier will feed its output logits to the next one, along with the tokenized sequence. This makes sense considering pairs of cause and effect spans usually have dependence on each other. 

For sequence and pair classification, it requires a more global view of the input sequence, thus the tokens will go through an extra step of pooling layer, before getting fed into separate sequence and pair classifers. The cross-entropy losses from above mentioned structures are aggregated using weighted summation, where a tunable scaling factor $\alpha$ is used to weigh the importance of token and sequence classification loss. 

These implementation details can be found in our [paper]() or in the later model signature section.

### Load the Model

Now we will show how to load the unified model for our use.

In [None]:
# import model
from ..models.classifiers.modeling_bert import BertForUnifiedCRBase
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
) # other imports are omitted for clarification

# Instantiate config
# either provide a pretrained config name or model path, or a model type to initialize a config.
# for tutorial of transformers config, see https://huggingface.co/docs/transformers/main_classes/configuration
config_name = model_name_or_path = None
model_type = 'roberta'
# Example token labels to ints mapping: {'B-C': 0, 'B-E': 1, 'I-C': 2, 'I-E': 3, 'O': 4}
# for the meaning of each label, see our paper: 
num_token_labels = 5

if config_name:
    config = AutoConfig.from_pretrained(config_name, num_labels=num_token_labels)
elif model_name_or_path:
    config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_token_labels)
else:
    config = CONFIG_MAPPING[model_type]()

# load model from config
if model_name_or_path:
    model = BertForUnifiedCRBase.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=config,
        num_seq_labels=2,
        loss_function='simple',
        alpha=3 # parameter mentioned before, for weighing token and sequence loss significance
    )
else:
    print("Training new model from scratch")
    model = BertForUnifiedCRBase.from_config(config)
# Now the model is ready for use. If you would like to see the full initialization process, check out 
# our training script at `../run.py`.

### Model Signature

It is even more illuminating to see how the model works internally as a HuggingFace [`BertPreTrainedModel`](https://huggingface.co/docs/transformers/main_classes/model) subclass. We have attached all relevant BertForUnifiedCRBase model function signatures (simplified) and i/o below for illustration.

```
class Pooler(nn.Module):
    """Pooler model to condense input sequence embeddings"""
    def __init__(self, seq_length, condensed_size):
        ...

class BertForUnifiedCR(BertPreTrainedModel):

    def __init__(self, config, ...):
        super().__init__(config)
        ...
        self.bert = BertModel(config, ...)
        self.dropout = nn.Dropout(...)
        ...
        
        self.reader = nn.LSTM(...)
        self.linear = nn.Linear(...)
        self.linear2 = nn.Linear(...) 
        self.tokclf = nn.Linear(...) # token classification layer
        self.pool = Pooler(...)
        self.seqclf = nn.Linear(...) # sequence classification layer

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        ... # other inputs to bert model
    ):

        # compute bert embeddings
        outputs = self.bert(
            input_ids=...,
            attention_mask=...,
            ...
        )
        
        # take first BERT token embedding as sequence classifier input
        sequence_output = outputs[0]
        sequence_output = some_function_to_add_other_features(sequence_output, ...)
        
        # lstm processing
        sequence_output = self.reader(some_processing_function(sequence_output)) # recall what self.reader was
        
        # condense
        sequence_output = self.linear(some_other_processing_function(sequence_output))
        sequence_output = self.linear2(some_other_processing_function(sequence_output))
                               
        # perform token classification
        sequence_output = self.dropout(sequence_output) # either toggled on or off
        tok_logits = self.tokclf(sequence_output)
        tok_loss = some_function_to_compute_loss(tok_logits, ...)

        # sequence/pair classification
        pooled_output=self.pool(tok_logits) # pooling operation for sequence inputs
        logits = self.seqclf(pooled_output)
        loss = some_other_function_to_compute_loss(logits, ...)
        
        # return predictions and losses
        return (loss, logits, tok_loss, tok_logits,)

    @classmethod
    def from_pretrained(cls,...):
        """Used when pretrained weights available. Will elaborate in the next notebook."""
        ...
```

### Compare with HF Classification

You may wonder how our unified model structure compares to HuggingFace off-the-shelf classifiers for the three tasks. Below is an equivalent model instantiation script for a HuggingFace [pre-trained model](https://huggingface.co/docs/transformers/model_doc/auto).

In [None]:
# Config instantiation process is the same as above
from transformers import AutoModelForTokenClassification

if model_name_or_path:
    model = AutoModelForTokenClassification.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=config,
    )
else:
    logger.info("Training new model from scratch")
    model = AutoModelForTokenClassification.from_config(config)

Now that we have cleared doubts about how to load in our unified model. It is time to train it up. In the next tutorial, we will introduce how to make predictions using our pretrained model.