In [None]:
! pip install torch
! pip install transformers
! pip install scikit-learn
! pip install tqdm
! pip install numpy
! pip install datasets
! pip install nltk
import nltk
nltk.download('stopwords')
! pip install scipy
! pip install transformers[torch] accelerate


Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m110.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m77.2 MB/s[0m eta [36m0:00:

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


Collecting accelerate
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.21.0


In [None]:
from datasets import load_dataset
existing_dataset = load_dataset("lex_glue", 'ecthr_b')
train=existing_dataset["train"]
print(train[1])

{'text': ['9.  The applicant is the monarch of Liechtenstein, born in 1945 and living in Vaduz (Liechtenstein).', '10.  The applicant’s late father, the former monarch of Liechtenstein, had been the owner of the painting Szene an einem römischen Kalkofen (alias Der große Kalkofen) of Pieter van Laer, which had formed part of his family’s art collection since at least 1767. Until the end of the Second World War the painting had been in one of the family’s castles on the territory of the now Czech Republic.', '11.  In 1946 the former Czechoslovakia confiscated the property of the applicant’s father which was situated in its territory, including the painting in question, under Decree no. 12 on the “confiscation and accelerated allocation of agricultural property of German and Hungarian persons and of those having committed treason and acted as enemies of the Czech and Slovak people” (dekretu prezidenta republiky č. 12/1945 Sb. o konfiskaci a urychleném rozdělení majetku Němců, Mad’arů, zr

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import numpy as np
from torch import nn
from transformers.file_utils import ModelOutput


@dataclass
class SimpleOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


def sinusoidal_init(num_embeddings: int, embedding_dim: int):
    # keep dim 0 for padding token position encoding zero vector
    position_enc = np.array([
        [pos / np.power(10000, 2 * i / embedding_dim) for i in range(embedding_dim)]
        if pos != 0 else np.zeros(embedding_dim) for pos in range(num_embeddings)])

    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return torch.from_numpy(position_enc).type(torch.FloatTensor)


class HierarchicalBert(nn.Module):

    def __init__(self, encoder, max_segments=64, max_segment_length=128):
        super(HierarchicalBert, self).__init__()
        supported_models = ['bert', 'roberta', 'deberta']
        assert encoder.config.model_type in supported_models  # other model types are not supported so far
        # Pre-trained segment (token-wise) encoder, e.g., BERT
        self.encoder = encoder
        # Specs for the segment-wise encoder
        self.hidden_size = encoder.config.hidden_size
        self.max_segments = max_segments
        self.max_segment_length = max_segment_length
        # Init sinusoidal positional embeddings
        self.seg_pos_embeddings = nn.Embedding(max_segments + 1, encoder.config.hidden_size,
                                               padding_idx=0,
                                               _weight=sinusoidal_init(max_segments + 1, encoder.config.hidden_size))
        # Init segment-wise transformer-based encoder
        self.seg_encoder = nn.Transformer(d_model=encoder.config.hidden_size,
                                          nhead=encoder.config.num_attention_heads,
                                          batch_first=True, dim_feedforward=encoder.config.intermediate_size,
                                          activation=encoder.config.hidden_act,
                                          dropout=encoder.config.hidden_dropout_prob,
                                          layer_norm_eps=encoder.config.layer_norm_eps,
                                          num_encoder_layers=2, num_decoder_layers=0).encoder

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                ):
        # Hypothetical Example
        # Batch of 4 documents: (batch_size, n_segments, max_segment_length) --> (4, 64, 128)
        # BERT-BASE encoder: 768 hidden units

        # Squash samples and segments into a single axis (batch_size * n_segments, max_segment_length) --> (256, 128)
        input_ids_reshape = input_ids.contiguous().view(-1, input_ids.size(-1))
        attention_mask_reshape = attention_mask.contiguous().view(-1, attention_mask.size(-1))
        if token_type_ids is not None:
            token_type_ids_reshape = token_type_ids.contiguous().view(-1, token_type_ids.size(-1))
        else:
            token_type_ids_reshape = None

        # Encode segments with BERT --> (256, 128, 768)
        encoder_outputs = self.encoder(input_ids=input_ids_reshape,
                                       attention_mask=attention_mask_reshape,
                                       token_type_ids=token_type_ids_reshape)[0]

        # Reshape back to (batch_size, n_segments, max_segment_length, output_size) --> (4, 64, 128, 768)
        encoder_outputs = encoder_outputs.contiguous().view(input_ids.size(0), self.max_segments,
                                                            self.max_segment_length,
                                                            self.hidden_size)

        # Gather CLS outputs per segment --> (4, 64, 768)
        encoder_outputs = encoder_outputs[:, :, 0]

        # Infer real segments, i.e., mask paddings
        seg_mask = (torch.sum(input_ids, 2) != 0).to(input_ids.dtype)
        # Infer and collect segment positional embeddings
        seg_positions = torch.arange(1, self.max_segments + 1).to(input_ids.device) * seg_mask
        # Add segment positional embeddings to segment inputs
        encoder_outputs += self.seg_pos_embeddings(seg_positions)

        # Encode segments with segment-wise transformer
        seg_encoder_outputs = self.seg_encoder(encoder_outputs)

        # Collect document representation
        outputs, _ = torch.max(seg_encoder_outputs, 1)

        return SimpleOutput(last_hidden_state=outputs, hidden_states=outputs)


if __name__ == "__main__":
    from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

    # Use as a stand-alone encoder
    bert = AutoModel.from_pretrained('bert-base-uncased')
    model = HierarchicalBert(encoder=bert, max_segments=64, max_segment_length=128)

    fake_inputs = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
    for i in range(4):
        # Tokenize segment
        temp_inputs = tokenizer(['dog ' * 126] * 64)
        fake_inputs['input_ids'].append(temp_inputs['input_ids'])
        fake_inputs['attention_mask'].append(temp_inputs['attention_mask'])
        fake_inputs['token_type_ids'].append(temp_inputs['token_type_ids'])

    fake_inputs['input_ids'] = torch.as_tensor(fake_inputs['input_ids'])
    fake_inputs['attention_mask'] = torch.as_tensor(fake_inputs['attention_mask'])
    fake_inputs['token_type_ids'] = torch.as_tensor(fake_inputs['token_type_ids'])

    output = model(fake_inputs['input_ids'], fake_inputs['attention_mask'], fake_inputs['token_type_ids'])

    # 4 document representations of 768 features are expected
    assert output[0].shape == torch.Size([4, 768])

    # Use with HuggingFace AutoModelForSequenceClassification and Trainer API

    # Init Classifier
    model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=10)
    # Replace flat BERT encoder with hierarchical BERT encoder
    model.bert = HierarchicalBert(encoder=model.bert, max_segments=64, max_segment_length=128)
    output = model(fake_inputs['input_ids'], fake_inputs['attention_mask'], fake_inputs['token_type_ids'])

    # 4 document outputs with 10 (num_labels) logits are expected
    assert output.logits.shape == torch.Size([4, 10])



Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import torch
from torch import nn
from transformers import DebertaPreTrainedModel, DebertaModel
from transformers.modeling_outputs import SequenceClassifierOutput, MultipleChoiceModelOutput
from transformers.activations import ACT2FN


class ContextPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
        self.dropout = StableDropout(config.pooler_dropout)
        self.config = config

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.

        context_token = hidden_states[:, 0]
        context_token = self.dropout(context_token)
        pooled_output = self.dense(context_token)
        pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
        return pooled_output

    @property
    def output_dim(self):
        return self.config.hidden_size


class DropoutContext(object):
    def __init__(self):
        self.dropout = 0
        self.mask = None
        self.scale = 1
        self.reuse_mask = True


def get_mask(input, local_context):
    if not isinstance(local_context, DropoutContext):
        dropout = local_context
        mask = None
    else:
        dropout = local_context.dropout
        dropout *= local_context.scale
        mask = local_context.mask if local_context.reuse_mask else None

    if dropout > 0 and mask is None:
        mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()

    if isinstance(local_context, DropoutContext):
        if local_context.mask is None:
            local_context.mask = mask

    return mask, dropout


class XDropout(torch.autograd.Function):
    """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""

    @staticmethod
    def forward(ctx, input, local_ctx):
        mask, dropout = get_mask(input, local_ctx)
        ctx.scale = 1.0 / (1 - dropout)
        if dropout > 0:
            ctx.save_for_backward(mask)
            return input.masked_fill(mask, 0) * ctx.scale
        else:
            return input

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.scale > 1:
            (mask,) = ctx.saved_tensors
            return grad_output.masked_fill(mask, 0) * ctx.scale, None
        else:
            return grad_output, None


class StableDropout(nn.Module):
    """
    Optimized dropout module for stabilizing the training

    Args:
        drop_prob (float): the dropout probabilities
    """

    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob
        self.count = 0
        self.context_stack = None

    def forward(self, x):
        """
        Call the module

        Args:
            x (:obj:`torch.tensor`): The input tensor to apply dropout
        """
        if self.training and self.drop_prob > 0:
            return XDropout.apply(x, self.get_context())
        return x

    def clear_context(self):
        self.count = 0
        self.context_stack = None

    def init_context(self, reuse_mask=True, scale=1):
        if self.context_stack is None:
            self.context_stack = []
        self.count = 0
        for c in self.context_stack:
            c.reuse_mask = reuse_mask
            c.scale = scale

    def get_context(self):
        if self.context_stack is not None:
            if self.count >= len(self.context_stack):
                self.context_stack.append(DropoutContext())
            ctx = self.context_stack[self.count]
            ctx.dropout = self.drop_prob
            self.count += 1
            return ctx
        else:
            return self.drop_prob


class DebertaForSequenceClassification(DebertaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        num_labels = getattr(config, "num_labels", 2)
        self.num_labels = num_labels

        self.deberta = DebertaModel(config)

        self.classifier = nn.Linear(config.hidden_size, num_labels)
        drop_out = getattr(config, "cls_dropout", None)
        drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
        self.dropout = nn.Dropout(drop_out)

        self.init_weights()

    def get_input_embeddings(self):
        return self.deberta.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        self.deberta.set_input_embeddings(new_embeddings)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.deberta(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = self.dropout(outputs[1])
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                # regression task
                loss_fn = nn.MSELoss()
                logits = logits.view(-1).to(labels.dtype)
                loss = loss_fn(logits, labels.view(-1))
            elif labels.dim() == 1 or labels.size(-1) == 1:
                label_index = (labels >= 0).nonzero()
                labels = labels.long()
                if label_index.size(0) > 0:
                    labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
                    labels = torch.gather(labels, 0, label_index.view(-1))
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
                else:
                    loss = torch.tensor(0).to(logits)
            else:
                log_softmax = nn.LogSoftmax(-1)
                loss = -((log_softmax(logits) * labels).sum(-1)).mean()
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output
        else:
            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )


class DebertaForMultipleChoice(DebertaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.deberta = DebertaModel(config)
        self.pooler = ContextPooler(config)
        output_dim = self.pooler.output_dim
        drop_out = getattr(config, "cls_dropout", None)
        drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
        self.dropout = StableDropout(drop_out)
        self.classifier = nn.Linear(output_dim, 1)

        self.init_weights()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
            :obj:`input_ids` above)
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        outputs = self.deberta(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        encoder_layer = outputs[0]
        pooled_output = self.pooler(encoder_layer)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



In [None]:
from torch import nn
from transformers import Trainer


class MultilabelTrainer(Trainer):
    padding=True
    truncation=True
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels),
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss





In [None]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [None]:
#!/usr/bin/env python
# coding=utf-8
""" Finetuning models on the ECtHR dataset (e.g. Bert, RoBERTa, LEGAL-BERT)."""

import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset
from sklearn.metrics import f1_score
#from trainer import MultilabelTrainer
from scipy.special import expit
from torch import nn
import glob
import shutil
import torch
torch.cuda.empty_cache()
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
#from models.hierbert import HierarchicalBert
#from models.deberta import DebertaForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.9.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

logger = logging.getLogger(__name__)

from transformers import AutoModel, AutoTokenizer

# First, load the tokenizer and pre-trained BERT model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
bert_model = AutoModel.from_pretrained('bert-base-uncased')

# Then, create an instance of HierarchicalBert
max_segments = 64
max_segment_length = 128
HierarchicalBertObj = HierarchicalBert(encoder=bert_model, max_segments=max_segments, max_segment_length=max_segment_length)
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb=256,512,1024"

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    max_seq_length: Optional[int] = field(
        default=4096,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_segments: Optional[int] = field(
        default=64,
        metadata={
            "help": "The maximum number of segments (paragraphs) to be considered. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_seg_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum segment (paragraph) length to be considered. Segments longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    task: Optional[str] = field(
        default='ecthr_a',
        metadata={
            "help": "Define downstream task"
        },
    )
    server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
    server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    hierarchical: bool = field(
        default=True, metadata={"help": "Whether to use a hierarchical variant or not"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    do_lower_case: Optional[bool] = field(
        default=True,
        metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )


def main(training_args):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    model_args = ModelArguments(
        model_name_or_path="bert-base-uncased",
        hierarchical=True,
        do_lower_case=True,
        use_fast_tokenizer=True,
    )
    data_args = DataTrainingArguments(
        max_seq_length=128,
        max_segments=64,
        max_seg_length=128,
        overwrite_cache=False,
        pad_to_max_length=True,
    )


    # Fix boolean parameter
    if model_args.do_lower_case == 'False' or not model_args.do_lower_case:
        model_args.do_lower_case = False
    else:
        model_args.do_lower_case = True

    if model_args.hierarchical == 'False' or not model_args.hierarchical:
        model_args.hierarchical = False
    else:
        model_args.hierarchical = True

    # Setup distant debugging if needed
    if data_args.server_ip and data_args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    # Downloading and loading eurlex dataset from the hub.
    if training_args.do_train:
        train_dataset = load_dataset("lex_glue", name=data_args.task, split="train", data_dir='data', cache_dir=model_args.cache_dir)

    if training_args.do_eval:
        eval_dataset = load_dataset("lex_glue", name=data_args.task, split="validation", data_dir='data', cache_dir=model_args.cache_dir)

    if training_args.do_predict:
        predict_dataset = load_dataset("lex_glue", name=data_args.task, split="test", data_dir='data', cache_dir=model_args.cache_dir)

    # Labels
    label_list = list(range(10))
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=f"{data_args.task}",
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        do_lower_case=model_args.do_lower_case,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    if config.model_type == 'deberta' and model_args.hierarchical:
        model = DebertaForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )

    if model_args.hierarchical:
        # Hack the classifier encoder to use hierarchical BERT
        if config.model_type in ['bert', 'deberta']:
            if config.model_type == 'bert':
                segment_encoder = model.bert
            else:
                segment_encoder = model.deberta
            model_encoder = HierarchicalBert(encoder=segment_encoder,
                                             max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            if config.model_type == 'bert':
                model.bert = model_encoder
            elif config.model_type == 'deberta':
                model.deberta = model_encoder
            else:
                raise NotImplementedError(f"{config.model_type} is no supported yet!")
        elif config.model_type == 'roberta':
            model_encoder = HierarchicalBert(encoder=model.roberta, max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            model.roberta = model_encoder
            # Build a new classification layer, as well
            dense = nn.Linear(config.hidden_size, config.hidden_size)
            dense.load_state_dict(model.classifier.dense.state_dict())  # load weights
            dropout = nn.Dropout(config.hidden_dropout_prob).to(model.device)
            out_proj = nn.Linear(config.hidden_size, config.num_labels).to(model.device)
            out_proj.load_state_dict(model.classifier.out_proj.state_dict())  # load weights
            model.classifier = nn.Sequential(dense, dropout, out_proj).to(model.device)
        elif config.model_type in ['longformer', 'big_bird']:
            pass
        else:
            raise NotImplementedError(f"{config.model_type} is no supported yet!")

    # Preprocessing the datasets
    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    def preprocess_function(examples):
        # Tokenize the texts
        if model_args.hierarchical:
            case_template = [[0] * data_args.max_seg_length]
            if config.model_type == 'roberta':
                batch = {'input_ids': [], 'attention_mask': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                                data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                                data_args.max_segments - len(case_encodings['attention_mask'])))
            else:
                batch = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                            data_args.max_segments - len(case_encodings['attention_mask'])))
                    batch['token_type_ids'].append(case_encodings['token_type_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['token_type_ids'])))
        elif config.model_type in ['longformer', 'big_bird']:
            cases = []
            max_position_embeddings = config.max_position_embeddings - 2 if config.model_type == 'longformer' \
                else config.max_position_embeddings
            for case in examples['text']:
                cases.append(f' {tokenizer.sep_token} '.join(
                    [' '.join(fact.split()[:data_args.max_seg_length]) for fact in case[:data_args.max_segments]]))
            batch = tokenizer(cases, padding=padding, max_length=max_position_embeddings, truncation=True)
            if config.model_type == 'longformer':
                global_attention_mask = np.zeros((len(cases), max_position_embeddings), dtype=np.int32)
                # global attention on cls token
                global_attention_mask[:, 0] = 1
                batch['global_attention_mask'] = list(global_attention_mask)
        else:
            cases = []
            for case in examples['text']:
                cases.append(f'\n'.join(case))
            batch = tokenizer(cases, padding=padding, max_length=512, truncation=True)
        batch["original_labels"] = [[label for label in label_list if label in labels] for labels in examples["labels"]]
        batch["labels"] = [[1 if label in labels else 0 for label in label_list] for labels in examples["labels"]]

        return batch

    if training_args.do_train:
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        # Log a few random samples from the training set:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    if training_args.do_eval:
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )

    if training_args.do_predict:
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        # Fix gold labels
        y_true = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_true[:, :-1] = p.label_ids
        y_true[:, -1] = (np.sum(p.label_ids, axis=1) == 0).astype('int32')
        # Fix predictions
        logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = (expit(logits) > 0.5).astype('int32')
        y_pred = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_pred[:, :-1] = preds
        y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32')
        # Compute scores
        macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
        micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
        return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # Initialize our Trainer
    trainer = MultilabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.save_model()  # Saves the tokenizer too for easy upload

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Predict ***")
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")

        # Get the actual predicted indexes (class labels)
        actual_predictions = predictions[0][:, 1]  # Assuming positive class is at index 1

        # Apply the threshold for binary predictions
        threshold = 0.5
        binary_predictions = (actual_predictions > threshold).astype(int)


        # Get the input text from the dataset
        input_texts = predict_dataset['text']

        # Create a DataFrame to hold the predictions, labels, and input text
        with open("predictions.txt", "w") as f:
            for index, input_text, prediction, label in zip(range(len(input_texts)), input_texts, actual_predictions, labels):
                f.write(f"Index: {index}\n")
                f.write(f"Input Text: {input_text}\n")
                f.write(f"Predictions: {prediction}\n")
                f.write(f"Labels: {label}\n")
                f.write("\n")

    # ... Your existing code ...

    # Access the rows from the dataset associated with correct predictions
    # ... Your existing code ...

# Access the rows from the dataset associated with correct predictions
    correct_predictions_dataset = predict_dataset.filter(
    lambda example, idx: all(actual_predictions[idx] == example["labels"]),
    with_indices=True
)

# Access the rows from the dataset associated with incorrect predictions
    incorrect_predictions_dataset = predict_dataset.filter(
    lambda example, idx: any(actual_predictions[idx] != example["labels"]),
    with_indices=True
)

# Access the rows from the dataset associated with original labels
    original_labels = predict_dataset['labels']

# Get the input texts and input IDs from the datasets
    correct_input_texts = correct_predictions_dataset["text"]
    correct_input_ids = correct_predictions_dataset["input_ids"]
    incorrect_input_texts = incorrect_predictions_dataset["text"]
    incorrect_input_ids = incorrect_predictions_dataset["input_ids"]

# Save the correct predictions to a file
    with open("correct_predictions.txt", "w") as f:
       for idx, input_text, input_id, labels in zip(range(len(correct_input_texts)), correct_input_texts, correct_input_ids, correct_predictions_dataset["labels"]):
         f.write(f"Index: {idx}\n")
         f.write(f"Input Text: {input_text}\n")
         f.write(f"Input IDs: {input_id}\n")
         f.write(f"Labels: {labels}\n")
         f.write("\n")

# Save the incorrect predictions to a file
    # Save the incorrect predictions to a file
    with open("incorrect_predictions.txt", "w") as f:
        for idx, input_text, input_id, labels, original_labels  in zip(
            range(len(incorrect_input_texts)),
            incorrect_input_texts,
            incorrect_input_ids,
            incorrect_predictions_dataset["labels"],
            incorrect_predictions_dataset["original_labels"],

        ):
            f.write(f"Index: {idx}\n")
            f.write(f"Input Text: {input_text}\n")
            f.write(f"Input IDs: {input_id}\n")
            f.write(f"Labels: {labels}\n")  # Write the labels to the file
            f.write(f"Original_Labels: {original_labels}\n")
            f.write("\n")


    # Prediction
    if training_args.do_predict:
        logger.info("*** Predict ***")
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")

        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv")
        if trainer.is_world_process_zero():
            with open(output_predict_file, "w") as writer:
                for index, pred_list in enumerate(predictions[0]):
                    pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
                    writer.write(f"{index}\t{pred_line}\n")

    # Prediction
    if training_args.do_predict:
        logger.info("*** Predict ***")
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")

        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv")
        if trainer.is_world_process_zero():
            with open(output_predict_file, "w") as writer:
                for index, pred_list in enumerate(predictions[0]):
                    pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
                    writer.write(f"{index}\t{pred_line}\n")

    # Clean up checkpoints
    checkpoints = [filepath for filepath in glob.glob(f'{training_args.output_dir}/*/') if '/checkpoint' in filepath]
    for checkpoint in checkpoints:
        shutil.rmtree(checkpoint)


if __name__ == "__main__":
    #For training

    training_args = TrainingArguments(
        do_train = True,
        output_dir=os.getcwd(),
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=8,
        save_steps=500,
        save_total_limit=2,
        fp16=False,
        logging_dir="./logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_first_step=False,
        load_best_model_at_end = True,
        metric_for_best_model="macro-f1",
    )
    #main(training_args)

# For Validation
    training_args = TrainingArguments(
        do_train = False,
        do_eval = True,
        output_dir=os.getcwd(),
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=8,
        save_steps=500,
        save_total_limit=2,
        fp16=False,
        logging_dir="./logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_first_step=False,
        load_best_model_at_end = True,
        metric_for_best_model="macro-f1",
    )
    #main(training_args)

    # For Evaluation
    training_args = TrainingArguments(
        do_train = True,
        do_eval = True,
        do_predict = True,
        output_dir=os.getcwd(),
        overwrite_output_dir=True,
        num_train_epochs=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        save_steps=500,
        save_total_limit=2,
        fp16=False,
        logging_dir="./logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_first_step=False,
        load_best_model_at_end = True,
        metric_for_best_model="micro-f1",
    )
    main(training_args)


You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running tokenizer on train dataset:   0%|          | 0/9000 [00:00<?, ? examples/s]

Running tokenizer on validation dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running tokenizer on prediction dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]



Step,Training Loss,Validation Loss,Macro-f1,Micro-f1
500,0.1657,0.228937,0.341642,0.53053
1000,0.1398,0.215355,0.429466,0.568588
1500,0.1364,0.187615,0.494645,0.637363
2000,0.1426,0.173353,0.447102,0.634127
2500,0.1149,0.154034,0.581512,0.681041
3000,0.1024,0.153453,0.570143,0.677197
3500,0.1107,0.159753,0.560743,0.693396
4000,0.1063,0.144982,0.632543,0.705644
4500,0.1015,0.144922,0.624403,0.703897


***** train metrics *****
  epoch                    =        2.0
  total_flos               = 82258672GF
  train_loss               =      0.137
  train_runtime            = 1:43:28.89
  train_samples            =       9000
  train_samples_per_second =      2.899
  train_steps_per_second   =      0.725


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

***** predict metrics *****
  predict_loss               =     0.1539
  predict_macro-f1           =     0.5896
  predict_micro-f1           =     0.7025
  predict_runtime            = 0:01:42.93
  predict_samples            =       1000
  predict_samples_per_second =      9.715
  predict_steps_per_second   =      2.429


***** predict metrics *****
  predict_loss               =     0.1539
  predict_macro-f1           =     0.5896
  predict_micro-f1           =     0.7025
  predict_runtime            = 0:01:42.49
  predict_samples            =       1000
  predict_samples_per_second =      9.756
  predict_steps_per_second   =      2.439


In [None]:
import re

# Read the "incorrect-predictions.txt" file
with open("/content/incorrect_predictions.txt", "r") as f:
    lines = f.readlines()

# Initialize variables
current_entry = {}
data_entries = []
text_entries = []
# Process the lines and create data entries
for line in lines:
    line = line.strip()
    if line.startswith("Index: "):
        current_entry["Index"] = int(line.split(": ")[1])
    elif line.startswith("Input Text: "):
        textstr = line.split(": ")[1]
        text_entries.append(textstr)
    elif line.startswith("Input IDs: "):
        input_ids_str = line.split(": ")[1]
        input_ids = [int(id_str) for id_str in re.findall(r'\d+', input_ids_str)]
        current_entry["input_ids"] = input_ids
    elif line.startswith("Labels: "):  # Process labels
        labels_str = line.split(": ")[1]
        labels = [int(label_str) for label_str in re.findall(r'\d+', labels_str)]
        current_entry["labels"] = labels
    elif line.startswith("Original_Labels: "):  # Process original labels
        labels_str = line.split(": ")[1]
        original_labels = [int(label_str) for label_str in re.findall(r'\d+', labels_str)]
        current_entry["original_labels"] = original_labels
        data_entries.append(current_entry)
        current_entry = {}

# Extract the list of labels from data entries
#text_lists = [entry.get("Input Text", []) for entry in text_entries]
label_lists = [entry.get("original_labels", []) for entry in data_entries]

print(label_lists)
print(len(label_lists))
print(len(text_entries))
filtered_list = [item for item in set(tuple(lst) for lst in label_lists) if item]
#print(len(filtered_list))
# Print the list of labels for each entry
#for index, labels in enumerate(label_lists):
#    print(f"Entry {index}: Labels = {labels}")


[[6], [4], [3], [3], [], [1], [3], [1], [3, 4], [3], [2], [3], [3], [0], [0], [1], [6], [3], [9], [3, 9], [5], [2], [1], [3], [3, 9], [1], [], [1], [0], [1, 2], [3, 9], [3], [3], [1], [3], [2], [3, 9], [3, 9], [1, 2], [3, 9], [], [4], [2], [2], [], [2, 3], [2], [], [], [1], [2], [7], [9], [], [7], [0], [7], [], [1, 4], [], [1], [3], [9], [4], [3], [9], [], [9], [1, 8], [], [3, 9], [3], [4], [0], [4], [3], [3], [1], [4], [3], [4], [2, 4], [3], [1, 2, 4], [3], [1, 2], [1], [3], [0], [3, 9], [3], [3], [3, 9], [7], [0], [], [3], [7], [0], [1], [6], [4], [4], [3], [0], [2], [3], [2], [1], [], [], [], [3, 4], [0], [3], [3], [4], [2], [], [3], [1, 2], [1], [1], [0, 1], [2], [6], [1, 2, 4], [3], [1, 2], [3], [1, 2, 3], [2], [], [4, 8], [3], [8, 9], [5], [9], [], [1], [3], [3], [], [4], [9], [1, 2], [], [1, 2], [3], [2], [3], [1, 4], [3], [], [], [3], [3, 9], [], [4], [4], [3], [], [0], [3, 9], [], [], [], [0, 2], [1], [3], [3], [3], [1], [3, 9], [3], [1], [9], [], [], [4], [2], [], [1], [3], [

In [None]:
import json
import re

# Read the "incorrect-predictions.txt" file
with open("/content/incorrect_predictions.txt", "r") as f:
    lines = f.readlines()

# Initialize variables
current_entry = {}
data_entries = []


# Process the lines and create data entries
for line in lines:
    line = line.strip()
    if line.startswith("Index: "):
        current_entry["Index"] = int(line.split(": ")[1])
    elif line.startswith("Input Text: "):
        current_entry["text"] = line.split(": ")[1]
    elif line.startswith("Input IDs: "):
        input_ids_str = line.split(": ")[1]
        input_ids = [int(id_str) for id_str in re.findall(r'\d+', input_ids_str)]
        current_entry["input_ids"] = input_ids
        current_entry["labels"] = []  # Assuming no labels for incorrect predictions
        data_entries.append(current_entry)
        current_entry = {}
    elif line.startswith("Labels: "):  # Process labels
        labels_str = line.split(": ")[1]
        labels = [int(label_str) for label_str in re.findall(r'\d+', labels_str)]
        current_entry["labels"] = labels
        data_entries.append(current_entry)
        current_entry = {}
    elif line.startswith("Original_Labels: "):  # Process original labels
        labels_str = line.split(": ")[1]
        labels = [int(label_str) for label_str in re.findall(r'\d+', labels_str)]
        current_entry["original_labels"] = labels
        data_entries.append(current_entry)
        current_entry = {}

# Save the data entries to a JSON file
output_filename = "incorrect_predictions_dataset.json"
with open(output_filename, "w") as json_file:
    json.dump(data_entries, json_file, indent=4)

print(f"Converted incorrect predictions saved to {output_filename}")


Converted incorrect predictions saved to incorrect_predictions_dataset.json


In [None]:
from datasets import load_dataset
from datasets import load_dataset
# Load the dataset
dataset = load_dataset("lex_glue", "ecthr_b")
test_dataset=dataset["train"]
# List of target labels
target_labels = label_lists  # Replace with your list of target labels
print(label_lists)
# Initialize a list to store the matching texts
matching_texts = []

# Iterate through the dataset
for entry in test_dataset:
    if "text" in entry and "labels" in entry:
        labels = entry["labels"]
        text = entry["text"]

        # Check if any label matches the target labels
        if any(label in target_labels for label in labels):
            matching_texts.append(text)
# Print the matching texts
print(len(matching_texts))


[[6], [4], [3], [3], [1, 3], [1], [3], [1], [3, 4], [3], [2], [3], [3], [0], [0, 9], [1], [6], [3], [9], [3, 9], [5], [3, 9], [1], [3], [3, 9], [1], [1], [1], [0], [1, 2], [3, 9], [3], [3], [1], [3], [2], [3, 9], [3, 9], [1, 2], [3, 9], [3], [1, 4], [2], [2], [1, 4, 6], [2, 3], [2], [9], [2, 4], [1], [2], [7], [9], [4], [7], [0], [7], [3, 9], [1, 4], [1, 4, 8], [1], [3], [9], [3, 4], [3], [9], [3], [9], [1, 8], [3], [3, 9], [3], [4], [0], [4, 6], [3], [3], [1], [4], [3], [4], [2, 4], [3], [1, 2, 3, 4], [3], [1, 2], [1], [3], [0], [3, 9], [3], [3], [3, 9], [7], [0], [0, 1], [3], [5, 7], [0], [0, 1], [6], [4], [4], [3], [0], [2], [3], [2], [1], [1, 4], [3], [6], [3, 4], [0], [2, 3], [3], [4], [2], [3], [3, 9], [1, 2], [1], [0, 1], [0, 1], [2], [6], [1, 2, 4], [3], [1, 2], [3], [1, 2, 3], [2], [2, 4], [4, 8], [3, 9], [8], [5], [9], [6], [1], [3], [3], [0, 1, 3], [4], [9], [1, 2], [3, 9], [1, 2, 3], [3], [2], [3], [1, 4], [3], [1], [4], [3], [3, 9], [6], [4], [4, 8], [3], [3], [0], [3, 9],

In [None]:
import json
from datasets import Dataset
from datasets import load_dataset, concatenate_datasets
from datasets import concatenate_datasets, Dataset, Value, Sequence

import pandas as pd

# Load JSON data from file
file_path = '/content/incorrect_predictions_dataset.json'
with open(file_path, 'r') as json_file:
    data = json.load(json_file)

# Create a dictionary to store original_labels
labels_dict = {}
texts = []

# Initialize lists to store texts and labels

# Iterate through the data and extract relevant information
#texts = texts.cast(Sequence(feature=Value(dtype='string', id=None)))
for entry in data:
    if "text" in entry:
        text = entry["text"]
        if "Index" in entry:
            index = entry["Index"]
            #labels.append(labels_dict.get(index, []))
        texts.append(text)
labels=[]
for entry in data:
    if "original_labels" in entry:
        original_labels = entry["original_labels"]
        if "Index" in entry:
            index = entry["Index"]
            #labels.append(labels_dict.get(index, []))
        labels.append(original_labels)

print("Number of Texts:", len(texts))
print("Number of Labels:", len(labels))

data_dict = {"text": texts, "labels": labels}

# Create a dataset using the data dictionary
dataset = Dataset.from_dict(data_dict)

# Convert the "text" feature to a Sequence of strings
dataset = dataset.map(lambda example: {"text": [example["text"]]})

print("Number of Texts:", len(texts))
print("Number of Labels:", len(labels))

Number of Texts: 1000
Number of Labels: 1000


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Number of Texts: 1000
Number of Labels: 1000


In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets
existing_dataset = load_dataset("lex_glue", 'ecthr_a')
existing_train_dataset = existing_dataset["train"]


merged_dataset = Dataset.from_dict({
    "text":  existing_train_dataset["text"] + dataset["text"],
    "labels": existing_train_dataset["labels"] + dataset["labels"],
})


Generating train split:   0%|          | 0/9000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
print("Number of Texts:", len(merged_dataset["text"]))
print("Number of Labels:", len(merged_dataset["labels"]))

Number of Texts: 10000
Number of Labels: 10000


In [None]:
#!/usr/bin/env python
# coding=utf-8
""" Finetuning models on the ECtHR dataset (e.g. Bert, RoBERTa, LEGAL-BERT)."""

import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset
from sklearn.metrics import f1_score
#from trainer import MultilabelTrainer
from scipy.special import expit
from torch import nn
import glob
import shutil

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
#from models.hierbert import HierarchicalBert
#from models.deberta import DebertaForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.9.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

logger = logging.getLogger(__name__)


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    max_seq_length: Optional[int] = field(
        default=4096,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_segments: Optional[int] = field(
        default=64,
        metadata={
            "help": "The maximum number of segments (paragraphs) to be considered. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_seg_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum segment (paragraph) length to be considered. Segments longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    task: Optional[str] = field(
        default='ecthr_a',
        metadata={
            "help": "Define downstream task"
        },
    )
    server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
    server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    hierarchical: bool = field(
        default=True, metadata={"help": "Whether to use a hierarchical variant or not"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    do_lower_case: Optional[bool] = field(
        default=True,
        metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )


def main(training_args):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    model_args = ModelArguments(
        model_name_or_path="bert-base-uncased",
        hierarchical=True,
        do_lower_case=True,
        use_fast_tokenizer=True,
    )
    data_args = DataTrainingArguments(
        max_seq_length=128,
        max_segments=64,
        max_seg_length=128,
        overwrite_cache=False,
        pad_to_max_length=True,
    )

    # Fix boolean parameter
    if model_args.do_lower_case == 'False' or not model_args.do_lower_case:
        model_args.do_lower_case = False
    else:
        model_args.do_lower_case = True

    if model_args.hierarchical == 'False' or not model_args.hierarchical:
        model_args.hierarchical = False
    else:
        model_args.hierarchical = True

    # Setup distant debugging if needed
    if data_args.server_ip and data_args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    # Downloading and loading eurlex dataset from the hub.
    if training_args.do_train:
        train_dataset = merged_dataset

    if training_args.do_eval:
        eval_dataset = load_dataset("lex_glue", name=data_args.task, split="validation", data_dir='data', cache_dir=model_args.cache_dir)

    if training_args.do_predict:
        predict_dataset = load_dataset("lex_glue", name=data_args.task, split="test", data_dir='data', cache_dir=model_args.cache_dir)

    # Labels
    label_list = list(range(10))
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=f"{data_args.task}",
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        do_lower_case=model_args.do_lower_case,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    if config.model_type == 'deberta' and model_args.hierarchical:
        model = DebertaForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )

    if model_args.hierarchical:
        # Hack the classifier encoder to use hierarchical BERT
        if config.model_type in ['bert', 'deberta']:
            if config.model_type == 'bert':
                segment_encoder = model.bert
            else:
                segment_encoder = model.deberta
            model_encoder = HierarchicalBert(encoder=segment_encoder,
                                             max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            if config.model_type == 'bert':
                model.bert = model_encoder
            elif config.model_type == 'deberta':
                model.deberta = model_encoder
            else:
                raise NotImplementedError(f"{config.model_type} is no supported yet!")
        elif config.model_type == 'roberta':
            model_encoder = HierarchicalBert(encoder=model.roberta, max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            model.roberta = model_encoder
            # Build a new classification layer, as well
            dense = nn.Linear(config.hidden_size, config.hidden_size)
            dense.load_state_dict(model.classifier.dense.state_dict())  # load weights
            dropout = nn.Dropout(config.hidden_dropout_prob).to(model.device)
            out_proj = nn.Linear(config.hidden_size, config.num_labels).to(model.device)
            out_proj.load_state_dict(model.classifier.out_proj.state_dict())  # load weights
            model.classifier = nn.Sequential(dense, dropout, out_proj).to(model.device)
        elif config.model_type in ['longformer', 'big_bird']:
            pass
        else:
            raise NotImplementedError(f"{config.model_type} is no supported yet!")

    # Preprocessing the datasets
    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    def preprocess_function(examples):
        # Tokenize the texts
        if model_args.hierarchical:
            case_template = [[0] * data_args.max_seg_length]
            if config.model_type == 'roberta':
                batch = {'input_ids': [], 'attention_mask': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                                data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                                data_args.max_segments - len(case_encodings['attention_mask'])))
            else:
                batch = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                            data_args.max_segments - len(case_encodings['attention_mask'])))
                    batch['token_type_ids'].append(case_encodings['token_type_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['token_type_ids'])))
        elif config.model_type in ['longformer', 'big_bird']:
            cases = []
            max_position_embeddings = config.max_position_embeddings - 2 if config.model_type == 'longformer' \
                else config.max_position_embeddings
            for case in examples['text']:
                cases.append(f' {tokenizer.sep_token} '.join(
                    [' '.join(fact.split()[:data_args.max_seg_length]) for fact in case[:data_args.max_segments]]))
            batch = tokenizer(cases, padding=padding, max_length=max_position_embeddings, truncation=True)
            if config.model_type == 'longformer':
                global_attention_mask = np.zeros((len(cases), max_position_embeddings), dtype=np.int32)
                # global attention on cls token
                global_attention_mask[:, 0] = 1
                batch['global_attention_mask'] = list(global_attention_mask)
        else:
            cases = []
            for case in examples['text']:
                cases.append(f'\n'.join(case))
            batch = tokenizer(cases, padding=padding, max_length=512, truncation=True)

        batch["labels"] = [[1 if label in labels else 0 for label in label_list] for labels in examples["labels"]]

        return batch

    if training_args.do_train:
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        # Log a few random samples from the training set:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    if training_args.do_eval:
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )

    if training_args.do_predict:
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        # Fix gold labels
        y_true = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_true[:, :-1] = p.label_ids
        y_true[:, -1] = (np.sum(p.label_ids, axis=1) == 0).astype('int32')
        # Fix predictions
        logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = (expit(logits) > 0.5).astype('int32')
        y_pred = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_pred[:, :-1] = preds
        y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32')
        # Compute scores
        macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
        micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
        return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # Initialize our Trainer
    trainer = MultilabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.save_model()  # Saves the tokenizer too for easy upload

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(eval_dataset=eval_dataset)

        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Prediction
    if training_args.do_predict:
        logger.info("*** Predict ***")
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")

        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv")
        if trainer.is_world_process_zero():
            with open(output_predict_file, "w") as writer:
                for index, pred_list in enumerate(predictions[0]):
                    pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
                    writer.write(f"{index}\t{pred_line}\n")

    # Clean up checkpoints
    checkpoints = [filepath for filepath in glob.glob(f'{training_args.output_dir}/*/') if '/checkpoint' in filepath]
    for checkpoint in checkpoints:
        shutil.rmtree(checkpoint)


if __name__ == "__main__":
    training_args = TrainingArguments(
        do_train = True,
        do_eval = True,
        do_predict = True,
        output_dir=os.getcwd(),
        overwrite_output_dir=True,
        num_train_epochs=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        save_steps=500,
        save_total_limit=2,
        fp16=False,
        logging_dir="./logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_first_step=False,
        load_best_model_at_end = True,
        metric_for_best_model="micro-f1",
    )
    main(training_args)



You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running tokenizer on train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Running tokenizer on validation dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running tokenizer on prediction dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]



Step,Training Loss,Validation Loss,Macro-f1,Micro-f1
500,0.1944,0.204295,0.373758,0.568284
1000,0.1562,0.196128,0.395173,0.594637
1500,0.1422,0.17672,0.471873,0.617355
2000,0.1441,0.165887,0.52558,0.657028
2500,0.1508,0.149872,0.559435,0.685668
3000,0.1226,0.149508,0.514612,0.676797
3500,0.1245,0.165385,0.561741,0.678783
4000,0.1085,0.143873,0.633314,0.70133
4500,0.1026,0.140962,0.636863,0.711469
5000,0.1024,0.143853,0.646531,0.710652


***** train metrics *****
  epoch                    =        2.0
  total_flos               = 91398524GF
  train_loss               =     0.1431
  train_runtime            = 1:55:16.96
  train_samples            =      10000
  train_samples_per_second =      2.891
  train_steps_per_second   =      0.723


***** eval metrics *****
  epoch                   =        2.0
  eval_loss               =      0.141
  eval_macro-f1           =     0.6369
  eval_micro-f1           =     0.7115
  eval_runtime            = 0:01:42.68
  eval_samples            =       1000
  eval_samples_per_second =      9.739
  eval_steps_per_second   =      2.435
***** predict metrics *****
  predict_loss               =     0.1317
  predict_macro-f1           =     0.6007
  predict_micro-f1           =     0.7276
  predict_runtime            = 0:01:42.68
  predict_samples            =       1000
  predict_samples_per_second =      9.738
  predict_steps_per_second   =      2.435


In [None]:
#!/usr/bin/env python
# coding=utf-8
""" Finetuning models on the ECtHR dataset (e.g. Bert, RoBERTa, LEGAL-BERT)."""

import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset
from sklearn.metrics import f1_score
#from trainer import MultilabelTrainer
from scipy.special import expit
from torch import nn
import glob
import shutil
import torch
torch.cuda.empty_cache()
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
#from models.hierbert import HierarchicalBert
#from models.deberta import DebertaForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.9.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

logger = logging.getLogger(__name__)

from transformers import AutoModel, AutoTokenizer

# First, load the tokenizer and pre-trained BERT model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
bert_model = AutoModel.from_pretrained('bert-base-uncased')

# Then, create an instance of HierarchicalBert
max_segments = 64
max_segment_length = 128
HierarchicalBertObj = HierarchicalBert(encoder=bert_model, max_segments=max_segments, max_segment_length=max_segment_length)
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb=256,512,1024"

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    max_seq_length: Optional[int] = field(
        default=4096,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_segments: Optional[int] = field(
        default=64,
        metadata={
            "help": "The maximum number of segments (paragraphs) to be considered. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_seg_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum segment (paragraph) length to be considered. Segments longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    task: Optional[str] = field(
        default=None,
        metadata={
            "help": "Define downstream task"
        },
    )
    server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
    server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    hierarchical: bool = field(
        default=True, metadata={"help": "Whether to use a hierarchical variant or not"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    do_lower_case: Optional[bool] = field(
        default=True,
        metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )


def main(training_args):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    model_args = ModelArguments(
        model_name_or_path="bert-base-uncased",
        hierarchical=True,
        do_lower_case=True,
        use_fast_tokenizer=True,
    )
    data_args = DataTrainingArguments(
        max_seq_length=128,
        max_segments=64,
        max_seg_length=128,
        overwrite_cache=False,
        pad_to_max_length=True,
    )


    # Fix boolean parameter
    if model_args.do_lower_case == 'False' or not model_args.do_lower_case:
        model_args.do_lower_case = False
    else:
        model_args.do_lower_case = True

    if model_args.hierarchical == 'False' or not model_args.hierarchical:
        model_args.hierarchical = False
    else:
        model_args.hierarchical = True

    # Setup distant debugging if needed
    if data_args.server_ip and data_args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    # Downloading and loading eurlex dataset from the hub.
    if training_args.do_train:
        train_dataset = load_dataset("lex_glue","ecthr_b", split="train", data_dir='data', cache_dir=model_args.cache_dir)

    if training_args.do_eval:
        eval_dataset = load_dataset("lex_glue","ecthr_b", split="validation", data_dir='data', cache_dir=model_args.cache_dir)

    if training_args.do_predict:
        predict_dataset = load_dataset("lex_glue", "ecthr_b" ,split="test", data_dir='data', cache_dir=model_args.cache_dir)

    # Labels
    label_list = list(range(10))
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=f"{data_args.task}",
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        do_lower_case=model_args.do_lower_case,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    if config.model_type == 'deberta' and model_args.hierarchical:
        model = DebertaForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )

    if model_args.hierarchical:
        # Hack the classifier encoder to use hierarchical BERT
        if config.model_type in ['bert', 'deberta']:
            if config.model_type == 'bert':
                segment_encoder = model.bert
            else:
                segment_encoder = model.deberta
            model_encoder = HierarchicalBert(encoder=segment_encoder,
                                             max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            if config.model_type == 'bert':
                model.bert = model_encoder
            elif config.model_type == 'deberta':
                model.deberta = model_encoder
            else:
                raise NotImplementedError(f"{config.model_type} is no supported yet!")
        elif config.model_type == 'roberta':
            model_encoder = HierarchicalBert(encoder=model.roberta, max_segments=data_args.max_segments,
                                             max_segment_length=data_args.max_seg_length)
            model.roberta = model_encoder
            # Build a new classification layer, as well
            dense = nn.Linear(config.hidden_size, config.hidden_size)
            dense.load_state_dict(model.classifier.dense.state_dict())  # load weights
            dropout = nn.Dropout(config.hidden_dropout_prob).to(model.device)
            out_proj = nn.Linear(config.hidden_size, config.num_labels).to(model.device)
            out_proj.load_state_dict(model.classifier.out_proj.state_dict())  # load weights
            model.classifier = nn.Sequential(dense, dropout, out_proj).to(model.device)
        elif config.model_type in ['longformer', 'big_bird']:
            pass
        else:
            raise NotImplementedError(f"{config.model_type} is no supported yet!")

    # Preprocessing the datasets
    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    def preprocess_function(examples):
        # Tokenize the texts
        if model_args.hierarchical:
            case_template = [[0] * data_args.max_seg_length]
            if config.model_type == 'roberta':
                batch = {'input_ids': [], 'attention_mask': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                                data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                                data_args.max_segments - len(case_encodings['attention_mask'])))
            else:
                batch = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
                for case in examples['text']:
                    case_encodings = tokenizer(case[:data_args.max_segments], padding=padding,
                                               max_length=data_args.max_seg_length, truncation=True)
                    batch['input_ids'].append(case_encodings['input_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['input_ids'])))
                    batch['attention_mask'].append(case_encodings['attention_mask'] + case_template * (
                            data_args.max_segments - len(case_encodings['attention_mask'])))
                    batch['token_type_ids'].append(case_encodings['token_type_ids'] + case_template * (
                            data_args.max_segments - len(case_encodings['token_type_ids'])))
        elif config.model_type in ['longformer', 'big_bird']:
            cases = []
            max_position_embeddings = config.max_position_embeddings - 2 if config.model_type == 'longformer' \
                else config.max_position_embeddings
            for case in examples['text']:
                cases.append(f' {tokenizer.sep_token} '.join(
                    [' '.join(fact.split()[:data_args.max_seg_length]) for fact in case[:data_args.max_segments]]))
            batch = tokenizer(cases, padding=padding, max_length=max_position_embeddings, truncation=True)
            if config.model_type == 'longformer':
                global_attention_mask = np.zeros((len(cases), max_position_embeddings), dtype=np.int32)
                # global attention on cls token
                global_attention_mask[:, 0] = 1
                batch['global_attention_mask'] = list(global_attention_mask)
        else:
            cases = []
            for case in examples['text']:
                cases.append(f'\n'.join(case))
            batch = tokenizer(cases, padding=padding, max_length=512, truncation=True)

        batch["labels"] = [[1 if label in labels else 0 for label in label_list] for labels in examples["labels"]]

        return batch

    if training_args.do_train:
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        # Log a few random samples from the training set:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    if training_args.do_eval:
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )

    if training_args.do_predict:
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        # Fix gold labels
        y_true = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_true[:, :-1] = p.label_ids
        y_true[:, -1] = (np.sum(p.label_ids, axis=1) == 0).astype('int32')
        # Fix predictions
        logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = (expit(logits) > 0.5).astype('int32')
        y_pred = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_pred[:, :-1] = preds
        y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32')
        # Compute scores
        macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
        micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
        return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # Initialize our Trainer
    trainer = MultilabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.save_model()  # Saves the tokenizer too for easy upload

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(eval_dataset=eval_dataset)

        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Prediction
    if training_args.do_predict:
        logger.info("*** Predict ***")
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")

        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv")
        if trainer.is_world_process_zero():
            with open(output_predict_file, "w") as writer:
                for index, pred_list in enumerate(predictions[0]):
                    pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
                    writer.write(f"{index}\t{pred_line}\n")

    # Clean up checkpoints
    checkpoints = [filepath for filepath in glob.glob(f'{training_args.output_dir}/*/') if '/checkpoint' in filepath]
    for checkpoint in checkpoints:
        shutil.rmtree(checkpoint)


if __name__ == "__main__":
    #For training

    # For Evaluation
    training_args = TrainingArguments(
        do_train = True,
        do_eval = True,
        do_predict = True,
        output_dir=os.getcwd(),
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        save_steps=500,
        save_total_limit=2,
        fp16=False,
        logging_dir="./logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_first_step=False,
        load_best_model_at_end = True,
        metric_for_best_model="micro-f1",
    )
    main(training_args)


You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running tokenizer on train dataset:   0%|          | 0/9000 [00:00<?, ? examples/s]

Running tokenizer on validation dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running tokenizer on prediction dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]



Step,Training Loss,Validation Loss


KeyboardInterrupt: ignored