In [None]:
!pip install transformers
!pip install datasets

In [None]:
from transformers import AutoTokenizer, TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification
from datasets import Dataset
from scipy.special import softmax
import torch
import torch.nn as nn
from datasets import Dataset
import pandas as pd
from os.path import join
import matplotlib.pyplot as plt
from datasets.load import load_metric
import numpy as np

In [None]:
from transformers import DistilBertPreTrainedModel, PretrainedConfig, DistilBertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union, Dict, List, Any

For our improved model, we are going to combine a sentiment feature with the textual features that our baseline BERT model has learned from Reddit post text. We will achieve this by performing feature concatenation in the pre_classifier layer (i.e. the layer before the classifier layer). We will also modify the weights of our classifier layer such that our new sentiment feature is assigned a weight of 1, and thus, has a larger impact on the final classification result.

In [None]:
class CustomDistilBertForSequenceClassification(DistilBertPreTrainedModel):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim + 1, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        # Initialize weights and apply final processing
        self.post_init()

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings
        """
        return self.distilbert.get_position_embeddings()

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embedding matrix. If position embeddings are learned, increasing the size
                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
                size will add correct vectors at the end following the position encoding algorithm, whereas reducing
                the size will remove vectors from the end.
        """
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        sentiment: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        severity: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)

        # Here is the magic! 

        pooled_output = torch.cat((pooled_output, sentiment.view(torch.numel(sentiment), -1)), dim=1)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        
        logits = self.classifier(pooled_output)  # (bs, num_labels)

        loss = None
        if severity is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (severity.dtype == torch.long or severity.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), severity.squeeze())
                else:
                    loss = loss_fct(logits, severity)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), severity.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, severity)

        if not return_dict:
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )

Below, I have provided an example of how the model can be trained. It is important to remember that the severity label is a weak signal that we are using to train on each Reddit post, since labels should only be assigned to users rather than posts. Therefore, we employ the same strategy as we did for the baseline, where we use DistilBERT to train our Reddit posts on severity and then aggregate the Reddit posts and their classifications by users before using a one-level Decision Tree classifier to assign a label to each user.

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

In [None]:
sample_train_datapt = tokenizer.encode_plus("I want to take you away.", padding='max_length', truncation=True)
sample_train_datapt2 = tokenizer.encode_plus("I want to eat.", padding='max_length', truncation=True)
sample_test_datapt = tokenizer.encode_plus("I want to take you away.", padding='max_length', truncation=True)
dummytrain_df = pd.DataFrame({'input_ids': [sample_train_datapt['input_ids'], sample_train_datapt2['input_ids']], 'attention_mask': [sample_train_datapt['attention_mask'], sample_train_datapt2['attention_mask']], 'severity': [0, 1], 'sentiment': [1, 0]})
dummytest_df = pd.DataFrame({'input_ids': [sample_test_datapt['input_ids']], 'attention_mask': [sample_test_datapt['attention_mask']], 'severity': [0], 'sentiment': [1]})

In [None]:
dummytrain_dataset = Dataset.from_pandas(dummytrain_df)
dummytest_dataset = Dataset.from_pandas(dummytest_df)

In [None]:
args = TrainingArguments(
          output_dir= "distilbert-base-cased-checkpoint",
          do_train=True,
          do_eval=True,
          num_train_epochs=1,
          evaluation_strategy='epoch',
          label_names=['severity']
        )

In [None]:
num_categories = 2
model = CustomDistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path="distilbert-base-cased", num_labels=num_categories)

In [None]:
trainer = Trainer(
          model=model,
          args=args,
          train_dataset=dummytrain_dataset,
          eval_dataset=dummytest_dataset,
          tokenizer=tokenizer,
        )

In [None]:
trainer.train()

We also want to experiment with using multiple outputs. For example, we may want to use the sentiment label for each Reddit post as a weak signal to fine-tune our modified DistilBERT model on the task of high-severity vs. low-severity classification. For this, we introduce a second classifier layer for our sentiment label. Then, we modify our model training process by subclassing the Trainer API from the HuggingFace library.

In [None]:
class CustomDistilBertMultipleOutputsForSequenceClassification(DistilBertPreTrainedModel):
    def __init__(self, config: PretrainedConfig, sentiment_labels: int):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.num_sentiments = sentiment_labels
        self.config = config

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)

        # Here, we increase the size of our classifier layer to accommodate for the extra feature(s)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.sentiment_classifier = nn.Linear(config.dim, sentiment_labels)

        self.dropout = nn.Dropout(config.seq_classif_dropout)

        # Initialize weights and apply final processing
        self.post_init()

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings
        """
        return self.distilbert.get_position_embeddings()

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embedding matrix. If position embeddings are learned, increasing the size
                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
                size will add correct vectors at the end following the position encoding algorithm, whereas reducing
                the size will remove vectors from the end.
        """
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        sentiment_label: Optional[torch.Tensor] = None,
        severity: Optional[torch.LongTensor] = None,
        sentiment: Optional[torch.LongTensor] = None
    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)

        logits_severity = self.classifier(pooled_output)  # (bs, num_labels)
        logits_sentiment = self.sentiment_classifier(pooled_output)

        loss_severity = None
        loss_sentiment = None

        if severity is not None:
          loss_fct = CrossEntropyLoss()
          loss_severity = loss_fct(logits_severity.view(-1, self.num_labels), severity)
        
        if sentiment is not None:
          loss_fct = BCEWithLogitsLoss()
          sentiment_labels = torch.FloatTensor([[1 if x == sentiment[i] else 0 for x in range(self.num_sentiments)] for i in range(torch.numel(sentiment))])
          loss_sentiment = loss_fct(logits_sentiment.view(-1, self.num_sentiments), sentiment_labels)

        return (loss_severity + loss_sentiment, SequenceClassifierOutput(
            loss=loss_severity,
            logits=logits_severity,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        ), SequenceClassifierOutput(
            loss=loss_sentiment,
            logits=logits_sentiment,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        ))

