In [None]:
# Transformers installation
!pip install datasets evaluate
!pip install transformers==4.28.0

# To install from source instead of the last release, comment the command above and uncomment the following one.
# ! pip install git+https://github.com/huggingface/transformers.git

# Text classification

Text classification is a common NLP task that assigns a label or class to text. Some of the largest companies run text classification in production for a wide range of practical applications. One of the most popular forms of text classification is sentiment analysis, which assigns a label like 🙂 positive, 🙁 negative, or 😐 neutral to a sequence of text.

This guide will show you how to:

1. Finetune [DistilBERT](https://huggingface.co/distilbert-base-uncased) on the [IMDb](https://huggingface.co/datasets/imdb) dataset to determine whether a movie review is positive or negative.
2. Use your finetuned model for inference.

<Tip>
The task illustrated in this tutorial is supported by the following model architectures:

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/albert), [BART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bart), [BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert), [BigBird](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bigbird_pegasus), [BLOOM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/camembert), [CANINE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/canine), [ConvBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/convbert), [CTRL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ctrl), [Data2VecText](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/data2vec-text), [DeBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta), [DeBERTa-v2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta-v2), [DistilBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/distilbert), [ELECTRA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie), [ErnieM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie_m), [ESM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/esm), [FlauBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/flaubert), [FNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/fnet), [Funnel Transformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/funnel), [GPT-Sw3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt-sw3), [OpenAI GPT-2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt2), [GPTBigCode](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_bigcode), [GPT Neo](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox), [GPT-J](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gptj), [I-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ibert), [LayoutLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlm), [LayoutLMv2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv2), [LayoutLMv3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv3), [LED](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/led), [LiLT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/lilt), [LLaMA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/llama), [Longformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/longformer), [LUKE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/luke), [MarkupLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/markuplm), [mBART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/megatron-bert), [MobileBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mobilebert), [MPNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mpnet), [MVP](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mvp), [Nezha](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nezha), [Nyströmformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nystromformer), [OpenAI GPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/openai-gpt), [OPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/opt), [Perceiver](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/perceiver), [PLBart](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/plbart), [QDQBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/reformer), [RemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roformer), [SqueezeBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/squeezebert), [TAPAS](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/tapas), [Transformer-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/transfo-xl), [XLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm), [XLM-RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xmod), [YOSO](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/yoso)


<!--End of the generated tip-->

</Tip>

Before you begin, make sure you have all the necessary libraries installed:

```bash
pip install transformers datasets evaluate
```

We encourage you to login to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to login:

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Start by loading the dataset from the 🤗 Datasets library:

In [3]:
from datasets import Dataset

Old dataset

In [None]:
import pandas as pd
path = '/content/drive/MyDrive/research/projects/supreme_court/cleaned_opninions_9judges.csv'
df = pd.read_csv(path)
df = df.drop(columns=['Unnamed: 0', 'Justice','Case'])
df = df.rename(columns={"Label": "label", "Short_Opinion": "text"}, errors="raise")
df.head()

Unnamed: 0,label,text
0,6,internal revenue code taxpayer may carry back ...
1,6,example company taxable income deductible ples...
2,6,unless otherwise noted treasury regulation ref...
3,6,amca petitioned internal revenue service refun...
4,6,intermet involved specified liability losses s...


In [4]:
import pandas as pd
path = '/content/drive/MyDrive/research/projects/supreme_court/data/opinions_1994_2020_split.csv'
df = pd.read_csv(path)
df = df[df['text'].str.len()>1000]
df = df[df.category=='majority']
# df = df.drop(columns=['url'])
# df = df.rename(columns={"Label": "label", "Short_Opinion": "text"}, errors="raise")
df.head()

Unnamed: 0,author_name,label,category,case_name,url,text
0,Justice Roberts,12,majority,McCutcheon v. Federal Election Comm'n,https://www.courtlistener.com/opinion/2659301/...,There is no right more basic in our democracy ...
1,Justice Roberts,12,majority,McCutcheon v. Federal Election Comm'n,https://www.courtlistener.com/opinion/2659301/...,Any regulation must instead target what we hav...
2,Justice Roberts,12,majority,McCutcheon v. Federal Election Comm'n,https://www.courtlistener.com/opinion/2659301/...,the original donor to the specified For the 20...
3,Justice Roberts,12,majority,McCutcheon v. Federal Election Comm'n,https://www.courtlistener.com/opinion/2659301/...,"the future. In the 2013–2014 election cycle, h..."
4,Justice Roberts,12,majority,McCutcheon v. Federal Election Comm'n,https://www.courtlistener.com/opinion/2659301/...,to refuse adjudication of the case on its meri...


