# Statistical Natural Language Processing (COMP0087) Demo Notebook
---
---
Greetings and welcome to our demo notebook for our NLP research project! Within this notebook, we will be presenting an automatic text summarisation technique and demonstrating its effectiveness in generating brief summaries of lengthy legal documents. We will then import three of our pre-trained models that are specifically designed for this summarisation method, including binary, multi-class classification, and regression models, and apply them to our withheld test set. For the purpose of this demonstration, we will focus solely on the Long-T5 model for text summarisation.

In [1]:
# imports and pip installs
!pip install gdown --quiet
!pip install datasets --quiet
!pip install transformers --quiet
!pip install sumy --quiet
from datasets import load_dataset
import pandas as pd
import numpy as np
import time
import gdown
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, LongT5ForConditionalGeneration, AutoModel
from sklearn.metrics import classification_report, mean_absolute_error
from tqdm.notebook import tqdm
from pathlib import Path
from __future__ import absolute_import
from __future__ import division, print_function, unicode_literals
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lex_rank import LexRankSummarizer
from sumy.summarizers.text_rank import TextRankSummarizer
from sumy.summarizers.reduction import ReductionSummarizer
from sumy.nlp.stemmers import Stemmer
from sumy.utils import get_stop_words
import nltk
nltk.download('punkt')
# led_url = "https://drive.google.com/drive/u/2/folders/1-gj7w6zxUiyHlz43I3gayV-QjN7EjL_c"
# tr_url = "https://drive.google.com/drive/folders/14wCaRAp9wGe97srtzeVxAfnTvL3ynOkF?usp=sharing"
# model1_url = "https://drive.google.com/drive/folders/17Kke3WwGsnY6MGQIHphPekXA0N7V2Gvu?usp=sharing"
url_lt5 = "https://drive.google.com/drive/folders/1okkRzSbEwjAQMgib1tfqAfx3rlmzcqpp?usp=sharing"
model_lt5_url = "https://drive.google.com/drive/folders/1JB1oaKR2ymt-HHMmlCL8sP_DJf8-iK6m?usp=sharing"
gdown.download_folder(url_lt5)
gdown.download_folder(model_lt5_url)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Retrieving folder list


Processing file 1-0HfMsdn4bAVxz_RrCJ-tZc_g7JlCuq0 anon_test.pkl
Processing file 1dlxnn2sPaCnGYoy9CEZ3nPRTywvbScmb anon_train.pkl
Processing file 1-AnGpzloxNUYQNDmNbU4CqtIpboU117t anon_valid.pkl
Processing file 1-1y03q8cvGyRhq0JQob7hZ8Q_VWyIp5b non-anon_test.pkl
Processing file 1Pu1ShrXaw7jfao6povJeHKHD29qB7ffJ non-anon_train.pkl
Processing file 1dR0YI60OKTjSw5R2fXSD41-nMEpTAo3f non-anon_valid.pkl
Building directory structure completed


