In [1]:
"""
The dataset contains 50 Indian Supreme Court case documents. F : facts, RLC : Ruling by Lower Court, A : Arguments, P : Precedent, S : Statute, R : Ratio of the decision, RPC : Ruling by Present Court/Final judgement.
"""

'\nThe dataset contains 50 Indian Supreme Court case documents. F : facts, RLC : Ruling by Lower Court, A : Arguments, P : Precedent, S : Statute, R : Ratio of the decision, RPC : Ruling by Present Court/Final judgement.\n'

In [None]:
!pip install transformers==4.2.1 pandas torch

In [None]:
import transformers

print(f"Running on transformers v{transformers.__version__}")

In [3]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from transformers import (AutoTokenizer, AutoModelForSequenceClassification, 
                          PreTrainedModel, BertModel, BertForSequenceClassification,
                          TrainingArguments, Trainer)
from transformers.modeling_outputs import SequenceClassifierOutput

In [6]:
df = pd.read_csv("/content/data1.csv")
df.head()


Unnamed: 0,id,sentence,F,R,RLC,A,P,S,RPC
0,0,they had married the plaintiff and had a numbe...,1,0,0,0,0,0,0
1,1,one lakshminarayana iyer a hindu brahmin who o...,1,0,0,0,0,0,0
2,2,ramalakshmi had married the plaintiff and had ...,1,0,0,0,0,0,0
3,3,they were all alive in december 1924 when laks...,1,0,0,0,0,0,0
4,4,before his death he executed a will on 16th no...,0,1,0,0,0,0,0


In [7]:
label_cols = [c for c in df.columns if c not in ["id", "sentence"]]
label_cols

['F', 'R', 'RLC', 'A', 'P', 'S', 'RPC']

In [8]:
df["labels"] = df[label_cols].values.tolist()
df.head()

Unnamed: 0,id,sentence,F,R,RLC,A,P,S,RPC,labels
0,0,they had married the plaintiff and had a numbe...,1,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0]"
1,1,one lakshminarayana iyer a hindu brahmin who o...,1,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0]"
2,2,ramalakshmi had married the plaintiff and had ...,1,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0]"
3,3,they were all alive in december 1924 when laks...,1,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0]"
4,4,before his death he executed a will on 16th no...,0,1,0,0,0,0,0,"[0, 1, 0, 0, 0, 0, 0]"


In [9]:
mask = np.random.rand(len(df)) < 0.8
df_train = df[mask]
df_test = df[~mask]

(df_train.shape, df_test.shape)

((6858, 10), (1786, 10))

In [10]:
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [11]:
train_encodings = tokenizer(df_train["sentence"].values.tolist(), truncation=True)
test_encodings = tokenizer(df_test["sentence"].values.tolist(), truncation=True)

In [12]:
train_labels = df_train["labels"].values.tolist()
test_labels = df_test["labels"].values.tolist()

In [13]:
class JigsawDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [14]:
train_dataset = JigsawDataset(train_encodings, train_labels)
test_dataset = JigsawDataset(test_encodings, test_labels)

In [15]:
class BertForMultilabelSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
      super().__init__(config)

    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict)

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), 
                            labels.float().view(-1, self.num_labels))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions)

In [16]:
num_labels=7
model = BertForMultilabelSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to('cuda')

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing BertForMultilabelSequenceClassification: ['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transforme

In [17]:
def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True): 
    y_pred = torch.from_numpy(y_pred)
    y_true = torch.from_numpy(y_true)
    if sigmoid: 
      y_pred = y_pred.sigmoid()
    return ((y_pred>thresh)==y_true.bool()).float().mean().item()

In [18]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return {'accuracy_thresh': accuracy_thresh(predictions, labels)}

In [19]:
batch_size = 8
# configure logging so we see training loss
logging_steps = len(train_dataset) // batch_size

args = TrainingArguments(
    output_dir="jigsaw",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps
)

In [20]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer)

In [21]:
trainer.evaluate()

{'eval_loss': 0.7231423854827881,
 'eval_accuracy_thresh': 0.5303151607513428,
 'eval_runtime': 3.9202,
 'eval_samples_per_second': 455.584}

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy Thresh,Runtime,Samples Per Second
1,0.3563,0.329995,0.857143,3.7545,475.695
2,0.3124,0.298995,0.87506,3.3477,533.499
3,0.2749,0.277626,0.882179,3.3263,536.937


TrainOutput(global_step=2574, training_loss=0.314532766554276, metrics={'train_runtime': 171.116, 'train_samples_per_second': 15.042, 'total_flos': 1114381231062204, 'epoch': 3.0})

In [23]:
save_directory = "/content"
model.save_pretrained(save_directory)