Now we create our custom trainer by subclassing the Trainer API. Because we need custom behavior for training our model with multiple outputs, we override the evaluation_loop and prediction_step methods.

In [None]:
from transformers.trainer_pt_utils import IterableDatasetShard, nested_detach, nested_concat, nested_numpify, nested_truncate, find_batch_size
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, denumpify_detensorize, has_length
from torch.utils.data import DataLoader
from transformers.deepspeed import deepspeed_init

In [None]:
def smp_forward_only(model, inputs):
  return model(**inputs)

def smp_nested_concat(tensor):
  if isinstance(tensor, (list, tuple)):
      return type(tensor)(smp_nested_concat(t) for t in tensor)
  elif isinstance(tensor, dict):
      return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
  # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
  # which is also the name of the decorator so Python is confused.
  return tensor.concat().detach().cpu()

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        return (outputs[0], outputs[1], outputs[2]) if return_outputs else outputs[0]
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
        Works both with or without labels.
        """
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

        # if eval is called w/o train init deepspeed here
        if args.deepspeed and not self.deepspeed:
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = getattr(dataloader, "dataset", None)

        if args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        inputs_host = None

        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        all_inputs = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            loss, logits_severity, logits_sentiment, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            # print(self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys))
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None

            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if labels is not None:
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
            
            if logits_severity is not None:
                logits_severity = self._pad_across_processes(logits_severity)
                logits_severity = self._nested_gather(logits_severity)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits_severity, labels)
                preds_host = logits_severity if preds_host is None else nested_concat(preds_host, logits_severity, padding_index=-100)
            
            if logits_sentiment is not None:
                logits_sentiment = self._pad_across_processes(logits_sentiment)
                logits_sentiment = self._nested_gather(logits_sentiment)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits_sentiment, labels)
                preds_host = logits_sentiment if preds_host is None else nested_concat(preds_host, logits_sentiment, padding_index=-100)


            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
        if has_length(eval_dataset):
            num_samples = len(eval_dataset)
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
            num_samples = eval_dataset.num_examples
        else:
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
    
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        logits_severity = None
        logits_sentiment = None

        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels or loss_without_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels or loss_without_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
                else:
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
            else:
                if has_labels or loss_without_labels:
                    with self.compute_loss_context_manager():
                        loss, output_severity, output_sentiment = self.compute_loss(model, inputs, return_outputs=True)
                    loss = loss.mean().detach()

                    if isinstance(output_severity, dict):
                        logits_severity = tuple(v for k, v in output_severity.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits_severity = outputs[1:]
                    
                    if isinstance(output_sentiment, dict):
                        logits_sentiment = tuple(v for k, v in output_sentiment.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits_sentiment = outputs[1:]
                else:
                    loss = None
                    with self.compute_loss_context_manager():
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None, None)

        logits_severity = nested_detach(logits_severity)
        if len(logits_severity) == 1:
            logits_severity = logits_severity[0]
        
        logits_sentiment = nested_detach(logits_sentiment)
        if len(logits_sentiment) == 1:
            logits_sentiment = logits_sentiment[0]

        return (loss, logits_severity, logits_sentiment, labels)

Below, I have provided an example for how the model will be used. Similar to the previous custom DistilBERT model, we employ the same strategy as we did for the baseline, where we use DistilBERT to train our Reddit posts on severity AND sentiment and then aggregate the Reddit posts and their classifications by users before using a one-level Decision Tree classifier to assign a label to each user.

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

In [None]:
sample_train_datapt = tokenizer.encode_plus("I want to take you away.", padding='max_length', truncation=True)
sample_train_datapt2 = tokenizer.encode_plus("I want to eat.", padding='max_length', truncation=True)
sample_test_datapt = tokenizer.encode_plus("I want to take you away.", padding='max_length', truncation=True)
dummytrain_df = pd.DataFrame({'input_ids': [sample_train_datapt['input_ids'], sample_train_datapt2['input_ids']], 'attention_mask': [sample_train_datapt['attention_mask'], sample_train_datapt2['attention_mask']], 'severity': [0, 1], 'sentiment': [1, 0]})
dummytest_df = pd.DataFrame({'input_ids': [sample_test_datapt['input_ids']], 'attention_mask': [sample_test_datapt['attention_mask']], 'severity': [0], 'sentiment': [1]})

In [None]:
dummytrain_dataset = Dataset.from_pandas(dummytrain_df)
dummytest_dataset = Dataset.from_pandas(dummytest_df)

In [None]:
args = TrainingArguments(
          output_dir= "distilbert-base-cased-checkpoint",
          do_train=True,
          do_eval=True,
          num_train_epochs=1,
          evaluation_strategy='epoch',
          label_names=['severity', 'sentiment']
        )

In [None]:
trainer = CustomTrainer(
          model=model,
          args=args,
          train_dataset=dummytrain_dataset,
          eval_dataset=dummytest_dataset,
          tokenizer=tokenizer,
        )

In [None]:
trainer.train()