Retrieving folder list completed
Building directory structure
Downloading...
From: https://drive.google.com/uc?id=1-0HfMsdn4bAVxz_RrCJ-tZc_g7JlCuq0
To: /content/long_t5_summary/anon_test.pkl
100%|██████████| 33.6M/33.6M [00:00<00:00, 54.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1dlxnn2sPaCnGYoy9CEZ3nPRTywvbScmb
To: /content/long_t5_summary/anon_train.pkl
100%|██████████| 103M/103M [00:00<00:00, 235MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1-AnGpzloxNUYQNDmNbU4CqtIpboU117t
To: /content/long_t5_summary/anon_valid.pkl
100%|██████████| 21.2M/21.2M [00:00<00:00, 205MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-1y03q8cvGyRhq0JQob7hZ8Q_VWyIp5b
To: /content/long_t5_summary/non-anon_test.pkl
100%|██████████| 39.4M/39.4M [00:00<00:00, 166MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Pu1ShrXaw7jfao6povJeHKHD29qB7ffJ
To: /content/long_t5_summary/non-anon_train.pkl
100%|██████████| 113M/113M [00:00<00:00, 203MB/s] 
Downloading...
From: ht

Processing file 18LsuCi8z__elW4BBkzSpdz6zUVuIRBMP long_t5_opt_binary_cls.pt
Processing file 155Np_Cw0Jfup44gERNhiEKc4iGr2bdU9 long_t5_opt_multi_cls.pt
Processing file 16IR_6ENYoTJj75yr36yj8cCsNtefW4wQ long_t5_opt_regression.pt
Building directory structure completed


Retrieving folder list completed
Building directory structure
Downloading...
From: https://drive.google.com/uc?id=18LsuCi8z__elW4BBkzSpdz6zUVuIRBMP
To: /content/long_t5_binary_cls/long_t5_opt_binary_cls.pt
100%|██████████| 140M/140M [00:00<00:00, 219MB/s]
Downloading...
From: https://drive.google.com/uc?id=155Np_Cw0Jfup44gERNhiEKc4iGr2bdU9
To: /content/long_t5_binary_cls/long_t5_opt_multi_cls.pt
100%|██████████| 140M/140M [00:00<00:00, 232MB/s]
Downloading...
From: https://drive.google.com/uc?id=16IR_6ENYoTJj75yr36yj8cCsNtefW4wQ
To: /content/long_t5_binary_cls/long_t5_opt_regression.pt
100%|██████████| 140M/140M [00:01<00:00, 109MB/s] 
Download completed


['/content/long_t5_binary_cls/long_t5_opt_binary_cls.pt',
 '/content/long_t5_binary_cls/long_t5_opt_multi_cls.pt',
 '/content/long_t5_binary_cls/long_t5_opt_regression.pt']

## 1. Data
---
In this section of the notebook, we download and explain the ECHR dataset used in this research project.


In [2]:
# ECHR DATASET
def clean_echr(df):
    """
    Clean ECHR dataset.
    Params:
    `df` (pd.DataFrame): dataframe to clean
    """

    # Drop rows where language is not English
    if (np.unique(df['languageisocode']) == 'ENG')!=True:
        df = df[df['languageisocode'] == 'ENG']

    # Drop rows where text is empty
    df = df[df['text'].apply(lambda x: len(x)) > 0]

    # Convert array of strings to string in the 'text' column
    df['text'] = df['text'].apply(lambda x: ' '.join(x))

    # Dummy variable for when conclusion = 'Inadmissible'
    df['inadmissible'] = df['conclusion'] == 'Inadmissible'

    # Create column for articles raised
    df['all_articles'] = df['violated_articles'] + df['non_violated_articles'] 

    # Keep only columns of interest
    df = df[['itemid', 'text', 'violated_articles', 'violated', 'non_violated_articles', 'all_articles', 'importance', 'inadmissible', 'date', 'docname']]

    return df

def download_echr(name):
    """
    Download and clean ECHR dataset from Hugging Face datasets library.
    Params:
    `name` (str): name of dataset to download, either 'anon' or 'non-anon'
    """

    # Check if name is valid
    if name not in ['anon', 'non-anon']:
        raise ValueError("Name must be either 'anon' or 'non-anon'")

    # Download dataset
    data = load_dataset(path = "jonathanli/echr", name = name)

    # Convert test, train and validation to dataframes
    test_df = pd.DataFrame(data['test'])
    train_df = pd.DataFrame(data['train'])
    valid_df = pd.DataFrame(data['validation'])

    # Clean each dataframe
    test_df = clean_echr(test_df)
    train_df = clean_echr(train_df)
    valid_df = clean_echr(valid_df)

    # Make data and echr folder if necessary
    if not os.path.exists('data'):
        os.mkdir('data')
    if not os.path.exists('data/echr'):
        os.mkdir('data/echr')

    # Save each dataframe to pickle file
    test_df.to_pickle(f'data/echr/{name}_test.pkl')
    train_df.to_pickle(f'data/echr/{name}_train.pkl')
    valid_df.to_pickle(f'data/echr/{name}_valid.pkl')

# Download and clean ECHR datasets
download_echr('non-anon')



  0%|          | 0/3 [00:00<?, ?it/s]

### The ECHR Dataset:

The ECHR (European Court of Human Rights) dataset is a corpus of legal texts and consists of roughly 11.5k court cases sourced from the ECHR public database, with each case containing the relevant facts leading to the judgement. Each datapoint consists of a legal case, a binary varaible indicating whether a European Court of Human Rights article was breached, a multi-class variable of which article was breached, and an importance rating from 1-4 on the score of the case assigned by the ECHR.


In [3]:
# read the data
df = pd.read_pickle("data/echr/non-anon_train.pkl")
df.head()

Unnamed: 0,itemid,text,violated_articles,violated,non_violated_articles,all_articles,importance,inadmissible,date,docname
0,001-60714,The applicant was born in 1943 and lives in La...,[6],True,[],[6],4,False,2002,CASE OF PIETILAINEN v. FINLAND
1,001-100920,"The applicant, Mr Panayiotis Panayi, is a Cypr...",[],False,[],[],4,True,2010,PANAYI v. CYPRUS
2,001-77249,The Salvation Army worked officially in Russia...,"[11, 9]",True,[],"[11, 9]",1,False,2006,CASE OF THE MOSCOW BRANCH OF THE SALVATION ARM...
3,001-4589,"The applicant is a British national, born in 1...",[],False,[],[],4,True,1999,A.J. v. THE UNITED KINGDOM
4,001-83374,The applicant was born in 1967 and lives in Li...,[5],True,[],[5],3,False,2007,CASE OF GAULT v. THE UNITED KINGDOM


### Example Legal Case:

This is an example of the raw legal case text data. The dataset provides a list of facts extractd using regular expressions from the case description.

----


The applicant was born in 1943 and lives in Laukaa. On 5 January 1987 criminal investigations were instituted against the applicant who was taken into police custody the same day in respect of, inter alia, alleged tax frauds. He was released on 16 January 1987. On 5 July and 31 August 1990 the applicant was summoned to appear before the Helsinki City Court (raastuvanoikeus, rådstuvurätt, as from 1 December 1993 Helsinki District Court, käräjäoikeus, tingsrätt) indicted for several aggravated tax frauds. The alleged offences concerned the importation of parts of vehicles and failure to pay relevant tax for them ...

## 2. Automatic Text Summarisation
---
In this section we showcase one of the seven automatic text summarisation methods that were used in our research project - Long-T5 - and demonstrate how it can be used to generate accurate and concise summaries of long legal case information. Long-T5 is an abstractive transformer model that generates summaries by predicting the most likely words to appear in a summary, given the input text.

<!-- In this section we showcase two automatic text summarisation methods that were used in our research project - TextRank and Long-T5 - and demonstrate how they can be used to generate accurate and concise summaries of long legal case information. TextRank is an extractive graph-based method that uses the PageRank algorithm to identify the most important sentences in a document and generate a summary based on those sentences. On the other hand, Long-T5 is an abstractive transformer model that generates summaries by predicting the most likely words to appear in a summary, given the input text. -->



### 2.1. Abstractive Text Summarisation Example (Long-T5)

For demonstration purposes we will produce summaries for the first 10 case texts in the non-anonymised ECHR training set.

In [4]:
# Function for summarisation
def summarise(texts, tokenizer, model, device, min_summary_legnth, max_summary_length):
    """
    Summarise a batch of texts.
    Args:
        texts (list): list of texts to be summarised.
        tokenizer (transformers.models.led.tokenization_led_fast.LEDTokenizerFast): tokenizer for the summarisation model.
        model (transformers.models.led.modeling_led.LEDForConditionalGeneration): model for summarisation.
        device (str): device to run the model on. Either 'cpu' or 'cuda'.
        min_summary_length (int): minimum length of the summary.
        max_summary_length (int): maximum length of the summary.
    Returns:
        batch_summaries (list): batch of summaries.
    """
    # tokenize the batch of texts
    inputs = tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True, max_length=10000).to(device)
    
    # generate summaries for the batch of texts
    summary_ids = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_summary_length, min_length = min_summary_legnth)
    
    # decode the summary ids back into text
    batch_summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    
    return batch_summaries

# move to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# longT5 model and tokenizer
model = LongT5ForConditionalGeneration.from_pretrained("google/long-t5-tglobal-base").to(device)
tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")

if "summary" not in df.columns:
  df["summary"] = ""

batch_size = 2
for j in range(5):
  start_idx = j*batch_size
  end_idx = (j+1)*batch_size
  texts = df['text'][start_idx:end_idx].tolist()
  batch_summaries = summarise(texts, tokenizer, model, device, 512, 512)

  for k, summary in enumerate(batch_summaries):
    df["summary"][start_idx+k] = summary

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["summary"][start_idx+k] = summary


### Example Summarised Legal Case:

This is an example of the first summarised legal case in the dataset.

----


In [5]:
df['summary'][0]

'On 5 July and 31 August 1990 the applicant was summoned to appear before the Helsinki City Court (raastuvanoikeus, rdstuvurätt, as from 1 December 1993 Helsinki District Court, käräjäoikeus, tingsrätt) indicted for several aggravated tax frauds. Furthermore, Chapter 16, Section 5, of the Code of Judicial Procedure provided: “When it is important to wait for a decision of another tribunal or some other body before a decision is given in a pending case, or when some other long-lasting impediment exists, a court may order that the hearing of the case will not be pursued until that obstacle ceases to exist.” According to Chapter 14, Section 7a (19.4.1991/708), of the Code of Judicial Procedure, which came into force on 1 April 1992, charges against defendants accused of committing the same offence must, in principle, be tried together. According to Chapter 14, Section 7a (19.4.1991/708), of the Code of Judicial Procedure, which came into force on 1 April 1992, charges against defendants a

## 3. Inference
---
In the Inference section of our notebook, we import one of our pre-trained and optimized models for each of binary classification, multi-class classification, and regression, specifically tailored to the Long-T5 summarisation model. It is important to note that these models have been fine-tuned on our Long-T5 training sets. The binary classification model takes in summary information of legal cases as input and predicts whether any articles have been violated or not. Similarly, for multi-class classification, the model outputs predictions of which articles have been violated. We evaluate the performance of the model on a withheld and unseen test set, and demonstrate how we have applied the model to make predictions on new legal case summaries and interpret the results. This section highlights the practical use of our NLP techniques for automated legal analysis, demonstrating how they can provide rapid and trustworthy insights to legal practitioners.

In [6]:
# cell for all the required functions

class LegalBertBinaryCls(nn.Module):
    """
    bert-small-uncased on binary classification.
    """
    def __init__(self, legalbert):
        super(LegalBertBinaryCls, self).__init__()
        self.bert = legalbert 
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(self.bert.pooler.dense.out_features, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        losses, logits = self.bert(input_ids, attention_mask, return_dict=False)
        outputs = self.dropout(logits)
        outputs = self.linear(logits)
        preds = self.sigmoid(outputs)

        return preds
    
class LegalBertMultiCls(nn.Module):
    """
    bert-small-uncased on multi-class classification.
    """
    def __init__(self, legalbert, num_classes=23):
        super(LegalBertMultiCls, self).__init__()
        self.bert = legalbert 
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(self.bert.pooler.dense.out_features, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        losses, logits = self.bert(input_ids, attention_mask, return_dict=False)
        outputs = self.dropout(logits)
        outputs = self.linear(logits)
        preds = self.sigmoid(outputs)

        return preds
    
class LegalBertRegression(nn.Module):
    """
    bert-small-uncased on regression.
    """
    def __init__(self, legalbert):
        super(LegalBertRegression, self).__init__()
        self.bert = legalbert 
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(self.bert.pooler.dense.out_features, 1)

    def forward(self, input_ids, attention_mask):
        losses, logits = self.bert(input_ids, attention_mask, return_dict=False)
        outputs = self.dropout(logits)
        outputs = self.linear(logits)

        return outputs

def get_model(task, model):
    """
    Select the model based on the task.
    Params:
    `task` (str): the task to be performed
    `model` (torch.nn.Module): the base model
    """
    if task == "binary_cls":
        model = LegalBertBinaryCls(model)
    elif task == "multi_cls":
        model = LegalBertMultiCls(model)
    elif task == "regression":
        model = LegalBertRegression(model)

    return model

def get_loss_func(task):
    """
    Get the loss function based on the task.
    Params:
    `task` (str): the task to be performed
    """
    if task == "binary_cls" or task == "multi_cls":
        loss_func = nn.BCELoss()
    elif task == "regression":
        loss_func = nn.L1Loss()

    return loss_func

# label encoding
article_dict = {'2': 0, '3': 1, '4': 2, '5': 3, '6': 4, '7': 5, '8': 6, '9': 7, '10': 8, '11': 9, '12': 10, '13': 11, '14': 12, '18': 13, '25': 14, '34': 15, '38': 16, '46': 17, 'P1': 18, 'P4': 19, 'P6': 20, 'P7': 21, 'P12': 22}

def get_one_hot_labels(df, article_dict):
    """
    Create a list of lists with a one-hot encoding of the articles violated for each row in the input dataframe
    Params:
    `df` (pd.DataFrame): dataframe containing the violated articles
    `article_dict` (dict): dictionary mapping article names to indices
    """
    article_dict = {
        '2': 0, '3': 1, '4': 2, '5': 3, '6': 4, '7': 5, '8': 6, '9': 7, '10': 8, '11': 9, 
        '12': 10, '13': 11, '14': 12, '18': 13, '25': 14, '34': 15, '38': 16, '46': 17, 'P1': 18, 'P4': 19, 
        'P6': 20, 'P7': 21, 'P12': 22
    }

    labels = []

    for articles in df["violated_articles"]:
        label = [int(key in articles) for key in article_dict.keys()]
        labels.append(label)

    return labels

def load_data(folder="echr", task="binary_cls", anon=False):
    """
    Load data from pickle files and return train, validation and test sets.
    Params:
    `folder` (str): folder containing the data
    `task` (str): task to perform, either 'binary_cls', 'multi_cls' or 'regression'
    `anon` (bool): whether to load the anonymised data or not
    """
    if anon == False:
        train_df = pd.read_pickle(f"{folder}/non-anon_train.pkl")
        val_df = pd.read_pickle(f"{folder}/non-anon_valid.pkl")
        test_df = pd.read_pickle(f"{folder}/non-anon_test.pkl")
    else:
        train_df = pd.read_pickle(f"data/{folder}/anon_train.pkl")
        val_df = pd.read_pickle(f"data/{folder}/anon_valid.pkl")
        test_df = pd.read_pickle(f"data/{folder}/anon_test.pkl")

    if folder == "echr":
        text_column = "text"
    else:
        text_column = "summary"

    train_texts, val_texts, test_texts = train_df[text_column].tolist(), val_df[text_column].tolist(), test_df[text_column].tolist()

    if task == "binary_cls":
        train_labels = train_df["violated"].astype(int).tolist()
        val_labels = val_df["violated"].astype(int).tolist()
        test_labels = test_df["violated"].astype(int).tolist()
    elif task == "multi_cls":
        train_labels = get_one_hot_labels(train_df, article_dict)
        val_labels = get_one_hot_labels(val_df, article_dict)
        test_labels = get_one_hot_labels(test_df, article_dict)
    elif task == "regression":
        train_labels = train_df["importance"].astype(int).tolist()
        val_labels = val_df["importance"].astype(int).tolist()
        test_labels = test_df["importance"].astype(int).tolist()

    return train_texts, train_labels, val_texts, val_labels, test_texts, test_labels

def generate_tokens(tokenizer, texts, max_length=512):
    """
    Tokenize the input texts.
    Params:
    `tokenizer` (transformers.PreTrainedTokenizer): tokenizer to use
    `texts` (list): list of texts to tokenize
    `max_length` (int): maximum length of the tokenized texts
    """
    tokens = tokenizer.batch_encode_plus(texts, 
        return_tensors = "pt", 
        padding = "max_length",
        truncation = True, 
        max_length = max_length, 
        pad_to_max_length = True, 
        return_token_type_ids = False
    )

    return tokens

def create_dataloader(tokens, labels, batch_size, type):
    """
    Create a dataloader for the input data.
    Params:
    `tokens` (torch.Tensor): tensor containing the tokenized texts
    `labels` (torch.Tensor): tensor containing the labels
    `batch_size` (int): batch size
    `type` (str): type of dataloader to create, either 'train', 'val' or 'test'
    """
    if not isinstance(labels[0], list):
        labels = torch.tensor(labels).unsqueeze(1)
    else:
        labels = torch.tensor(labels)

    data = TensorDataset(tokens.input_ids, tokens.attention_mask, labels)

    if type == "train":
        sampler = RandomSampler(data)
        dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
    elif type == "val":
        sampler = SequentialSampler(data)
        dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
    elif type == "test":
        dataloader = DataLoader(data, batch_size=batch_size)

    return dataloader

def test(test_loader, model, task, model_name):
    """
    Test the inference model on the test set.
    Params:
    `test_loader` (torch.utils.data.DataLoader): dataloader for the test set
    `model` (torch.nn.Module): the model to be tested
    `task` (str): the task to be performed
    `model_name` (str): the name of the model
    """
    start = time.time()
    # load trained model
    model.load_state_dict(torch.load(model_name))

    all_preds = torch.tensor([])
    all_labels = torch.tensor([])
    running_loss = 0

    loss_func = get_loss_func(task)
    model.eval()

    # iterate over batches
    for i, batch in enumerate(test_loader):
        # progress update after every 100 batches.
        if i % 100 == 0:
            print("--> batch {:} of {:}.".format(i, len(test_loader)))
        # push the batch to gpu
        batch = [r.to(device) for r in batch]
        input_ids, attention_mask, labels = batch
        with torch.no_grad():
            # forward pass  
            preds = model(input_ids, attention_mask)
            preds, labels = preds.type(torch.FloatTensor), labels.type(torch.FloatTensor)
            # compute the loss between actual and predicted values
            loss = loss_func(preds, labels)
            # add on to the total loss
            running_loss += loss.item()
            if task == "binary_cls" or task == "multi_cls":
                preds = torch.round(preds)
            all_preds = torch.cat((all_preds, preds), dim=0)
            all_labels = torch.cat((all_labels, labels), dim=0)

    running_loss = running_loss / len(test_loader)
    print(f"----> test loss {running_loss}")
    print(f"----> time taken {time.time()-start}")

    return all_preds, all_labels, running_loss

In [7]:
# use gpu!
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(device)

cuda


In [8]:
if __name__ == "__main__":
    print("=========================")
    print("script starts...")

    folder = "long_t5_summary"
    tasks = ["binary_cls", "multi_cls", "regression"]
    model_names = ["/content/long_t5_binary_cls/long_t5_opt_binary_cls.pt","/content/long_t5_binary_cls/long_t5_opt_multi_cls.pt","/content/long_t5_binary_cls/long_t5_opt_regression.pt" ]
    max_seq_length = 512
    batch_size = 4
    pretrained_model = "nlpaueb/legal-bert-small-uncased"

    print(f"using {device}")
    print(f"folder: {folder}")
    print(f"We will be evaluating the following tasks: {tasks}")

    # loop through each task
    for i, task in enumerate(tasks):
        print("task: ", task)
        # load data
        train_texts, train_labels, val_texts, val_labels, test_texts, test_labels = load_data(folder, task)
        # load model
        model = AutoModel.from_pretrained(pretrained_model, return_dict=False)
        # adapt model
        model = get_model(task, model)
        # move model to gpu
        model.to(device)
        # tokenizer
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        # generate tokens
        test_tokens = generate_tokens(tokenizer, test_texts, max_seq_length)
        # dataloader
        test_loader = create_dataloader(test_tokens, test_labels, batch_size, type="test")
        # run test script
        test_preds, test_labels, test_loss = test(test_loader, model, task, model_names[i])
        # print results
        if task == "binary_cls" or task == "multi_cls":
            report = classification_report(test_labels, test_preds)
            print(f'{task} classification report:')
            print(report)
        elif task == "regression":
            mae = mean_absolute_error(test_labels, test_preds)
            print(f'{task} mean absolute error:')
            print(mae)

    print("script finishes")
    print("=========================")

script starts...
using cuda
folder: long_t5_summary
We will be evaluating the following tasks: ['binary_cls', 'multi_cls', 'regression']
task:  binary_cls


Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--> batch 0 of 750.
--> batch 100 of 750.
--> batch 200 of 750.
--> batch 300 of 750.
--> batch 400 of 750.
--> batch 500 of 750.
--> batch 600 of 750.
--> batch 700 of 750.
----> test loss 0.8703353334168593
----> time taken 25.854230880737305
binary_cls classification report:
              precision    recall  f1-score   support

         0.0       0.40      0.58      0.48      1024
         1.0       0.72      0.55      0.62      1974

    accuracy                           0.56      2998
   macro avg       0.56      0.57      0.55      2998
weighted avg       0.61      0.56      0.57      2998

task:  multi_cls


Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--> batch 0 of 750.
--> batch 100 of 750.
--> batch 200 of 750.
--> batch 300 of 750.
--> batch 400 of 750.
--> batch 500 of 750.
--> batch 600 of 750.
--> batch 700 of 750.
----> test loss 0.24676336768393714
----> time taken 26.093497276306152
multi_cls classification report:
              precision    recall  f1-score   support

           0       0.47      0.48      0.48       118
           1       0.60      0.43      0.50       524
           2       0.00      0.00      0.00         3
           3       0.68      0.43      0.53       383
           4       0.54      0.51      0.53       726
           5       0.00      0.00      0.00         5
           6       0.34      0.20      0.26       215
           7       0.31      0.24      0.27        17
           8       0.54      0.24      0.33       105
           9       0.57      0.30      0.39        54
          10       0.00      0.00      0.00         1
          11       0.39      0.17      0.24       322
          12      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exa

--> batch 0 of 750.
--> batch 100 of 750.
--> batch 200 of 750.
--> batch 300 of 750.
--> batch 400 of 750.
--> batch 500 of 750.
--> batch 600 of 750.
--> batch 700 of 750.
----> test loss 0.5627017760276795
----> time taken 26.297304391860962
regression mean absolute error:
0.56305516
script finishes


All our code for this project can be found on our public github repo (https://github.com/rorycreedon/comp0087_assignment.git).