In [5]:
df['category'].value_counts()

majority    19711
Name: category, dtype: int64

In [6]:
df['author_name'].value_counts()

Justice Kennedy      2936
Justice Ginsburg     2380
Justice Breyer       2179
Justice Scalia       1850
Justice Thomas       1792
Justice Alito        1593
Justice Stevens      1323
Justice Souter       1269
Justice Kagan        1205
Justice O'Connor     1101
Justice Sotomayor    1048
Justice Rehnquist     775
Justice Roberts       260
Name: author_name, dtype: int64

In [13]:
df = df.groupby('author_name',as_index = False,group_keys=False).apply(lambda s: s.sample(1500,replace=True))

In [14]:
df['author_name'].value_counts()

Justice Alito        1500
Justice Breyer       1500
Justice Ginsburg     1500
Justice Kagan        1500
Justice Kennedy      1500
Justice O'Connor     1500
Justice Rehnquist    1500
Justice Roberts      1500
Justice Scalia       1500
Justice Sotomayor    1500
Justice Souter       1500
Justice Stevens      1500
Justice Thomas       1500
Name: author_name, dtype: int64

In [15]:
court = Dataset.from_pandas(df[['label','text']])
court = court.remove_columns('__index_level_0__')
court = court.class_encode_column("label")
court

Stringifying the column:   0%|          | 0/19500 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/19500 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'text'],
    num_rows: 19500
})

In [22]:
df.iloc[-100:,1]

18492    8
23348    8
24300    8
15687    8
24079    8
        ..
1100     8
16262    8
16602    8
32131    8
5935     8
Name: label, Length: 100, dtype: int64

In [17]:
court['label']

[2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,


In [None]:
court = court.train_test_split(0.2, stratify_by_column='label')
court

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 15600
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 3900
    })
})

In [None]:
court["test"][0]

{'label': 10,
 'text': 'as Amicus Curiae 25–26 (hereinafter Brief for United States). Yet even if this reading were correct, state and local adminis trative reports, hearings, audits, and investigations of a legislative-type character are presumably just as public, and just as likely to put the Federal Government on notice of a potential fraud, as state and local administrative hearings of an adjudicatory character.9 —————— Sylvia, The False Claims Act: Fraud Against the Government p. 642 (2004) (hereinafter Sylvia). 9 See (“Indeed, the statute would seem to be inconsistent if it included state and local administrative hearings as sources of public disclosures [in Category ] and then, in the next breath, excluded state administrative reports as sources”); In re Natu ral Gas Royalties Qui Tam Litigation, 43–44 (“There is no reason to conclude that Congress intended to limit administrative reports, audits, and investigations to federal actions, while simultaneously allowing all state and

There are two fields in this dataset:

- `text`: the opinion text.
- `label`: a value of the author justice

## Preprocess

The next step is to load a tokenizer to preprocess the `text` field:

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-small-uncased")

Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/989 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

Create a preprocessing function to tokenize `text` and truncate sequences to be no longer than legal-bert's maximum input length:

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

To apply the preprocessing function over the entire dataset, use 🤗 Datasets [map](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) function. You can speed up `map` by setting `batched=True` to process multiple elements of the dataset at once:

In [None]:
# tokenized_imdb = imdb.map(preprocess_function, batched=True)
tokenized_court = court.map(preprocess_function, batched=True)

Map:   0%|          | 0/15600 [00:00<?, ? examples/s]

Map:   0%|          | 0/3900 [00:00<?, ? examples/s]

In [None]:
tokenized_court

DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 15600
    })
    test: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 3900
    })
})

In [None]:
tokenized_court['test'][0]['text']

'as Amicus Curiae 25–26 (hereinafter Brief for United States). Yet even if this reading were correct, state and local adminis trative reports, hearings, audits, and investigations of a legislative-type character are presumably just as public, and just as likely to put the Federal Government on notice of a potential fraud, as state and local administrative hearings of an adjudicatory character.9 —————— Sylvia, The False Claims Act: Fraud Against the Government p. 642 (2004) (hereinafter Sylvia). 9 See (“Indeed, the statute would seem to be inconsistent if it included state and local administrative hearings as sources of public disclosures [in Category ] and then, in the next breath, excluded state administrative reports as sources”); In re Natu ral Gas Royalties Qui Tam Litigation, 43–44 (“There is no reason to conclude that Congress intended to limit administrative reports, audits, and investigations to federal actions, while simultaneously allowing all state and local civil litiga tio

Now create a batch of examples using [DataCollatorWithPadding](https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorWithPadding). It's more efficient to *dynamically pad* the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## Evaluate

Including a metric during training is often helpful for evaluating your model's performance. You can quickly load a evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load the [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Then create a function that passes your predictions and labels to [compute](https://huggingface.co/docs/evaluate/main/en/package_reference/main_classes#evaluate.EvaluationModule.compute) to calculate the accuracy:

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

Your `compute_metrics` function is ready to go now, and you'll return to it when you setup your training.

## Train

Before you start training your model, create a map of the expected ids to their labels with `id2label` and `label2id`:

In [None]:
# id2label = {0: "Justice Breyer", 1: "Justice Ginsburg", 2: "Justice Kennedy", 3: "Justice O'Connor",
#             4: "Justice Rehnquist", 5: "Justice Scalia",6: "Justice Souter", 7: "Justice Stevens", 8: "Justice Thomas"}
# label2id = {"Justice Breyer": 0, "Justice Ginsburg": 1, "Justice Kennedy": 2, "Justice O'Connor":3,
#             "Justice Rehnquist":4, "Justice Scalia":5, "Justice Souter": 6, "Justice Stevens":7, "Justice Thomas":8}

id2label = {0: "Justice Breyer", 1: "Justice Ginsburg", 2: "Justice Kennedy", 3: "Justice O'Connor",
            4: "Justice Rehnquist", 5: "Justice Scalia",6: "Justice Souter", 7: "Justice Stevens", 8: "Justice Thomas",
            9: 'Justice Kagan', 10: 'Justice Alito', 11: 'Justice Sotomayor', 12: 'Justice Roberts'}
label2id = {"Justice Breyer": 0, "Justice Ginsburg": 1, "Justice Kennedy": 2, "Justice O'Connor":3,
            "Justice Rehnquist":4, "Justice Scalia":5, "Justice Souter": 6, "Justice Stevens":7, "Justice Thomas":8,
            'Justice Kagan':9, 'Justice Alito':10, 'Justice Sotomayor':11,'Justice Roberts':12}

<Tip>

If you aren't familiar with finetuning a model with the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer), take a look at the basic tutorial [here](https://huggingface.co/docs/transformers/main/en/tasks/../training#train-with-pytorch-trainer)!

</Tip>

You're ready to start training your model now! Load DistilBERT with [AutoModelForSequenceClassification](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSequenceClassification) along with the number of expected labels, and the label mappings:

In [None]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch.nn as nn
model = AutoModelForSequenceClassification.from_pretrained(
    "nlpaueb/legal-bert-small-uncased", num_labels=13, id2label=id2label, label2id=label2id,
    # hidden_dropout_prob = 0.4, attention_probs_dropout_prob = 0.2, revision =
)

# model = AutoModelForSequenceClassification.from_pretrained(
#     "raminass/scotus", num_labels=13, id2label=id2label, label2id=label2id,
#     revision = 'de256299358d88003756b0345ef1562784a4cc8c'
# )

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

At this point, only three steps remain:

1. Define your training hyperparameters in [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments). The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model). At the end of each epoch, the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) will evaluate the accuracy and save the training checkpoint.
2. Pass the training arguments to [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) along with the model, dataset, tokenizer, data collator, and `compute_metrics` function.
3. Call [train()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train) to finetune your model.

In [None]:
model.classifier

Linear(in_features=512, out_features=13, bias=True)

In [None]:
# model.classifier = nn.Sequential(
#             nn.Linear(512, 100),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(100, 50),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(50, 9),
#         )

In [None]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, e

In [None]:
# for param in model.bert.parameters():
#     param.requires_grad = False
for name, param in model.named_parameters():
    print(name, param.requires_grad)

bert.embeddings.word_embeddings.weight True
bert.embeddings.position_embeddings.weight True
bert.embeddings.token_type_embeddings.weight True
bert.embeddings.LayerNorm.weight True
bert.embeddings.LayerNorm.bias True
bert.encoder.layer.0.attention.self.query.weight True
bert.encoder.layer.0.attention.self.query.bias True
bert.encoder.layer.0.attention.self.key.weight True
bert.encoder.layer.0.attention.self.key.bias True
bert.encoder.layer.0.attention.self.value.weight True
bert.encoder.layer.0.attention.self.value.bias True
bert.encoder.layer.0.attention.output.dense.weight True
bert.encoder.layer.0.attention.output.dense.bias True
bert.encoder.layer.0.attention.output.LayerNorm.weight True
bert.encoder.layer.0.attention.output.LayerNorm.bias True
bert.encoder.layer.0.intermediate.dense.weight True
bert.encoder.layer.0.intermediate.dense.bias True
bert.encoder.layer.0.output.dense.weight True
bert.encoder.layer.0.output.dense.bias True
bert.encoder.layer.0.output.LayerNorm.weight True


In [None]:
model.num_parameters()

35075085

In [None]:
training_args = TrainingArguments(
    output_dir="scotus_new",
    logging_dir="scotus_new/runs/001",
    learning_rate=2e-5, #as in bert paper
    per_device_train_batch_size=16, # in legal-bert 256 however no memory here
    per_device_eval_batch_size=16, # in legal-bert 256 however no memory here
    num_train_epochs=20,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_court["train"],
    eval_dataset=tokenized_court["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Cloning https://huggingface.co/raminass/scotus_new into local empty directory.


Epoch,Training Loss,Validation Loss,Accuracy
1,2.2202,1.444826,0.573333
2,1.2664,0.875578,0.738974
3,0.6543,0.562627,0.840256
4,0.3136,0.450264,0.877436
5,0.1522,0.360827,0.902821
6,0.0604,0.375223,0.907436
7,0.0301,0.411397,0.910769
8,0.0164,0.419255,0.913077
9,0.0108,0.4827,0.912051
10,0.0062,0.4799,0.915385


TrainOutput(global_step=19500, training_loss=0.21541113685797422, metrics={'train_runtime': 9342.0919, 'train_samples_per_second': 33.397, 'train_steps_per_second': 2.087, 'total_flos': 1.8387799252992e+16, 'train_loss': 0.21541113685797422, 'epoch': 20.0})

<Tip>

[Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) applies dynamic padding by default when you pass `tokenizer` to it. In this case, you don't need to specify a data collator explicitly.

</Tip>

Once training is completed, share your model to the Hub with the [push_to_hub()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub) method so everyone can use your model:

In [None]:
trainer.push_to_hub()

Several commits (2) will be pushed upstream.
The progress bars may be unreliable.


Upload file runs/001/events.out.tfevents.1692347595.a225c33d5d7e.257.2:   0%|          | 1.00/17.3k [00:00<?, …

To https://huggingface.co/raminass/scotus_new
   8e06715..5bb0cab  main -> main

   8e06715..5bb0cab  main -> main

To https://huggingface.co/raminass/scotus_new
   5bb0cab..21cc475  main -> main

   5bb0cab..21cc475  main -> main



'https://huggingface.co/raminass/scotus_new/commit/5bb0cab07c7c627436a392ff2cd4c6e4b0d4303d'

<Tip>

For a more in-depth example of how to finetune a model for text classification, take a look at the corresponding
[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)
or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).

</Tip>

## Inference

Great, now that you've finetuned a model, you can use it for inference!

Grab some text you'd like to run inference on:

In [None]:
text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."

The simplest way to try out your finetuned model for inference is to use it in a [pipeline()](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline). Instantiate a `pipeline` for sentiment analysis with your model, and pass your text to it:

In [None]:
from transformers import pipeline

classifier = pipeline("text-classification", model="raminass/my_legal_model")
classifier(text)

Downloading pytorch_model.bin:   0%|          | 0.00/140M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/702k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

[{'label': 'SOUTER', 'score': 0.15111537277698517}]

You can also manually replicate the results of the `pipeline` if you'd like:

Tokenize the text and return PyTorch tensors:

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_model")
inputs = tokenizer(text, return_tensors="pt")

Pass your inputs to the model and return the `logits`:

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("stevhliu/my_awesome_model")
with torch.no_grad():
    logits = model(**inputs).logits

Get the class with the highest probability, and use the model's `id2label` mapping to convert it to a text label:

In [None]:
predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]

'POSITIVE'