In [8]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'..')))
from utils import load_tensor, save_model
from models import BiLSTM, CNN_BiLSTM
from evaluation_functions import class_accuracy, class_f1_score

In [9]:
def remap_targets(target_tensor: torch.TensorType, old_le, new_le):
    inverse_tensor = old_le.inverse_transform(target_tensor.long())
    for idx, label in enumerate(inverse_tensor):
        if label == 'ARG_RESPONDENT' or label == 'ARG_PETITIONER':
            inverse_tensor[idx] = "ARG"
        elif label == 'PRE_NOT_RELIED' or label == 'PRE_RELIED':
            inverse_tensor[idx] = 'PRE'
    new_tensor = torch.tensor(new_le.transform(inverse_tensor))
    return new_tensor

In [10]:
list_of_targets_old = ['ISSUE', 'FAC', 'NONE', 'ARG_PETITIONER', 'PRE_NOT_RELIED', 'STA', 'RPC', 'ARG_RESPONDENT', 'PREAMBLE', 'ANALYSIS', 'RLC', 'PRE_RELIED', 'RATIO']
label_encoder_old = LabelEncoder().fit(list_of_targets_old)
list_of_targets_new = ['ISSUE', 'FAC', 'NONE', 'ARG', 'PRE', 'STA', 'RPC', 'PREAMBLE', 'ANALYSIS', 'RLC', 'RATIO']
label_encoder_new = LabelEncoder().fit(list_of_targets_new)

In [11]:
sample_input, sample_target = None, None
for idx in range(246):
    if sample_input is None:
        sample_input = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        sample_target = load_tensor(filepath=f"../train_document/doc_{idx}/label")
    else:
        sample_input = torch.cat((sample_input,load_tensor(filepath=f"../train_document/doc_{idx}/embedding")), dim=0)
        sample_target = torch.cat((sample_target,load_tensor(filepath=f"../train_document/doc_{idx}/label")), dim=0)

In [12]:
sample_target.size()

torch.Size([28864])

In [13]:
remapped_target = remap_targets(sample_target, label_encoder_old, label_encoder_new)

In [14]:
from sklearn.utils.class_weight import compute_class_weight
    
class_weights = compute_class_weight(class_weight = "balanced",
                                    classes = np.unique(remapped_target.numpy()),
                                    y = remapped_target.numpy())
class_weights = torch.FloatTensor(class_weights)
class_weights

tensor([0.2460, 1.3107, 0.4584, 7.1499, 1.8544, 1.6639, 0.6347, 3.8990, 3.5033,
        2.4341, 5.4895])

In [17]:
def calculate_confusion_matrix(test_emb, test_labels, model, num_labels = 11):
    """

    Parameters:

    Returns:
    """
    model.eval()
    output = model(test_emb)
    return confusion_matrix(output, test_labels, num_labels)

def confusion_matrix(y_pred, y_true, num_classes):
    """
    Create a confusion matrix for label encodings in PyTorch.

    Parameters:
    y_pred (torch.Tensor): Predicted labels tensor.
    y_true (torch.Tensor): True labels tensor.
    num_classes (int): Number of classes.

    Returns:
    numpy.ndarray: Confusion matrix.
    """ 
    conf_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    y_pred_np = y_pred.argmax(dim=1).cpu().numpy()
    y_true_np = y_true.cpu().numpy()

    for pred, true in zip(y_pred_np, y_true_np):
        conf_matrix[pred, true] += 1

    return conf_matrix

# BERT-Base

## BiLSTM

In [19]:
model1 = BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model1.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model1.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model1(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.03:
        break
# batch_loss.append(loss.item())


-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:14, 16.83it/s]

100%|██████████| 246/246 [00:11<00:00, 22.25it/s]


Epoch: 1 	 Loss: 1.41150


100%|██████████| 246/246 [00:08<00:00, 27.47it/s]


Epoch: 2 	 Loss: 0.91454


100%|██████████| 246/246 [00:07<00:00, 33.86it/s]


Epoch: 3 	 Loss: 0.77575


100%|██████████| 246/246 [00:07<00:00, 34.66it/s]


Epoch: 4 	 Loss: 0.69199


100%|██████████| 246/246 [00:07<00:00, 33.37it/s]


Epoch: 5 	 Loss: 0.63587


100%|██████████| 246/246 [00:07<00:00, 34.97it/s]


Epoch: 6 	 Loss: 0.56841


100%|██████████| 246/246 [00:07<00:00, 32.49it/s]


Epoch: 7 	 Loss: 0.51521


100%|██████████| 246/246 [00:07<00:00, 34.95it/s]


Epoch: 8 	 Loss: 0.48346


100%|██████████| 246/246 [00:06<00:00, 35.24it/s]


Epoch: 9 	 Loss: 0.42910


100%|██████████| 246/246 [00:08<00:00, 27.94it/s]


Epoch: 10 	 Loss: 0.39686


100%|██████████| 246/246 [00:07<00:00, 30.97it/s]


Epoch: 11 	 Loss: 0.36762


100%|██████████| 246/246 [00:07<00:00, 31.83it/s]


Epoch: 12 	 Loss: 0.33663


100%|██████████| 246/246 [00:07<00:00, 32.47it/s]


Epoch: 13 	 Loss: 0.31116


100%|██████████| 246/246 [00:07<00:00, 31.76it/s]


Epoch: 14 	 Loss: 0.29202


100%|██████████| 246/246 [00:08<00:00, 30.60it/s]


Epoch: 15 	 Loss: 0.26526


100%|██████████| 246/246 [00:07<00:00, 31.77it/s]


Epoch: 16 	 Loss: 0.27301


100%|██████████| 246/246 [00:07<00:00, 32.70it/s]


Epoch: 17 	 Loss: 0.24896


100%|██████████| 246/246 [00:07<00:00, 31.42it/s]


Epoch: 18 	 Loss: 0.23896


100%|██████████| 246/246 [00:07<00:00, 32.74it/s]


Epoch: 19 	 Loss: 0.21332


100%|██████████| 246/246 [00:08<00:00, 30.22it/s]


Epoch: 20 	 Loss: 0.21086


100%|██████████| 246/246 [00:07<00:00, 32.45it/s]


Epoch: 21 	 Loss: 0.20686


100%|██████████| 246/246 [00:07<00:00, 32.37it/s]


Epoch: 22 	 Loss: 0.19512


100%|██████████| 246/246 [00:07<00:00, 32.95it/s]


Epoch: 23 	 Loss: 0.17104


100%|██████████| 246/246 [00:07<00:00, 32.14it/s]


Epoch: 24 	 Loss: 0.15712


100%|██████████| 246/246 [00:07<00:00, 31.24it/s]


Epoch: 25 	 Loss: 0.14813


100%|██████████| 246/246 [00:07<00:00, 32.98it/s]


Epoch: 26 	 Loss: 0.13339


100%|██████████| 246/246 [00:09<00:00, 27.33it/s]


Epoch: 27 	 Loss: 0.12302


100%|██████████| 246/246 [00:08<00:00, 29.52it/s]


Epoch: 28 	 Loss: 0.11978


100%|██████████| 246/246 [00:07<00:00, 33.02it/s]


Epoch: 29 	 Loss: 0.11987


100%|██████████| 246/246 [00:07<00:00, 31.09it/s]


Epoch: 30 	 Loss: 0.12943


100%|██████████| 246/246 [00:07<00:00, 31.35it/s]


Epoch: 31 	 Loss: 0.11539


100%|██████████| 246/246 [00:07<00:00, 32.42it/s]


Epoch: 32 	 Loss: 0.10335


100%|██████████| 246/246 [00:08<00:00, 29.97it/s]


Epoch: 33 	 Loss: 0.09870


100%|██████████| 246/246 [00:07<00:00, 32.64it/s]


Epoch: 34 	 Loss: 0.11609


100%|██████████| 246/246 [00:08<00:00, 30.70it/s]


Epoch: 35 	 Loss: 0.11285


100%|██████████| 246/246 [00:07<00:00, 32.66it/s]


Epoch: 36 	 Loss: 0.08564


100%|██████████| 246/246 [00:07<00:00, 31.61it/s]


Epoch: 37 	 Loss: 0.08279


100%|██████████| 246/246 [00:07<00:00, 32.59it/s]


Epoch: 38 	 Loss: 0.07980


100%|██████████| 246/246 [00:07<00:00, 31.91it/s]


Epoch: 39 	 Loss: 0.09614


100%|██████████| 246/246 [00:08<00:00, 30.40it/s]


Epoch: 40 	 Loss: 0.11100


100%|██████████| 246/246 [00:07<00:00, 32.51it/s]


Epoch: 41 	 Loss: 0.09006


100%|██████████| 246/246 [00:07<00:00, 31.30it/s]


Epoch: 42 	 Loss: 0.06765


100%|██████████| 246/246 [00:08<00:00, 30.60it/s]


Epoch: 43 	 Loss: 0.05784


100%|██████████| 246/246 [00:07<00:00, 32.76it/s]


Epoch: 44 	 Loss: 0.05760


100%|██████████| 246/246 [00:08<00:00, 30.38it/s]


Epoch: 45 	 Loss: 0.06051


100%|██████████| 246/246 [00:07<00:00, 32.12it/s]


Epoch: 46 	 Loss: 0.07958


100%|██████████| 246/246 [00:07<00:00, 32.11it/s]


Epoch: 47 	 Loss: 0.08196


100%|██████████| 246/246 [00:08<00:00, 30.65it/s]


Epoch: 48 	 Loss: 0.05529


100%|██████████| 246/246 [00:07<00:00, 32.96it/s]


Epoch: 49 	 Loss: 0.05809


100%|██████████| 246/246 [00:07<00:00, 31.71it/s]


Epoch: 50 	 Loss: 0.08483


100%|██████████| 246/246 [00:07<00:00, 31.88it/s]


Epoch: 51 	 Loss: 0.05141


100%|██████████| 246/246 [00:07<00:00, 32.26it/s]


Epoch: 52 	 Loss: 0.05096


100%|██████████| 246/246 [00:07<00:00, 31.66it/s]


Epoch: 53 	 Loss: 0.04988


100%|██████████| 246/246 [00:07<00:00, 30.85it/s]


Epoch: 54 	 Loss: 0.03855


100%|██████████| 246/246 [00:07<00:00, 32.32it/s]


Epoch: 55 	 Loss: 0.05664


100%|██████████| 246/246 [00:07<00:00, 33.12it/s]


Epoch: 56 	 Loss: 0.05706


100%|██████████| 246/246 [00:07<00:00, 31.27it/s]


Epoch: 57 	 Loss: 0.04339


100%|██████████| 246/246 [00:08<00:00, 29.62it/s]


Epoch: 58 	 Loss: 0.04975


100%|██████████| 246/246 [00:07<00:00, 31.36it/s]


Epoch: 59 	 Loss: 0.03555


100%|██████████| 246/246 [00:07<00:00, 31.59it/s]


Epoch: 60 	 Loss: 0.04113


100%|██████████| 246/246 [00:07<00:00, 31.63it/s]


Epoch: 61 	 Loss: 0.04124


100%|██████████| 246/246 [00:09<00:00, 26.94it/s]


Epoch: 62 	 Loss: 0.05722


100%|██████████| 246/246 [00:08<00:00, 30.70it/s]


Epoch: 63 	 Loss: 0.05300


100%|██████████| 246/246 [00:07<00:00, 32.00it/s]


Epoch: 64 	 Loss: 0.04585


100%|██████████| 246/246 [00:07<00:00, 33.19it/s]

Epoch: 65 	 Loss: 0.02967





In [20]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model1, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))
# result.append((config, (average_accuracy, average_f1)))

Accuracies: [0.81427072 0.41414141 0.84507042 0.80392157 0.93157895 0.49206349
 0.99603175 0.57377049 0.67857143 0.84946237 0.53125   ] 
 Average acccuracy: 0.7209211455239293
F1 Scores: [0.8049792  0.40394084 0.83696595 0.81188114 0.9315789  0.61386134
 0.9960317  0.52631574 0.44186042 0.87292813 0.557377  ] 
 Average F1: 0.7088836671511293


## CNN-BiLSTM

In [21]:
model2 = CNN_BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model2.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model2.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model2(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.03:
        break
# batch_loss.append(loss.item())

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:13, 18.56it/s]

100%|██████████| 246/246 [00:08<00:00, 27.69it/s]


Epoch: 1 	 Loss: 1.57144


100%|██████████| 246/246 [00:08<00:00, 30.27it/s]


Epoch: 2 	 Loss: 1.03683


100%|██████████| 246/246 [00:08<00:00, 29.91it/s]


Epoch: 3 	 Loss: 0.88318


100%|██████████| 246/246 [00:08<00:00, 28.86it/s]


Epoch: 4 	 Loss: 0.80489


100%|██████████| 246/246 [00:08<00:00, 28.64it/s]


Epoch: 5 	 Loss: 0.72276


100%|██████████| 246/246 [00:08<00:00, 27.90it/s]


Epoch: 6 	 Loss: 0.66418


100%|██████████| 246/246 [00:08<00:00, 29.36it/s]


Epoch: 7 	 Loss: 0.61806


100%|██████████| 246/246 [00:09<00:00, 26.97it/s]


Epoch: 8 	 Loss: 0.58899


100%|██████████| 246/246 [00:08<00:00, 28.70it/s]


Epoch: 9 	 Loss: 0.56366


100%|██████████| 246/246 [00:08<00:00, 28.62it/s]


Epoch: 10 	 Loss: 0.52156


100%|██████████| 246/246 [00:08<00:00, 29.12it/s]


Epoch: 11 	 Loss: 0.48422


100%|██████████| 246/246 [00:08<00:00, 27.80it/s]


Epoch: 12 	 Loss: 0.44806


100%|██████████| 246/246 [00:08<00:00, 29.27it/s]


Epoch: 13 	 Loss: 0.44566


100%|██████████| 246/246 [00:08<00:00, 28.27it/s]


Epoch: 14 	 Loss: 0.41404


100%|██████████| 246/246 [00:08<00:00, 28.57it/s]


Epoch: 15 	 Loss: 0.40906


100%|██████████| 246/246 [00:08<00:00, 28.67it/s]


Epoch: 16 	 Loss: 0.37738


100%|██████████| 246/246 [00:08<00:00, 27.94it/s]


Epoch: 17 	 Loss: 0.34730


100%|██████████| 246/246 [00:08<00:00, 28.75it/s]


Epoch: 18 	 Loss: 0.31065


100%|██████████| 246/246 [00:08<00:00, 28.84it/s]


Epoch: 19 	 Loss: 0.29175


100%|██████████| 246/246 [00:08<00:00, 29.21it/s]


Epoch: 20 	 Loss: 0.31108


100%|██████████| 246/246 [00:08<00:00, 27.49it/s]


Epoch: 21 	 Loss: 0.26481


100%|██████████| 246/246 [00:08<00:00, 28.46it/s]


Epoch: 22 	 Loss: 0.24377


100%|██████████| 246/246 [00:08<00:00, 29.57it/s]


Epoch: 23 	 Loss: 0.23882


100%|██████████| 246/246 [00:08<00:00, 28.61it/s]


Epoch: 24 	 Loss: 0.21601


100%|██████████| 246/246 [00:08<00:00, 27.83it/s]


Epoch: 25 	 Loss: 0.21788


100%|██████████| 246/246 [00:08<00:00, 29.15it/s]


Epoch: 26 	 Loss: 0.19879


100%|██████████| 246/246 [00:08<00:00, 28.50it/s]


Epoch: 27 	 Loss: 0.20479


100%|██████████| 246/246 [00:08<00:00, 29.16it/s]


Epoch: 28 	 Loss: 0.18498


100%|██████████| 246/246 [00:08<00:00, 28.23it/s]


Epoch: 29 	 Loss: 0.17723


100%|██████████| 246/246 [00:08<00:00, 29.68it/s]


Epoch: 30 	 Loss: 0.17821


100%|██████████| 246/246 [00:08<00:00, 28.72it/s]


Epoch: 31 	 Loss: 0.15285


100%|██████████| 246/246 [00:08<00:00, 27.61it/s]


Epoch: 32 	 Loss: 0.15021


100%|██████████| 246/246 [00:10<00:00, 23.44it/s]


Epoch: 33 	 Loss: 0.14174


100%|██████████| 246/246 [00:10<00:00, 23.62it/s]


Epoch: 34 	 Loss: 0.15063


100%|██████████| 246/246 [00:11<00:00, 21.90it/s]


Epoch: 35 	 Loss: 0.12808


100%|██████████| 246/246 [00:12<00:00, 19.99it/s]


Epoch: 36 	 Loss: 0.12488


100%|██████████| 246/246 [00:10<00:00, 22.57it/s]


Epoch: 37 	 Loss: 0.11149


100%|██████████| 246/246 [00:09<00:00, 26.22it/s]


Epoch: 38 	 Loss: 0.11871


100%|██████████| 246/246 [00:08<00:00, 29.54it/s]


Epoch: 39 	 Loss: 0.13073


100%|██████████| 246/246 [00:08<00:00, 28.58it/s]


Epoch: 40 	 Loss: 0.10375


100%|██████████| 246/246 [00:08<00:00, 28.61it/s]


Epoch: 41 	 Loss: 0.09156


100%|██████████| 246/246 [00:08<00:00, 29.85it/s]


Epoch: 42 	 Loss: 0.08509


100%|██████████| 246/246 [00:08<00:00, 28.71it/s]


Epoch: 43 	 Loss: 0.08204


100%|██████████| 246/246 [00:08<00:00, 29.69it/s]


Epoch: 44 	 Loss: 0.08164


100%|██████████| 246/246 [00:09<00:00, 27.31it/s]


Epoch: 45 	 Loss: 0.18878


100%|██████████| 246/246 [00:08<00:00, 28.21it/s]


Epoch: 46 	 Loss: 0.18346


100%|██████████| 246/246 [00:08<00:00, 30.26it/s]


Epoch: 47 	 Loss: 0.08662


100%|██████████| 246/246 [00:08<00:00, 29.26it/s]


Epoch: 48 	 Loss: 0.06925


100%|██████████| 246/246 [00:08<00:00, 29.06it/s]


Epoch: 49 	 Loss: 0.05800


100%|██████████| 246/246 [00:08<00:00, 29.46it/s]


Epoch: 50 	 Loss: 0.05669


100%|██████████| 246/246 [00:08<00:00, 28.51it/s]


Epoch: 51 	 Loss: 0.05313


100%|██████████| 246/246 [00:08<00:00, 29.19it/s]


Epoch: 52 	 Loss: 0.05896


100%|██████████| 246/246 [00:08<00:00, 28.44it/s]


Epoch: 53 	 Loss: 0.05973


100%|██████████| 246/246 [00:08<00:00, 28.94it/s]


Epoch: 54 	 Loss: 0.05347


100%|██████████| 246/246 [00:08<00:00, 28.27it/s]


Epoch: 55 	 Loss: 0.06421


100%|██████████| 246/246 [00:08<00:00, 29.07it/s]


Epoch: 56 	 Loss: 0.05566


100%|██████████| 246/246 [00:08<00:00, 30.18it/s]


Epoch: 57 	 Loss: 0.06848


100%|██████████| 246/246 [00:08<00:00, 29.21it/s]


Epoch: 58 	 Loss: 0.07935


100%|██████████| 246/246 [00:08<00:00, 29.13it/s]


Epoch: 59 	 Loss: 0.08931


100%|██████████| 246/246 [00:08<00:00, 30.31it/s]


Epoch: 60 	 Loss: 0.07645


100%|██████████| 246/246 [00:08<00:00, 29.22it/s]


Epoch: 61 	 Loss: 0.04987


100%|██████████| 246/246 [00:09<00:00, 26.18it/s]


Epoch: 62 	 Loss: 0.04256


100%|██████████| 246/246 [00:08<00:00, 29.21it/s]


Epoch: 63 	 Loss: 0.03557


100%|██████████| 246/246 [00:08<00:00, 29.35it/s]


Epoch: 64 	 Loss: 0.03461


100%|██████████| 246/246 [00:08<00:00, 29.82it/s]


Epoch: 65 	 Loss: 0.03443


100%|██████████| 246/246 [00:08<00:00, 28.27it/s]


Epoch: 66 	 Loss: 0.03365


100%|██████████| 246/246 [00:08<00:00, 30.10it/s]


Epoch: 67 	 Loss: 0.03651


100%|██████████| 246/246 [00:08<00:00, 28.93it/s]


Epoch: 68 	 Loss: 0.05880


100%|██████████| 246/246 [00:08<00:00, 28.61it/s]


Epoch: 69 	 Loss: 0.06639


100%|██████████| 246/246 [00:08<00:00, 28.70it/s]


Epoch: 70 	 Loss: 0.04269


100%|██████████| 246/246 [00:08<00:00, 30.04it/s]


Epoch: 71 	 Loss: 0.05038


100%|██████████| 246/246 [00:08<00:00, 29.19it/s]


Epoch: 72 	 Loss: 0.04729


100%|██████████| 246/246 [00:08<00:00, 29.05it/s]


Epoch: 73 	 Loss: 0.04009


100%|██████████| 246/246 [00:08<00:00, 27.92it/s]


Epoch: 74 	 Loss: 0.04420


100%|██████████| 246/246 [00:08<00:00, 27.67it/s]


Epoch: 75 	 Loss: 0.03215


100%|██████████| 246/246 [00:08<00:00, 29.58it/s]

Epoch: 76 	 Loss: 0.02146





In [22]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model2, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.79356846 0.37878788 0.78288431 0.73076923 0.93333333 0.65405405
 0.98242188 0.59375    0.62686567 0.85714286 0.5       ] 
 Average acccuracy: 0.7121434250979548
F1 Scores: [0.78906648 0.2941176  0.81652888 0.74509799 0.90810806 0.71810084
 0.99015743 0.55882348 0.45901635 0.90322576 0.57971009] 
 Average F1: 0.7056320859173362


# LEGAL-BERT

## BiLSTM

In [23]:
model3 = BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model3.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model3.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model3(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.1:
        break

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:17, 14.13it/s]

100%|██████████| 246/246 [00:08<00:00, 29.79it/s]


Epoch: 1 	 Loss: 2.22225


100%|██████████| 246/246 [00:07<00:00, 34.01it/s]


Epoch: 2 	 Loss: 1.87958


100%|██████████| 246/246 [00:07<00:00, 35.06it/s]


Epoch: 3 	 Loss: 1.73143


100%|██████████| 246/246 [00:07<00:00, 34.34it/s]


Epoch: 4 	 Loss: 1.63950


100%|██████████| 246/246 [00:07<00:00, 32.79it/s]


Epoch: 5 	 Loss: 1.55702


100%|██████████| 246/246 [00:07<00:00, 33.15it/s]


Epoch: 6 	 Loss: 1.47517


100%|██████████| 246/246 [00:07<00:00, 31.82it/s]


Epoch: 7 	 Loss: 1.43047


100%|██████████| 246/246 [00:07<00:00, 34.15it/s]


Epoch: 8 	 Loss: 1.37735


100%|██████████| 246/246 [00:07<00:00, 34.25it/s]


Epoch: 9 	 Loss: 1.32780


100%|██████████| 246/246 [00:07<00:00, 34.89it/s]


Epoch: 10 	 Loss: 1.29118


100%|██████████| 246/246 [00:07<00:00, 32.85it/s]


Epoch: 11 	 Loss: 1.23132


100%|██████████| 246/246 [00:06<00:00, 35.29it/s]


Epoch: 12 	 Loss: 1.19783


100%|██████████| 246/246 [00:07<00:00, 34.77it/s]


Epoch: 13 	 Loss: 1.16965


100%|██████████| 246/246 [00:07<00:00, 33.17it/s]


Epoch: 14 	 Loss: 1.15211


100%|██████████| 246/246 [00:07<00:00, 34.94it/s]


Epoch: 15 	 Loss: 1.09974


100%|██████████| 246/246 [00:07<00:00, 34.67it/s]


Epoch: 16 	 Loss: 1.06550


100%|██████████| 246/246 [00:07<00:00, 34.33it/s]


Epoch: 17 	 Loss: 1.03870


100%|██████████| 246/246 [00:06<00:00, 35.28it/s]


Epoch: 18 	 Loss: 0.98359


100%|██████████| 246/246 [00:07<00:00, 33.49it/s]


Epoch: 19 	 Loss: 0.96213


100%|██████████| 246/246 [00:07<00:00, 33.98it/s]


Epoch: 20 	 Loss: 0.91447


100%|██████████| 246/246 [00:07<00:00, 34.63it/s]


Epoch: 21 	 Loss: 0.88301


100%|██████████| 246/246 [00:06<00:00, 35.48it/s]


Epoch: 22 	 Loss: 0.85659


100%|██████████| 246/246 [00:06<00:00, 35.27it/s]


Epoch: 23 	 Loss: 0.81432


100%|██████████| 246/246 [00:07<00:00, 33.83it/s]


Epoch: 24 	 Loss: 0.79681


100%|██████████| 246/246 [00:07<00:00, 33.42it/s]


Epoch: 25 	 Loss: 0.82372


100%|██████████| 246/246 [00:07<00:00, 35.00it/s]


Epoch: 26 	 Loss: 0.75768


100%|██████████| 246/246 [00:07<00:00, 34.15it/s]


Epoch: 27 	 Loss: 0.73933


100%|██████████| 246/246 [00:07<00:00, 34.44it/s]


Epoch: 28 	 Loss: 0.70246


100%|██████████| 246/246 [00:07<00:00, 33.87it/s]


Epoch: 29 	 Loss: 0.65663


100%|██████████| 246/246 [00:08<00:00, 30.16it/s]


Epoch: 30 	 Loss: 0.67047


100%|██████████| 246/246 [00:06<00:00, 35.66it/s]


Epoch: 31 	 Loss: 0.66627


100%|██████████| 246/246 [00:07<00:00, 34.09it/s]


Epoch: 32 	 Loss: 0.63097


100%|██████████| 246/246 [00:07<00:00, 33.65it/s]


Epoch: 33 	 Loss: 0.62003


100%|██████████| 246/246 [00:07<00:00, 33.30it/s]


Epoch: 34 	 Loss: 0.59919


100%|██████████| 246/246 [00:07<00:00, 34.83it/s]


Epoch: 35 	 Loss: 0.57966


100%|██████████| 246/246 [00:07<00:00, 34.14it/s]


Epoch: 36 	 Loss: 0.60327


100%|██████████| 246/246 [00:07<00:00, 33.58it/s]


Epoch: 37 	 Loss: 0.56896


100%|██████████| 246/246 [00:07<00:00, 32.61it/s]


Epoch: 38 	 Loss: 0.53711


100%|██████████| 246/246 [00:07<00:00, 33.79it/s]


Epoch: 39 	 Loss: 0.51120


100%|██████████| 246/246 [00:08<00:00, 29.43it/s]


Epoch: 40 	 Loss: 0.49128


100%|██████████| 246/246 [00:07<00:00, 32.02it/s]


Epoch: 41 	 Loss: 0.46845


100%|██████████| 246/246 [00:07<00:00, 34.69it/s]


Epoch: 42 	 Loss: 0.45939


100%|██████████| 246/246 [00:07<00:00, 33.41it/s]


Epoch: 43 	 Loss: 0.47733


100%|██████████| 246/246 [00:06<00:00, 35.43it/s]


Epoch: 44 	 Loss: 0.45669


100%|██████████| 246/246 [00:06<00:00, 35.33it/s]


Epoch: 45 	 Loss: 0.44983


100%|██████████| 246/246 [00:07<00:00, 33.37it/s]


Epoch: 46 	 Loss: 0.43282


100%|██████████| 246/246 [00:07<00:00, 34.13it/s]


Epoch: 47 	 Loss: 0.41081


100%|██████████| 246/246 [00:07<00:00, 31.49it/s]


Epoch: 48 	 Loss: 0.39618


100%|██████████| 246/246 [00:07<00:00, 35.12it/s]


Epoch: 49 	 Loss: 0.41691


100%|██████████| 246/246 [00:07<00:00, 32.77it/s]


Epoch: 50 	 Loss: 0.37877


100%|██████████| 246/246 [00:07<00:00, 34.58it/s]


Epoch: 51 	 Loss: 0.38221


100%|██████████| 246/246 [00:06<00:00, 35.53it/s]


Epoch: 52 	 Loss: 0.38266


100%|██████████| 246/246 [00:07<00:00, 33.83it/s]


Epoch: 53 	 Loss: 0.37498


100%|██████████| 246/246 [00:07<00:00, 34.12it/s]


Epoch: 54 	 Loss: 0.36377


100%|██████████| 246/246 [00:07<00:00, 33.25it/s]


Epoch: 55 	 Loss: 0.35854


100%|██████████| 246/246 [00:06<00:00, 35.50it/s]


Epoch: 56 	 Loss: 0.35696


100%|██████████| 246/246 [00:07<00:00, 33.32it/s]


Epoch: 57 	 Loss: 0.32549


100%|██████████| 246/246 [00:07<00:00, 34.79it/s]


Epoch: 58 	 Loss: 0.35632


100%|██████████| 246/246 [00:07<00:00, 34.50it/s]


Epoch: 59 	 Loss: 0.39597


100%|██████████| 246/246 [00:07<00:00, 33.76it/s]


Epoch: 60 	 Loss: 0.34619


100%|██████████| 246/246 [00:07<00:00, 34.87it/s]


Epoch: 61 	 Loss: 0.31556


100%|██████████| 246/246 [00:07<00:00, 34.02it/s]


Epoch: 62 	 Loss: 0.29689


100%|██████████| 246/246 [00:07<00:00, 34.71it/s]


Epoch: 63 	 Loss: 0.29899


100%|██████████| 246/246 [00:07<00:00, 35.09it/s]


Epoch: 64 	 Loss: 0.29913


100%|██████████| 246/246 [00:07<00:00, 32.38it/s]


Epoch: 65 	 Loss: 0.28181


100%|██████████| 246/246 [00:07<00:00, 33.44it/s]


Epoch: 66 	 Loss: 0.27789


100%|██████████| 246/246 [00:07<00:00, 34.19it/s]


Epoch: 67 	 Loss: 0.27979


100%|██████████| 246/246 [00:07<00:00, 34.66it/s]


Epoch: 68 	 Loss: 0.31257


100%|██████████| 246/246 [00:06<00:00, 35.35it/s]


Epoch: 69 	 Loss: 0.33179


100%|██████████| 246/246 [00:07<00:00, 31.81it/s]


Epoch: 70 	 Loss: 0.28723


100%|██████████| 246/246 [00:07<00:00, 32.78it/s]


Epoch: 71 	 Loss: 0.26724


100%|██████████| 246/246 [00:07<00:00, 32.67it/s]


Epoch: 72 	 Loss: 0.25117


100%|██████████| 246/246 [00:07<00:00, 35.10it/s]


Epoch: 73 	 Loss: 0.24616


100%|██████████| 246/246 [00:07<00:00, 35.13it/s]


Epoch: 74 	 Loss: 0.24090


100%|██████████| 246/246 [00:07<00:00, 31.84it/s]


Epoch: 75 	 Loss: 0.26177


100%|██████████| 246/246 [00:07<00:00, 33.64it/s]


Epoch: 76 	 Loss: 0.25151


100%|██████████| 246/246 [00:07<00:00, 34.92it/s]


Epoch: 77 	 Loss: 0.24981


100%|██████████| 246/246 [00:07<00:00, 32.04it/s]


Epoch: 78 	 Loss: 0.24931


100%|██████████| 246/246 [00:07<00:00, 33.68it/s]


Epoch: 79 	 Loss: 0.25764


100%|██████████| 246/246 [00:07<00:00, 33.28it/s]


Epoch: 80 	 Loss: 0.25987


100%|██████████| 246/246 [00:07<00:00, 35.02it/s]


Epoch: 81 	 Loss: 0.24557


100%|██████████| 246/246 [00:06<00:00, 35.57it/s]


Epoch: 82 	 Loss: 0.22094


100%|██████████| 246/246 [00:07<00:00, 33.72it/s]


Epoch: 83 	 Loss: 0.21837


100%|██████████| 246/246 [00:07<00:00, 34.07it/s]


Epoch: 84 	 Loss: 0.22183


100%|██████████| 246/246 [00:06<00:00, 35.17it/s]


Epoch: 85 	 Loss: 0.22174


100%|██████████| 246/246 [00:07<00:00, 33.87it/s]


Epoch: 86 	 Loss: 0.23137


100%|██████████| 246/246 [00:07<00:00, 35.06it/s]


Epoch: 87 	 Loss: 0.21554


100%|██████████| 246/246 [00:07<00:00, 32.92it/s]


Epoch: 88 	 Loss: 0.21196


100%|██████████| 246/246 [00:07<00:00, 31.82it/s]


Epoch: 89 	 Loss: 0.20912


100%|██████████| 246/246 [00:07<00:00, 34.66it/s]


Epoch: 90 	 Loss: 0.23325


100%|██████████| 246/246 [00:07<00:00, 34.94it/s]


Epoch: 91 	 Loss: 0.21308


100%|██████████| 246/246 [00:07<00:00, 34.35it/s]


Epoch: 92 	 Loss: 0.19399


100%|██████████| 246/246 [00:07<00:00, 33.18it/s]


Epoch: 93 	 Loss: 0.19569


100%|██████████| 246/246 [00:06<00:00, 35.52it/s]


Epoch: 94 	 Loss: 0.21256


100%|██████████| 246/246 [00:07<00:00, 34.55it/s]


Epoch: 95 	 Loss: 0.22090


100%|██████████| 246/246 [00:07<00:00, 33.69it/s]


Epoch: 96 	 Loss: 0.26664


100%|██████████| 246/246 [00:07<00:00, 34.84it/s]


Epoch: 97 	 Loss: 0.22236


100%|██████████| 246/246 [00:07<00:00, 33.27it/s]


Epoch: 98 	 Loss: 0.20081


100%|██████████| 246/246 [00:07<00:00, 34.20it/s]


Epoch: 99 	 Loss: 0.18144


100%|██████████| 246/246 [00:07<00:00, 35.06it/s]

Epoch: 100 	 Loss: 0.19187





In [24]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model3, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.66444233 0.0952381  0.52470588 0.7        0.76315789 0.2020202
 0.83030303 0.3125     0.33333333 0.49193548 0.3       ] 
 Average acccuracy: 0.47433056798910883
F1 Scores: [0.68873513 0.07185624 0.62421269 0.2333333  0.43609018 0.1593625
 0.82282277 0.32894732 0.01680672 0.57547165 0.15384611] 
 Average F1: 0.3737713290617482


## CNN-BiLSTM

In [25]:
model4 = CNN_BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model4.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model4.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model4(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.1:
        break

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:19, 12.63it/s]

100%|██████████| 246/246 [00:08<00:00, 27.54it/s]


Epoch: 1 	 Loss: 2.24902


100%|██████████| 246/246 [00:08<00:00, 29.44it/s]


Epoch: 2 	 Loss: 1.91365


100%|██████████| 246/246 [00:07<00:00, 30.88it/s]


Epoch: 3 	 Loss: 1.76526


100%|██████████| 246/246 [00:08<00:00, 30.60it/s]


Epoch: 4 	 Loss: 1.64916


100%|██████████| 246/246 [00:07<00:00, 30.87it/s]


Epoch: 5 	 Loss: 1.57735


100%|██████████| 246/246 [00:08<00:00, 29.53it/s]


Epoch: 6 	 Loss: 1.50676


100%|██████████| 246/246 [00:08<00:00, 29.39it/s]


Epoch: 7 	 Loss: 1.44801


100%|██████████| 246/246 [00:08<00:00, 30.26it/s]


Epoch: 8 	 Loss: 1.40970


100%|██████████| 246/246 [00:08<00:00, 29.53it/s]


Epoch: 9 	 Loss: 1.37216


100%|██████████| 246/246 [00:08<00:00, 29.29it/s]


Epoch: 10 	 Loss: 1.30634


100%|██████████| 246/246 [00:08<00:00, 29.86it/s]


Epoch: 11 	 Loss: 1.30337


100%|██████████| 246/246 [00:08<00:00, 29.94it/s]


Epoch: 12 	 Loss: 1.21658


100%|██████████| 246/246 [00:08<00:00, 28.99it/s]


Epoch: 13 	 Loss: 1.17096


100%|██████████| 246/246 [00:09<00:00, 26.60it/s]


Epoch: 14 	 Loss: 1.13447


100%|██████████| 246/246 [00:08<00:00, 28.30it/s]


Epoch: 15 	 Loss: 1.10313


100%|██████████| 246/246 [00:08<00:00, 29.66it/s]


Epoch: 16 	 Loss: 1.05179


100%|██████████| 246/246 [00:08<00:00, 29.69it/s]


Epoch: 17 	 Loss: 0.99747


100%|██████████| 246/246 [00:08<00:00, 28.62it/s]


Epoch: 18 	 Loss: 0.97412


100%|██████████| 246/246 [00:08<00:00, 29.64it/s]


Epoch: 19 	 Loss: 0.95235


100%|██████████| 246/246 [00:08<00:00, 30.28it/s]


Epoch: 20 	 Loss: 0.92728


100%|██████████| 246/246 [00:08<00:00, 29.17it/s]


Epoch: 21 	 Loss: 0.87524


100%|██████████| 246/246 [00:08<00:00, 28.52it/s]


Epoch: 22 	 Loss: 0.82393


100%|██████████| 246/246 [00:08<00:00, 30.39it/s]


Epoch: 23 	 Loss: 0.77405


100%|██████████| 246/246 [00:08<00:00, 29.24it/s]


Epoch: 24 	 Loss: 0.75430


100%|██████████| 246/246 [00:08<00:00, 29.30it/s]


Epoch: 25 	 Loss: 0.72233


100%|██████████| 246/246 [00:08<00:00, 27.73it/s]


Epoch: 26 	 Loss: 0.68391


100%|██████████| 246/246 [00:08<00:00, 30.37it/s]


Epoch: 27 	 Loss: 0.68558


100%|██████████| 246/246 [00:08<00:00, 29.94it/s]


Epoch: 28 	 Loss: 0.69265


100%|██████████| 246/246 [00:08<00:00, 28.91it/s]


Epoch: 29 	 Loss: 0.61364


100%|██████████| 246/246 [00:08<00:00, 29.00it/s]


Epoch: 30 	 Loss: 0.56908


100%|██████████| 246/246 [00:08<00:00, 30.17it/s]


Epoch: 31 	 Loss: 0.56068


100%|██████████| 246/246 [00:08<00:00, 30.27it/s]


Epoch: 32 	 Loss: 0.51761


100%|██████████| 246/246 [00:08<00:00, 28.27it/s]


Epoch: 33 	 Loss: 0.48788


100%|██████████| 246/246 [00:08<00:00, 28.86it/s]


Epoch: 34 	 Loss: 0.47996


100%|██████████| 246/246 [00:07<00:00, 31.09it/s]


Epoch: 35 	 Loss: 0.46907


100%|██████████| 246/246 [00:07<00:00, 32.72it/s]


Epoch: 36 	 Loss: 0.48487


100%|██████████| 246/246 [00:06<00:00, 37.20it/s]


Epoch: 37 	 Loss: 0.45817


100%|██████████| 246/246 [00:07<00:00, 32.89it/s]


Epoch: 38 	 Loss: 0.44543


100%|██████████| 246/246 [00:08<00:00, 29.65it/s]


Epoch: 39 	 Loss: 0.40855


100%|██████████| 246/246 [00:08<00:00, 28.99it/s]


Epoch: 40 	 Loss: 0.38756


100%|██████████| 246/246 [00:08<00:00, 29.34it/s]


Epoch: 41 	 Loss: 0.38475


100%|██████████| 246/246 [00:08<00:00, 28.39it/s]


Epoch: 42 	 Loss: 0.37399


100%|██████████| 246/246 [00:08<00:00, 27.90it/s]


Epoch: 43 	 Loss: 0.34509


100%|██████████| 246/246 [00:08<00:00, 29.73it/s]


Epoch: 44 	 Loss: 0.31760


100%|██████████| 246/246 [00:08<00:00, 29.50it/s]


Epoch: 45 	 Loss: 0.33389


100%|██████████| 246/246 [00:08<00:00, 28.81it/s]


Epoch: 46 	 Loss: 0.33086


100%|██████████| 246/246 [00:09<00:00, 25.37it/s]


Epoch: 47 	 Loss: 0.31541


100%|██████████| 246/246 [00:08<00:00, 28.42it/s]


Epoch: 48 	 Loss: 0.30046


100%|██████████| 246/246 [00:08<00:00, 29.33it/s]


Epoch: 49 	 Loss: 0.29429


100%|██████████| 246/246 [00:08<00:00, 28.78it/s]


Epoch: 50 	 Loss: 0.30938


100%|██████████| 246/246 [00:08<00:00, 30.09it/s]


Epoch: 51 	 Loss: 0.28290


100%|██████████| 246/246 [00:08<00:00, 30.22it/s]


Epoch: 52 	 Loss: 0.26773


100%|██████████| 246/246 [00:08<00:00, 28.53it/s]


Epoch: 53 	 Loss: 0.25016


100%|██████████| 246/246 [00:08<00:00, 28.74it/s]


Epoch: 54 	 Loss: 0.24538


100%|██████████| 246/246 [00:08<00:00, 30.00it/s]


Epoch: 55 	 Loss: 0.25153


100%|██████████| 246/246 [00:08<00:00, 27.83it/s]


Epoch: 56 	 Loss: 0.25368


100%|██████████| 246/246 [00:08<00:00, 28.89it/s]


Epoch: 57 	 Loss: 0.26190


100%|██████████| 246/246 [00:08<00:00, 29.55it/s]


Epoch: 58 	 Loss: 0.23116


100%|██████████| 246/246 [00:08<00:00, 29.58it/s]


Epoch: 59 	 Loss: 0.22171


100%|██████████| 246/246 [00:08<00:00, 29.47it/s]


Epoch: 60 	 Loss: 0.20656


100%|██████████| 246/246 [00:08<00:00, 28.22it/s]


Epoch: 61 	 Loss: 0.20930


100%|██████████| 246/246 [00:08<00:00, 29.25it/s]


Epoch: 62 	 Loss: 0.20763


100%|██████████| 246/246 [00:08<00:00, 30.16it/s]


Epoch: 63 	 Loss: 0.22035


100%|██████████| 246/246 [00:08<00:00, 29.30it/s]


Epoch: 64 	 Loss: 0.22100


100%|██████████| 246/246 [00:08<00:00, 28.82it/s]


Epoch: 65 	 Loss: 0.21747


100%|██████████| 246/246 [00:08<00:00, 28.38it/s]


Epoch: 66 	 Loss: 0.19248


100%|██████████| 246/246 [00:08<00:00, 29.74it/s]


Epoch: 67 	 Loss: 0.21077


100%|██████████| 246/246 [00:08<00:00, 29.68it/s]


Epoch: 68 	 Loss: 0.30549


100%|██████████| 246/246 [00:08<00:00, 29.35it/s]


Epoch: 69 	 Loss: 0.24017


100%|██████████| 246/246 [00:08<00:00, 29.39it/s]


Epoch: 70 	 Loss: 0.17663


100%|██████████| 246/246 [00:08<00:00, 28.45it/s]


Epoch: 71 	 Loss: 0.16058


100%|██████████| 246/246 [00:08<00:00, 30.03it/s]


Epoch: 72 	 Loss: 0.16157


100%|██████████| 246/246 [00:08<00:00, 28.96it/s]


Epoch: 73 	 Loss: 0.15341


100%|██████████| 246/246 [00:08<00:00, 28.29it/s]


Epoch: 74 	 Loss: 0.15173


100%|██████████| 246/246 [00:08<00:00, 30.17it/s]


Epoch: 75 	 Loss: 0.15135


100%|██████████| 246/246 [00:08<00:00, 29.25it/s]


Epoch: 76 	 Loss: 0.15036


100%|██████████| 246/246 [00:08<00:00, 28.50it/s]


Epoch: 77 	 Loss: 0.14359


100%|██████████| 246/246 [00:08<00:00, 30.06it/s]


Epoch: 78 	 Loss: 0.16266


100%|██████████| 246/246 [00:08<00:00, 29.68it/s]


Epoch: 79 	 Loss: 0.15336


100%|██████████| 246/246 [00:08<00:00, 29.76it/s]


Epoch: 80 	 Loss: 0.14943


100%|██████████| 246/246 [00:08<00:00, 28.57it/s]


Epoch: 81 	 Loss: 0.14886


100%|██████████| 246/246 [00:08<00:00, 28.96it/s]


Epoch: 82 	 Loss: 0.22113


100%|██████████| 246/246 [00:08<00:00, 29.12it/s]


Epoch: 83 	 Loss: 0.25992


100%|██████████| 246/246 [00:08<00:00, 29.26it/s]


Epoch: 84 	 Loss: 0.16622


100%|██████████| 246/246 [00:08<00:00, 28.16it/s]


Epoch: 85 	 Loss: 0.13830


100%|██████████| 246/246 [00:08<00:00, 30.29it/s]


Epoch: 86 	 Loss: 0.12486


100%|██████████| 246/246 [00:08<00:00, 28.97it/s]


Epoch: 87 	 Loss: 0.12716


100%|██████████| 246/246 [00:08<00:00, 28.54it/s]


Epoch: 88 	 Loss: 0.11628


100%|██████████| 246/246 [00:08<00:00, 28.43it/s]


Epoch: 89 	 Loss: 0.11352


100%|██████████| 246/246 [00:08<00:00, 28.74it/s]


Epoch: 90 	 Loss: 0.11190


100%|██████████| 246/246 [00:08<00:00, 29.03it/s]


Epoch: 91 	 Loss: 0.11674


100%|██████████| 246/246 [00:08<00:00, 27.88it/s]


Epoch: 92 	 Loss: 0.11443


100%|██████████| 246/246 [00:08<00:00, 27.48it/s]


Epoch: 93 	 Loss: 0.13843


100%|██████████| 246/246 [00:08<00:00, 27.83it/s]


Epoch: 94 	 Loss: 0.15472


100%|██████████| 246/246 [00:08<00:00, 29.37it/s]


Epoch: 95 	 Loss: 0.13552


100%|██████████| 246/246 [00:08<00:00, 28.58it/s]


Epoch: 96 	 Loss: 0.11822


100%|██████████| 246/246 [00:09<00:00, 26.96it/s]


Epoch: 97 	 Loss: 0.10485


100%|██████████| 246/246 [00:08<00:00, 29.75it/s]


Epoch: 98 	 Loss: 0.12097


100%|██████████| 246/246 [00:08<00:00, 28.07it/s]


Epoch: 99 	 Loss: 0.11448


100%|██████████| 246/246 [00:08<00:00, 28.18it/s]

Epoch: 100 	 Loss: 0.11275





In [26]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model4, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.65126812 0.0212766  0.55426918 0.34090909 0.54022989 0.11403509
 0.82916667 0.31325301 0.05660377 0.52941176 0.11111111] 
 Average acccuracy: 0.3692303889634425
F1 Scores: [0.69167864 0.01324499 0.60314956 0.31914889 0.33935014 0.09774431
 0.80894304 0.33548382 0.03550292 0.568421   0.14457827] 
 Average F1: 0.3597495971500088


In [27]:
torch.save(model1, 'bert-base-bilstm.pth')
torch.save(model2, 'bert-base-cnnbilstm.pth')
torch.save(model3, 'bert-legal-bilstm.pth')
torch.save(model4, 'bert-legal-cnnbilstm.pth')

# InLEGAL BERT

In [28]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'..')))

from Dataset_Reader import Dataset_Reader
from utils import read_json, data_to_embeddings, save_tensor, label_encode, organize_data
from utils import document_max_length, write_dictionary_to_json

from main import TRAIN_DATA_PATH, TEST_DATA_PATH

In [29]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
model = AutoModel.from_pretrained("law-ai/InLegalBERT")

tokenizer_config.json: 100%|██████████| 516/516 [00:00<00:00, 419kB/s]
vocab.txt: 100%|██████████| 222k/222k [00:00<00:00, 1.06MB/s]
special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 115kB/s]
config.json: 100%|██████████| 671/671 [00:00<00:00, 492kB/s]
pytorch_model.bin: 100%|██████████| 534M/534M [00:51<00:00, 10.5MB/s] 
  return self.fget.__get__(instance, owner)()


In [32]:
train_data = Dataset_Reader('../data/train.json')
test_data = Dataset_Reader('../data/dev.json')

print(f"Number of sentences in training data: {len(train_data.texts)}")
print(f"Number of sentences in test data: {len(test_data.texts)}")

# Manually defining labels for convenience
list_of_targets = ['ISSUE', 'FAC', 'NONE', 'ARG_PETITIONER', 'PRE_NOT_RELIED',
                'STA', 'RPC', 'ARG_RESPONDENT', 'PREAMBLE', 'ANALYSIS', 'RLC', 'PRE_RELIED', 'RATIO']

# Numerically encode labels
label_encoder = label_encode(list_of_targets)


#Compute the maximum sentence length for each document in the training and test data (to ensure all embeddings will be the same size within a document)
max_length_dict_TRAIN = document_max_length(train_data, tokenizer=tokenizer)
max_length_dict_TEST = document_max_length(test_data, tokenizer=tokenizer)

# # To same time during training process, write these documents to json file
# write_dictionary_to_json(max_length_dict_TRAIN, 'max_length_dicts/max_length_train.json')
# write_dictionary_to_json(max_length_dict_TEST, 'max_length_dicts/max_length_test.json')


#retrieve max_length dictionaries to compute word embeddings
# max_length_dict_TRAIN = read_json('max_length_dicts/max_length_train.json', reading_max_length=True)
# max_length_dict_TEST = read_json('max_length_dicts/max_length_test.json', reading_max_length=True)


#organize and process data

train_doc_idxs, train_batched_texts, train_batched_labels = organize_data(train_data, batch_size= 1) 
test_doc_idxs, test_batched_texts, test_batched_labels = organize_data(test_data, batch_size= 1) 

for idx, train_idx in tqdm(enumerate(train_doc_idxs)):
    TRAIN_emb, TRAIN_labels = data_to_embeddings(train_idx, train_batched_texts[idx], train_batched_labels[idx],
                                                label_encoder,max_length_dict_TRAIN, tokenizer=tokenizer,
                                                emb_model=model)
    save_tensor(TRAIN_emb, 'train_document/doc_'+str(idx)+'_inlegal/',"embedding")
    save_tensor(TRAIN_labels, 'train_document/doc_'+str(idx)+'_inlegal/',"label")
    

for idx, test_idx in tqdm(enumerate(test_doc_idxs)):
    TEST_emb, TEST_labels = data_to_embeddings(test_idx, test_batched_texts[idx], test_batched_labels[idx],
                                            label_encoder,max_length_dict_TEST, tokenizer=tokenizer,
                                            emb_model=model)
    save_tensor(TEST_emb, 'test_document/doc_'+str(idx)+'_inlegal/',"embedding")
    save_tensor(TEST_labels, 'test_document/doc_'+str(idx)+'_inlegal/',"label")
    

Number of sentences in training data: 28986
Number of sentences in test data: 2890


1it [00:32, 32.08s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_0_inlegal/embedding'
Tensor saved to 'train_document/doc_0_inlegal/label'


2it [01:07, 33.86s/it]

X_train size: torch.Size([72, 1, 768])	Y_train size: torch.Size([72])
Tensor saved to 'train_document/doc_1_inlegal/embedding'
Tensor saved to 'train_document/doc_1_inlegal/label'


3it [03:33, 85.00s/it]

X_train size: torch.Size([200, 1, 768])	Y_train size: torch.Size([200])
Tensor saved to 'train_document/doc_2_inlegal/embedding'
Tensor saved to 'train_document/doc_2_inlegal/label'


4it [04:49, 81.65s/it]

X_train size: torch.Size([119, 1, 768])	Y_train size: torch.Size([119])
Tensor saved to 'train_document/doc_3_inlegal/embedding'
Tensor saved to 'train_document/doc_3_inlegal/label'


5it [06:58, 98.67s/it]

X_train size: torch.Size([184, 1, 768])	Y_train size: torch.Size([184])
Tensor saved to 'train_document/doc_4_inlegal/embedding'
Tensor saved to 'train_document/doc_4_inlegal/label'


6it [08:03, 87.17s/it]

X_train size: torch.Size([211, 1, 768])	Y_train size: torch.Size([211])
Tensor saved to 'train_document/doc_5_inlegal/embedding'
Tensor saved to 'train_document/doc_5_inlegal/label'


7it [09:26, 85.89s/it]

X_train size: torch.Size([140, 1, 768])	Y_train size: torch.Size([140])
Tensor saved to 'train_document/doc_6_inlegal/embedding'
Tensor saved to 'train_document/doc_6_inlegal/label'


8it [09:52, 66.69s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_7_inlegal/embedding'
Tensor saved to 'train_document/doc_7_inlegal/label'


9it [11:18, 72.84s/it]

X_train size: torch.Size([228, 1, 768])	Y_train size: torch.Size([228])
Tensor saved to 'train_document/doc_8_inlegal/embedding'
Tensor saved to 'train_document/doc_8_inlegal/label'


10it [11:57, 62.53s/it]

X_train size: torch.Size([99, 1, 768])	Y_train size: torch.Size([99])
Tensor saved to 'train_document/doc_9_inlegal/embedding'
Tensor saved to 'train_document/doc_9_inlegal/label'


11it [12:14, 48.52s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_10_inlegal/embedding'
Tensor saved to 'train_document/doc_10_inlegal/label'


12it [13:40, 59.98s/it]

X_train size: torch.Size([213, 1, 768])	Y_train size: torch.Size([213])
Tensor saved to 'train_document/doc_11_inlegal/embedding'
Tensor saved to 'train_document/doc_11_inlegal/label'


13it [14:25, 55.28s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_12_inlegal/embedding'
Tensor saved to 'train_document/doc_12_inlegal/label'


14it [16:39, 79.20s/it]

X_train size: torch.Size([199, 1, 768])	Y_train size: torch.Size([199])
Tensor saved to 'train_document/doc_13_inlegal/embedding'
Tensor saved to 'train_document/doc_13_inlegal/label'


15it [18:23, 86.63s/it]

X_train size: torch.Size([188, 1, 768])	Y_train size: torch.Size([188])
Tensor saved to 'train_document/doc_14_inlegal/embedding'
Tensor saved to 'train_document/doc_14_inlegal/label'


16it [23:44, 157.16s/it]

X_train size: torch.Size([271, 1, 768])	Y_train size: torch.Size([271])
Tensor saved to 'train_document/doc_15_inlegal/embedding'
Tensor saved to 'train_document/doc_15_inlegal/label'


17it [23:58, 114.12s/it]

X_train size: torch.Size([43, 1, 768])	Y_train size: torch.Size([43])
Tensor saved to 'train_document/doc_16_inlegal/embedding'
Tensor saved to 'train_document/doc_16_inlegal/label'


18it [24:25, 87.84s/it] 

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_17_inlegal/embedding'
Tensor saved to 'train_document/doc_17_inlegal/label'


19it [25:24, 79.24s/it]

X_train size: torch.Size([171, 1, 768])	Y_train size: torch.Size([171])
Tensor saved to 'train_document/doc_18_inlegal/embedding'
Tensor saved to 'train_document/doc_18_inlegal/label'


20it [28:05, 103.88s/it]

X_train size: torch.Size([149, 1, 768])	Y_train size: torch.Size([149])
Tensor saved to 'train_document/doc_19_inlegal/embedding'
Tensor saved to 'train_document/doc_19_inlegal/label'


21it [28:51, 86.31s/it] 

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_20_inlegal/embedding'
Tensor saved to 'train_document/doc_20_inlegal/label'


22it [29:15, 67.81s/it]

X_train size: torch.Size([56, 1, 768])	Y_train size: torch.Size([56])
Tensor saved to 'train_document/doc_21_inlegal/embedding'
Tensor saved to 'train_document/doc_21_inlegal/label'


23it [29:28, 51.33s/it]

X_train size: torch.Size([47, 1, 768])	Y_train size: torch.Size([47])
Tensor saved to 'train_document/doc_22_inlegal/embedding'
Tensor saved to 'train_document/doc_22_inlegal/label'


24it [30:07, 47.49s/it]

X_train size: torch.Size([116, 1, 768])	Y_train size: torch.Size([116])
Tensor saved to 'train_document/doc_23_inlegal/embedding'
Tensor saved to 'train_document/doc_23_inlegal/label'


25it [30:38, 42.56s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_24_inlegal/embedding'
Tensor saved to 'train_document/doc_24_inlegal/label'


26it [30:59, 36.20s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_25_inlegal/embedding'
Tensor saved to 'train_document/doc_25_inlegal/label'


27it [32:14, 47.70s/it]

X_train size: torch.Size([109, 1, 768])	Y_train size: torch.Size([109])
Tensor saved to 'train_document/doc_26_inlegal/embedding'
Tensor saved to 'train_document/doc_26_inlegal/label'


28it [32:58, 46.78s/it]

X_train size: torch.Size([155, 1, 768])	Y_train size: torch.Size([155])
Tensor saved to 'train_document/doc_27_inlegal/embedding'
Tensor saved to 'train_document/doc_27_inlegal/label'


29it [33:58, 50.57s/it]

X_train size: torch.Size([198, 1, 768])	Y_train size: torch.Size([198])
Tensor saved to 'train_document/doc_28_inlegal/embedding'
Tensor saved to 'train_document/doc_28_inlegal/label'


30it [35:37, 65.26s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_29_inlegal/embedding'
Tensor saved to 'train_document/doc_29_inlegal/label'


31it [35:59, 52.09s/it]

X_train size: torch.Size([50, 1, 768])	Y_train size: torch.Size([50])
Tensor saved to 'train_document/doc_30_inlegal/embedding'
Tensor saved to 'train_document/doc_30_inlegal/label'


32it [36:12, 40.58s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_31_inlegal/embedding'
Tensor saved to 'train_document/doc_31_inlegal/label'


33it [37:47, 56.95s/it]

X_train size: torch.Size([264, 1, 768])	Y_train size: torch.Size([264])
Tensor saved to 'train_document/doc_32_inlegal/embedding'
Tensor saved to 'train_document/doc_32_inlegal/label'


34it [38:31, 52.83s/it]

X_train size: torch.Size([114, 1, 768])	Y_train size: torch.Size([114])
Tensor saved to 'train_document/doc_33_inlegal/embedding'
Tensor saved to 'train_document/doc_33_inlegal/label'


35it [39:01, 46.06s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_34_inlegal/embedding'
Tensor saved to 'train_document/doc_34_inlegal/label'


36it [40:39, 61.68s/it]

X_train size: torch.Size([243, 1, 768])	Y_train size: torch.Size([243])
Tensor saved to 'train_document/doc_35_inlegal/embedding'
Tensor saved to 'train_document/doc_35_inlegal/label'


37it [41:58, 66.78s/it]

X_train size: torch.Size([107, 1, 768])	Y_train size: torch.Size([107])
Tensor saved to 'train_document/doc_36_inlegal/embedding'
Tensor saved to 'train_document/doc_36_inlegal/label'


38it [43:11, 68.86s/it]

X_train size: torch.Size([127, 1, 768])	Y_train size: torch.Size([127])
Tensor saved to 'train_document/doc_37_inlegal/embedding'
Tensor saved to 'train_document/doc_37_inlegal/label'


39it [45:00, 80.62s/it]

X_train size: torch.Size([188, 1, 768])	Y_train size: torch.Size([188])
Tensor saved to 'train_document/doc_38_inlegal/embedding'
Tensor saved to 'train_document/doc_38_inlegal/label'


40it [46:06, 76.44s/it]

X_train size: torch.Size([189, 1, 768])	Y_train size: torch.Size([189])
Tensor saved to 'train_document/doc_39_inlegal/embedding'
Tensor saved to 'train_document/doc_39_inlegal/label'


41it [47:13, 73.67s/it]

X_train size: torch.Size([157, 1, 768])	Y_train size: torch.Size([157])
Tensor saved to 'train_document/doc_40_inlegal/embedding'
Tensor saved to 'train_document/doc_40_inlegal/label'


42it [47:32, 57.04s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_41_inlegal/embedding'
Tensor saved to 'train_document/doc_41_inlegal/label'


43it [48:17, 53.41s/it]

X_train size: torch.Size([120, 1, 768])	Y_train size: torch.Size([120])
Tensor saved to 'train_document/doc_42_inlegal/embedding'
Tensor saved to 'train_document/doc_42_inlegal/label'


44it [49:32, 60.00s/it]

X_train size: torch.Size([142, 1, 768])	Y_train size: torch.Size([142])
Tensor saved to 'train_document/doc_43_inlegal/embedding'
Tensor saved to 'train_document/doc_43_inlegal/label'


45it [50:10, 53.30s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_44_inlegal/embedding'
Tensor saved to 'train_document/doc_44_inlegal/label'


46it [51:11, 55.71s/it]

X_train size: torch.Size([146, 1, 768])	Y_train size: torch.Size([146])
Tensor saved to 'train_document/doc_45_inlegal/embedding'
Tensor saved to 'train_document/doc_45_inlegal/label'


47it [52:33, 63.54s/it]

X_train size: torch.Size([185, 1, 768])	Y_train size: torch.Size([185])
Tensor saved to 'train_document/doc_46_inlegal/embedding'
Tensor saved to 'train_document/doc_46_inlegal/label'


48it [52:54, 50.83s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_47_inlegal/embedding'
Tensor saved to 'train_document/doc_47_inlegal/label'


49it [53:12, 40.93s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_48_inlegal/embedding'
Tensor saved to 'train_document/doc_48_inlegal/label'


50it [53:44, 38.24s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_49_inlegal/embedding'
Tensor saved to 'train_document/doc_49_inlegal/label'


51it [54:49, 46.42s/it]

X_train size: torch.Size([187, 1, 768])	Y_train size: torch.Size([187])
Tensor saved to 'train_document/doc_50_inlegal/embedding'
Tensor saved to 'train_document/doc_50_inlegal/label'


52it [55:22, 42.41s/it]

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_51_inlegal/embedding'
Tensor saved to 'train_document/doc_51_inlegal/label'


53it [55:56, 39.74s/it]

X_train size: torch.Size([88, 1, 768])	Y_train size: torch.Size([88])
Tensor saved to 'train_document/doc_52_inlegal/embedding'
Tensor saved to 'train_document/doc_52_inlegal/label'


54it [56:24, 36.24s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'train_document/doc_53_inlegal/embedding'
Tensor saved to 'train_document/doc_53_inlegal/label'


55it [56:46, 32.06s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'train_document/doc_54_inlegal/embedding'
Tensor saved to 'train_document/doc_54_inlegal/label'


56it [57:28, 34.84s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_55_inlegal/embedding'
Tensor saved to 'train_document/doc_55_inlegal/label'


57it [58:37, 45.22s/it]

X_train size: torch.Size([141, 1, 768])	Y_train size: torch.Size([141])
Tensor saved to 'train_document/doc_56_inlegal/embedding'
Tensor saved to 'train_document/doc_56_inlegal/label'


58it [59:00, 38.69s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_57_inlegal/embedding'
Tensor saved to 'train_document/doc_57_inlegal/label'


59it [59:18, 32.21s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_58_inlegal/embedding'
Tensor saved to 'train_document/doc_58_inlegal/label'


60it [59:41, 29.70s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_59_inlegal/embedding'
Tensor saved to 'train_document/doc_59_inlegal/label'


61it [1:00:02, 27.04s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_60_inlegal/embedding'
Tensor saved to 'train_document/doc_60_inlegal/label'


62it [1:00:15, 22.82s/it]

X_train size: torch.Size([39, 1, 768])	Y_train size: torch.Size([39])
Tensor saved to 'train_document/doc_61_inlegal/embedding'
Tensor saved to 'train_document/doc_61_inlegal/label'


63it [1:02:50, 62.49s/it]

X_train size: torch.Size([208, 1, 768])	Y_train size: torch.Size([208])
Tensor saved to 'train_document/doc_62_inlegal/embedding'
Tensor saved to 'train_document/doc_62_inlegal/label'


64it [1:03:40, 58.81s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_63_inlegal/embedding'
Tensor saved to 'train_document/doc_63_inlegal/label'


65it [1:05:33, 75.04s/it]

X_train size: torch.Size([155, 1, 768])	Y_train size: torch.Size([155])
Tensor saved to 'train_document/doc_64_inlegal/embedding'
Tensor saved to 'train_document/doc_64_inlegal/label'


66it [1:06:36, 71.40s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_65_inlegal/embedding'
Tensor saved to 'train_document/doc_65_inlegal/label'


67it [1:07:11, 60.35s/it]

X_train size: torch.Size([52, 1, 768])	Y_train size: torch.Size([52])
Tensor saved to 'train_document/doc_66_inlegal/embedding'
Tensor saved to 'train_document/doc_66_inlegal/label'


68it [1:08:15, 61.39s/it]

X_train size: torch.Size([143, 1, 768])	Y_train size: torch.Size([143])
Tensor saved to 'train_document/doc_67_inlegal/embedding'
Tensor saved to 'train_document/doc_67_inlegal/label'


69it [1:10:03, 75.56s/it]

X_train size: torch.Size([142, 1, 768])	Y_train size: torch.Size([142])
Tensor saved to 'train_document/doc_68_inlegal/embedding'
Tensor saved to 'train_document/doc_68_inlegal/label'
X_train size: torch.Size([0, 1, 768])	Y_train size: torch.Size([0])
Tensor saved to 'train_document/doc_69_inlegal/embedding'
Tensor saved to 'train_document/doc_69_inlegal/label'


71it [1:10:24, 45.49s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'train_document/doc_70_inlegal/embedding'
Tensor saved to 'train_document/doc_70_inlegal/label'


72it [1:11:01, 43.31s/it]

X_train size: torch.Size([116, 1, 768])	Y_train size: torch.Size([116])
Tensor saved to 'train_document/doc_71_inlegal/embedding'
Tensor saved to 'train_document/doc_71_inlegal/label'


73it [1:12:02, 47.87s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_72_inlegal/embedding'
Tensor saved to 'train_document/doc_72_inlegal/label'


74it [1:13:15, 54.95s/it]

X_train size: torch.Size([122, 1, 768])	Y_train size: torch.Size([122])
Tensor saved to 'train_document/doc_73_inlegal/embedding'
Tensor saved to 'train_document/doc_73_inlegal/label'


75it [1:14:01, 52.44s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_74_inlegal/embedding'
Tensor saved to 'train_document/doc_74_inlegal/label'


76it [1:14:50, 51.30s/it]

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_75_inlegal/embedding'
Tensor saved to 'train_document/doc_75_inlegal/label'


77it [1:16:27, 64.73s/it]

X_train size: torch.Size([184, 1, 768])	Y_train size: torch.Size([184])
Tensor saved to 'train_document/doc_76_inlegal/embedding'
Tensor saved to 'train_document/doc_76_inlegal/label'


78it [1:17:16, 60.04s/it]

X_train size: torch.Size([144, 1, 768])	Y_train size: torch.Size([144])
Tensor saved to 'train_document/doc_77_inlegal/embedding'
Tensor saved to 'train_document/doc_77_inlegal/label'


79it [1:17:32, 47.00s/it]

X_train size: torch.Size([35, 1, 768])	Y_train size: torch.Size([35])
Tensor saved to 'train_document/doc_78_inlegal/embedding'
Tensor saved to 'train_document/doc_78_inlegal/label'


80it [1:19:15, 63.56s/it]

X_train size: torch.Size([140, 1, 768])	Y_train size: torch.Size([140])
Tensor saved to 'train_document/doc_79_inlegal/embedding'
Tensor saved to 'train_document/doc_79_inlegal/label'


81it [1:19:36, 51.02s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_80_inlegal/embedding'
Tensor saved to 'train_document/doc_80_inlegal/label'


82it [1:22:14, 82.73s/it]

X_train size: torch.Size([137, 1, 768])	Y_train size: torch.Size([137])
Tensor saved to 'train_document/doc_81_inlegal/embedding'
Tensor saved to 'train_document/doc_81_inlegal/label'


83it [1:22:26, 61.67s/it]

X_train size: torch.Size([24, 1, 768])	Y_train size: torch.Size([24])
Tensor saved to 'train_document/doc_82_inlegal/embedding'
Tensor saved to 'train_document/doc_82_inlegal/label'


84it [1:24:00, 71.24s/it]

X_train size: torch.Size([113, 1, 768])	Y_train size: torch.Size([113])
Tensor saved to 'train_document/doc_83_inlegal/embedding'
Tensor saved to 'train_document/doc_83_inlegal/label'


85it [1:24:29, 58.69s/it]

X_train size: torch.Size([105, 1, 768])	Y_train size: torch.Size([105])
Tensor saved to 'train_document/doc_84_inlegal/embedding'
Tensor saved to 'train_document/doc_84_inlegal/label'


86it [1:24:51, 47.77s/it]

X_train size: torch.Size([80, 1, 768])	Y_train size: torch.Size([80])
Tensor saved to 'train_document/doc_85_inlegal/embedding'
Tensor saved to 'train_document/doc_85_inlegal/label'


87it [1:25:37, 47.17s/it]

X_train size: torch.Size([106, 1, 768])	Y_train size: torch.Size([106])
Tensor saved to 'train_document/doc_86_inlegal/embedding'
Tensor saved to 'train_document/doc_86_inlegal/label'


88it [1:25:55, 38.51s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_87_inlegal/embedding'
Tensor saved to 'train_document/doc_87_inlegal/label'


89it [1:26:41, 40.57s/it]

X_train size: torch.Size([105, 1, 768])	Y_train size: torch.Size([105])
Tensor saved to 'train_document/doc_88_inlegal/embedding'
Tensor saved to 'train_document/doc_88_inlegal/label'


90it [1:26:58, 33.56s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_89_inlegal/embedding'
Tensor saved to 'train_document/doc_89_inlegal/label'


91it [1:28:13, 46.15s/it]

X_train size: torch.Size([136, 1, 768])	Y_train size: torch.Size([136])
Tensor saved to 'train_document/doc_90_inlegal/embedding'
Tensor saved to 'train_document/doc_90_inlegal/label'


92it [1:28:47, 42.36s/it]

X_train size: torch.Size([83, 1, 768])	Y_train size: torch.Size([83])
Tensor saved to 'train_document/doc_91_inlegal/embedding'
Tensor saved to 'train_document/doc_91_inlegal/label'


93it [1:30:36, 62.51s/it]

X_train size: torch.Size([221, 1, 768])	Y_train size: torch.Size([221])
Tensor saved to 'train_document/doc_92_inlegal/embedding'
Tensor saved to 'train_document/doc_92_inlegal/label'


94it [1:31:29, 59.66s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'train_document/doc_93_inlegal/embedding'
Tensor saved to 'train_document/doc_93_inlegal/label'


95it [1:32:41, 63.32s/it]

X_train size: torch.Size([114, 1, 768])	Y_train size: torch.Size([114])
Tensor saved to 'train_document/doc_94_inlegal/embedding'
Tensor saved to 'train_document/doc_94_inlegal/label'


96it [1:33:11, 53.38s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'train_document/doc_95_inlegal/embedding'
Tensor saved to 'train_document/doc_95_inlegal/label'


97it [1:33:31, 43.37s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'train_document/doc_96_inlegal/embedding'
Tensor saved to 'train_document/doc_96_inlegal/label'


98it [1:34:12, 42.63s/it]

X_train size: torch.Size([70, 1, 768])	Y_train size: torch.Size([70])
Tensor saved to 'train_document/doc_97_inlegal/embedding'
Tensor saved to 'train_document/doc_97_inlegal/label'


99it [1:35:57, 61.11s/it]

X_train size: torch.Size([264, 1, 768])	Y_train size: torch.Size([264])
Tensor saved to 'train_document/doc_98_inlegal/embedding'
Tensor saved to 'train_document/doc_98_inlegal/label'


100it [1:37:59, 79.51s/it]

X_train size: torch.Size([167, 1, 768])	Y_train size: torch.Size([167])
Tensor saved to 'train_document/doc_99_inlegal/embedding'
Tensor saved to 'train_document/doc_99_inlegal/label'


101it [1:38:15, 60.59s/it]

X_train size: torch.Size([49, 1, 768])	Y_train size: torch.Size([49])
Tensor saved to 'train_document/doc_100_inlegal/embedding'
Tensor saved to 'train_document/doc_100_inlegal/label'


102it [1:38:48, 52.31s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_101_inlegal/embedding'
Tensor saved to 'train_document/doc_101_inlegal/label'


103it [1:39:06, 41.76s/it]

X_train size: torch.Size([74, 1, 768])	Y_train size: torch.Size([74])
Tensor saved to 'train_document/doc_102_inlegal/embedding'
Tensor saved to 'train_document/doc_102_inlegal/label'


104it [1:40:25, 53.03s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_103_inlegal/embedding'
Tensor saved to 'train_document/doc_103_inlegal/label'


105it [1:41:18, 53.09s/it]

X_train size: torch.Size([147, 1, 768])	Y_train size: torch.Size([147])
Tensor saved to 'train_document/doc_104_inlegal/embedding'
Tensor saved to 'train_document/doc_104_inlegal/label'


106it [1:42:58, 67.10s/it]

X_train size: torch.Size([180, 1, 768])	Y_train size: torch.Size([180])
Tensor saved to 'train_document/doc_105_inlegal/embedding'
Tensor saved to 'train_document/doc_105_inlegal/label'


107it [1:43:27, 55.64s/it]

X_train size: torch.Size([71, 1, 768])	Y_train size: torch.Size([71])
Tensor saved to 'train_document/doc_106_inlegal/embedding'
Tensor saved to 'train_document/doc_106_inlegal/label'


108it [1:44:25, 56.28s/it]

X_train size: torch.Size([174, 1, 768])	Y_train size: torch.Size([174])
Tensor saved to 'train_document/doc_107_inlegal/embedding'
Tensor saved to 'train_document/doc_107_inlegal/label'


109it [1:45:14, 54.10s/it]

X_train size: torch.Size([126, 1, 768])	Y_train size: torch.Size([126])
Tensor saved to 'train_document/doc_108_inlegal/embedding'
Tensor saved to 'train_document/doc_108_inlegal/label'


110it [1:45:53, 49.79s/it]

X_train size: torch.Size([129, 1, 768])	Y_train size: torch.Size([129])
Tensor saved to 'train_document/doc_109_inlegal/embedding'
Tensor saved to 'train_document/doc_109_inlegal/label'


111it [1:46:06, 38.81s/it]

X_train size: torch.Size([32, 1, 768])	Y_train size: torch.Size([32])
Tensor saved to 'train_document/doc_110_inlegal/embedding'
Tensor saved to 'train_document/doc_110_inlegal/label'


112it [1:46:43, 38.12s/it]

X_train size: torch.Size([121, 1, 768])	Y_train size: torch.Size([121])
Tensor saved to 'train_document/doc_111_inlegal/embedding'
Tensor saved to 'train_document/doc_111_inlegal/label'


113it [1:47:09, 34.59s/it]

X_train size: torch.Size([60, 1, 768])	Y_train size: torch.Size([60])
Tensor saved to 'train_document/doc_112_inlegal/embedding'
Tensor saved to 'train_document/doc_112_inlegal/label'


114it [1:48:19, 45.19s/it]

X_train size: torch.Size([94, 1, 768])	Y_train size: torch.Size([94])
Tensor saved to 'train_document/doc_113_inlegal/embedding'
Tensor saved to 'train_document/doc_113_inlegal/label'


115it [1:49:03, 44.67s/it]

X_train size: torch.Size([121, 1, 768])	Y_train size: torch.Size([121])
Tensor saved to 'train_document/doc_114_inlegal/embedding'
Tensor saved to 'train_document/doc_114_inlegal/label'


116it [1:49:49, 45.13s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_115_inlegal/embedding'
Tensor saved to 'train_document/doc_115_inlegal/label'


117it [1:51:40, 64.94s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_116_inlegal/embedding'
Tensor saved to 'train_document/doc_116_inlegal/label'


118it [1:53:29, 78.27s/it]

X_train size: torch.Size([308, 1, 768])	Y_train size: torch.Size([308])
Tensor saved to 'train_document/doc_117_inlegal/embedding'
Tensor saved to 'train_document/doc_117_inlegal/label'


119it [1:53:52, 61.44s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_118_inlegal/embedding'
Tensor saved to 'train_document/doc_118_inlegal/label'


120it [1:54:42, 58.19s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'train_document/doc_119_inlegal/embedding'
Tensor saved to 'train_document/doc_119_inlegal/label'


121it [1:55:10, 48.92s/it]

X_train size: torch.Size([71, 1, 768])	Y_train size: torch.Size([71])
Tensor saved to 'train_document/doc_120_inlegal/embedding'
Tensor saved to 'train_document/doc_120_inlegal/label'


122it [1:55:36, 42.05s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_121_inlegal/embedding'
Tensor saved to 'train_document/doc_121_inlegal/label'


123it [1:56:56, 53.70s/it]

X_train size: torch.Size([201, 1, 768])	Y_train size: torch.Size([201])
Tensor saved to 'train_document/doc_122_inlegal/embedding'
Tensor saved to 'train_document/doc_122_inlegal/label'


124it [1:57:11, 41.98s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_123_inlegal/embedding'
Tensor saved to 'train_document/doc_123_inlegal/label'


125it [1:58:52, 59.69s/it]

X_train size: torch.Size([168, 1, 768])	Y_train size: torch.Size([168])
Tensor saved to 'train_document/doc_124_inlegal/embedding'
Tensor saved to 'train_document/doc_124_inlegal/label'


126it [2:00:15, 66.66s/it]

X_train size: torch.Size([213, 1, 768])	Y_train size: torch.Size([213])
Tensor saved to 'train_document/doc_125_inlegal/embedding'
Tensor saved to 'train_document/doc_125_inlegal/label'


127it [2:00:54, 58.31s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'train_document/doc_126_inlegal/embedding'
Tensor saved to 'train_document/doc_126_inlegal/label'


128it [2:01:48, 57.19s/it]

X_train size: torch.Size([85, 1, 768])	Y_train size: torch.Size([85])
Tensor saved to 'train_document/doc_127_inlegal/embedding'
Tensor saved to 'train_document/doc_127_inlegal/label'


129it [2:02:57, 60.54s/it]

X_train size: torch.Size([174, 1, 768])	Y_train size: torch.Size([174])
Tensor saved to 'train_document/doc_128_inlegal/embedding'
Tensor saved to 'train_document/doc_128_inlegal/label'


130it [2:03:18, 48.71s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_129_inlegal/embedding'
Tensor saved to 'train_document/doc_129_inlegal/label'


131it [2:04:02, 47.31s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_130_inlegal/embedding'
Tensor saved to 'train_document/doc_130_inlegal/label'


132it [2:05:17, 55.69s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'train_document/doc_131_inlegal/embedding'
Tensor saved to 'train_document/doc_131_inlegal/label'


133it [2:06:19, 57.55s/it]

X_train size: torch.Size([126, 1, 768])	Y_train size: torch.Size([126])
Tensor saved to 'train_document/doc_132_inlegal/embedding'
Tensor saved to 'train_document/doc_132_inlegal/label'


134it [2:07:06, 54.36s/it]

X_train size: torch.Size([158, 1, 768])	Y_train size: torch.Size([158])
Tensor saved to 'train_document/doc_133_inlegal/embedding'
Tensor saved to 'train_document/doc_133_inlegal/label'


135it [2:07:48, 50.80s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_134_inlegal/embedding'
Tensor saved to 'train_document/doc_134_inlegal/label'


136it [2:08:47, 53.10s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'train_document/doc_135_inlegal/embedding'
Tensor saved to 'train_document/doc_135_inlegal/label'


137it [2:09:27, 49.14s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_136_inlegal/embedding'
Tensor saved to 'train_document/doc_136_inlegal/label'


138it [2:10:24, 51.47s/it]

X_train size: torch.Size([98, 1, 768])	Y_train size: torch.Size([98])
Tensor saved to 'train_document/doc_137_inlegal/embedding'
Tensor saved to 'train_document/doc_137_inlegal/label'


139it [2:11:37, 57.98s/it]

X_train size: torch.Size([161, 1, 768])	Y_train size: torch.Size([161])
Tensor saved to 'train_document/doc_138_inlegal/embedding'
Tensor saved to 'train_document/doc_138_inlegal/label'


140it [2:12:11, 50.96s/it]

X_train size: torch.Size([93, 1, 768])	Y_train size: torch.Size([93])
Tensor saved to 'train_document/doc_139_inlegal/embedding'
Tensor saved to 'train_document/doc_139_inlegal/label'


141it [2:12:37, 43.34s/it]

X_train size: torch.Size([67, 1, 768])	Y_train size: torch.Size([67])
Tensor saved to 'train_document/doc_140_inlegal/embedding'
Tensor saved to 'train_document/doc_140_inlegal/label'


142it [2:14:02, 55.89s/it]

X_train size: torch.Size([173, 1, 768])	Y_train size: torch.Size([173])
Tensor saved to 'train_document/doc_141_inlegal/embedding'
Tensor saved to 'train_document/doc_141_inlegal/label'


143it [2:17:02, 93.03s/it]

X_train size: torch.Size([234, 1, 768])	Y_train size: torch.Size([234])
Tensor saved to 'train_document/doc_142_inlegal/embedding'
Tensor saved to 'train_document/doc_142_inlegal/label'


144it [2:20:55, 135.13s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_143_inlegal/embedding'
Tensor saved to 'train_document/doc_143_inlegal/label'


145it [2:21:56, 112.77s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_144_inlegal/embedding'
Tensor saved to 'train_document/doc_144_inlegal/label'


146it [2:22:15, 84.69s/it] 

X_train size: torch.Size([42, 1, 768])	Y_train size: torch.Size([42])
Tensor saved to 'train_document/doc_145_inlegal/embedding'
Tensor saved to 'train_document/doc_145_inlegal/label'


147it [2:23:13, 76.76s/it]

X_train size: torch.Size([164, 1, 768])	Y_train size: torch.Size([164])
Tensor saved to 'train_document/doc_146_inlegal/embedding'
Tensor saved to 'train_document/doc_146_inlegal/label'


148it [2:23:32, 59.28s/it]

X_train size: torch.Size([36, 1, 768])	Y_train size: torch.Size([36])
Tensor saved to 'train_document/doc_147_inlegal/embedding'
Tensor saved to 'train_document/doc_147_inlegal/label'


149it [2:23:54, 48.28s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'train_document/doc_148_inlegal/embedding'
Tensor saved to 'train_document/doc_148_inlegal/label'


150it [2:27:15, 93.92s/it]

X_train size: torch.Size([386, 1, 768])	Y_train size: torch.Size([386])
Tensor saved to 'train_document/doc_149_inlegal/embedding'
Tensor saved to 'train_document/doc_149_inlegal/label'


151it [2:28:24, 86.62s/it]

X_train size: torch.Size([119, 1, 768])	Y_train size: torch.Size([119])
Tensor saved to 'train_document/doc_150_inlegal/embedding'
Tensor saved to 'train_document/doc_150_inlegal/label'


152it [2:29:09, 74.00s/it]

X_train size: torch.Size([83, 1, 768])	Y_train size: torch.Size([83])
Tensor saved to 'train_document/doc_151_inlegal/embedding'
Tensor saved to 'train_document/doc_151_inlegal/label'


153it [2:30:58, 84.57s/it]

X_train size: torch.Size([200, 1, 768])	Y_train size: torch.Size([200])
Tensor saved to 'train_document/doc_152_inlegal/embedding'
Tensor saved to 'train_document/doc_152_inlegal/label'


154it [2:33:45, 109.29s/it]

X_train size: torch.Size([159, 1, 768])	Y_train size: torch.Size([159])
Tensor saved to 'train_document/doc_153_inlegal/embedding'
Tensor saved to 'train_document/doc_153_inlegal/label'


155it [2:34:12, 84.42s/it] 

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_154_inlegal/embedding'
Tensor saved to 'train_document/doc_154_inlegal/label'


156it [2:34:42, 68.16s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_155_inlegal/embedding'
Tensor saved to 'train_document/doc_155_inlegal/label'


157it [2:35:09, 55.97s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_156_inlegal/embedding'
Tensor saved to 'train_document/doc_156_inlegal/label'


158it [2:36:04, 55.58s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_157_inlegal/embedding'
Tensor saved to 'train_document/doc_157_inlegal/label'


159it [2:37:07, 57.70s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'train_document/doc_158_inlegal/embedding'
Tensor saved to 'train_document/doc_158_inlegal/label'


160it [2:37:38, 49.94s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_159_inlegal/embedding'
Tensor saved to 'train_document/doc_159_inlegal/label'


161it [2:38:40, 53.52s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'train_document/doc_160_inlegal/embedding'
Tensor saved to 'train_document/doc_160_inlegal/label'


162it [2:38:55, 41.78s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_161_inlegal/embedding'
Tensor saved to 'train_document/doc_161_inlegal/label'


163it [2:39:35, 41.28s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_162_inlegal/embedding'
Tensor saved to 'train_document/doc_162_inlegal/label'


164it [2:40:16, 41.33s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_163_inlegal/embedding'
Tensor saved to 'train_document/doc_163_inlegal/label'


165it [2:41:29, 50.60s/it]

X_train size: torch.Size([190, 1, 768])	Y_train size: torch.Size([190])
Tensor saved to 'train_document/doc_164_inlegal/embedding'
Tensor saved to 'train_document/doc_164_inlegal/label'


166it [2:41:45, 40.36s/it]

X_train size: torch.Size([33, 1, 768])	Y_train size: torch.Size([33])
Tensor saved to 'train_document/doc_165_inlegal/embedding'
Tensor saved to 'train_document/doc_165_inlegal/label'


167it [2:44:03, 69.81s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_166_inlegal/embedding'
Tensor saved to 'train_document/doc_166_inlegal/label'


168it [2:44:51, 63.18s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'train_document/doc_167_inlegal/embedding'
Tensor saved to 'train_document/doc_167_inlegal/label'


169it [2:45:30, 55.74s/it]

X_train size: torch.Size([52, 1, 768])	Y_train size: torch.Size([52])
Tensor saved to 'train_document/doc_168_inlegal/embedding'
Tensor saved to 'train_document/doc_168_inlegal/label'


170it [2:46:16, 52.88s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_169_inlegal/embedding'
Tensor saved to 'train_document/doc_169_inlegal/label'


171it [2:47:48, 64.53s/it]

X_train size: torch.Size([251, 1, 768])	Y_train size: torch.Size([251])
Tensor saved to 'train_document/doc_170_inlegal/embedding'
Tensor saved to 'train_document/doc_170_inlegal/label'


172it [2:48:41, 61.15s/it]

X_train size: torch.Size([117, 1, 768])	Y_train size: torch.Size([117])
Tensor saved to 'train_document/doc_171_inlegal/embedding'
Tensor saved to 'train_document/doc_171_inlegal/label'


173it [2:49:32, 58.19s/it]

X_train size: torch.Size([146, 1, 768])	Y_train size: torch.Size([146])
Tensor saved to 'train_document/doc_172_inlegal/embedding'
Tensor saved to 'train_document/doc_172_inlegal/label'


174it [2:50:03, 49.94s/it]

X_train size: torch.Size([84, 1, 768])	Y_train size: torch.Size([84])
Tensor saved to 'train_document/doc_173_inlegal/embedding'
Tensor saved to 'train_document/doc_173_inlegal/label'


175it [2:52:23, 77.15s/it]

X_train size: torch.Size([223, 1, 768])	Y_train size: torch.Size([223])
Tensor saved to 'train_document/doc_174_inlegal/embedding'
Tensor saved to 'train_document/doc_174_inlegal/label'


176it [2:52:50, 62.11s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_175_inlegal/embedding'
Tensor saved to 'train_document/doc_175_inlegal/label'


177it [2:55:55, 98.97s/it]

X_train size: torch.Size([248, 1, 768])	Y_train size: torch.Size([248])
Tensor saved to 'train_document/doc_176_inlegal/embedding'
Tensor saved to 'train_document/doc_176_inlegal/label'


178it [2:56:31, 80.09s/it]

X_train size: torch.Size([69, 1, 768])	Y_train size: torch.Size([69])
Tensor saved to 'train_document/doc_177_inlegal/embedding'
Tensor saved to 'train_document/doc_177_inlegal/label'


179it [2:57:31, 73.86s/it]

X_train size: torch.Size([89, 1, 768])	Y_train size: torch.Size([89])
Tensor saved to 'train_document/doc_178_inlegal/embedding'
Tensor saved to 'train_document/doc_178_inlegal/label'


180it [2:58:13, 64.29s/it]

X_train size: torch.Size([124, 1, 768])	Y_train size: torch.Size([124])
Tensor saved to 'train_document/doc_179_inlegal/embedding'
Tensor saved to 'train_document/doc_179_inlegal/label'


181it [2:58:29, 49.89s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_180_inlegal/embedding'
Tensor saved to 'train_document/doc_180_inlegal/label'


182it [2:59:06, 46.01s/it]

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_181_inlegal/embedding'
Tensor saved to 'train_document/doc_181_inlegal/label'


183it [2:59:48, 44.88s/it]

X_train size: torch.Size([131, 1, 768])	Y_train size: torch.Size([131])
Tensor saved to 'train_document/doc_182_inlegal/embedding'
Tensor saved to 'train_document/doc_182_inlegal/label'


184it [3:00:43, 47.89s/it]

X_train size: torch.Size([117, 1, 768])	Y_train size: torch.Size([117])
Tensor saved to 'train_document/doc_183_inlegal/embedding'
Tensor saved to 'train_document/doc_183_inlegal/label'


185it [3:00:55, 37.04s/it]

X_train size: torch.Size([27, 1, 768])	Y_train size: torch.Size([27])
Tensor saved to 'train_document/doc_184_inlegal/embedding'
Tensor saved to 'train_document/doc_184_inlegal/label'


186it [3:01:41, 39.68s/it]

X_train size: torch.Size([106, 1, 768])	Y_train size: torch.Size([106])
Tensor saved to 'train_document/doc_185_inlegal/embedding'
Tensor saved to 'train_document/doc_185_inlegal/label'


187it [3:03:13, 55.48s/it]

X_train size: torch.Size([244, 1, 768])	Y_train size: torch.Size([244])
Tensor saved to 'train_document/doc_186_inlegal/embedding'
Tensor saved to 'train_document/doc_186_inlegal/label'


188it [3:03:38, 46.33s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_187_inlegal/embedding'
Tensor saved to 'train_document/doc_187_inlegal/label'


189it [3:04:11, 42.33s/it]

X_train size: torch.Size([78, 1, 768])	Y_train size: torch.Size([78])
Tensor saved to 'train_document/doc_188_inlegal/embedding'
Tensor saved to 'train_document/doc_188_inlegal/label'


190it [3:05:24, 51.43s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'train_document/doc_189_inlegal/embedding'
Tensor saved to 'train_document/doc_189_inlegal/label'


191it [3:05:39, 40.72s/it]

X_train size: torch.Size([36, 1, 768])	Y_train size: torch.Size([36])
Tensor saved to 'train_document/doc_190_inlegal/embedding'
Tensor saved to 'train_document/doc_190_inlegal/label'


192it [3:06:27, 42.86s/it]

X_train size: torch.Size([128, 1, 768])	Y_train size: torch.Size([128])
Tensor saved to 'train_document/doc_191_inlegal/embedding'
Tensor saved to 'train_document/doc_191_inlegal/label'


193it [3:07:56, 56.48s/it]

X_train size: torch.Size([154, 1, 768])	Y_train size: torch.Size([154])
Tensor saved to 'train_document/doc_192_inlegal/embedding'
Tensor saved to 'train_document/doc_192_inlegal/label'


194it [3:08:30, 49.78s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_193_inlegal/embedding'
Tensor saved to 'train_document/doc_193_inlegal/label'


195it [3:08:52, 41.61s/it]

X_train size: torch.Size([61, 1, 768])	Y_train size: torch.Size([61])
Tensor saved to 'train_document/doc_194_inlegal/embedding'
Tensor saved to 'train_document/doc_194_inlegal/label'


196it [3:11:34, 77.64s/it]

X_train size: torch.Size([245, 1, 768])	Y_train size: torch.Size([245])
Tensor saved to 'train_document/doc_195_inlegal/embedding'
Tensor saved to 'train_document/doc_195_inlegal/label'


197it [3:12:27, 70.23s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'train_document/doc_196_inlegal/embedding'
Tensor saved to 'train_document/doc_196_inlegal/label'


198it [3:13:03, 60.04s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_197_inlegal/embedding'
Tensor saved to 'train_document/doc_197_inlegal/label'


199it [3:13:46, 54.93s/it]

X_train size: torch.Size([120, 1, 768])	Y_train size: torch.Size([120])
Tensor saved to 'train_document/doc_198_inlegal/embedding'
Tensor saved to 'train_document/doc_198_inlegal/label'


200it [3:14:08, 44.90s/it]

X_train size: torch.Size([89, 1, 768])	Y_train size: torch.Size([89])
Tensor saved to 'train_document/doc_199_inlegal/embedding'
Tensor saved to 'train_document/doc_199_inlegal/label'


201it [3:14:39, 40.76s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_200_inlegal/embedding'
Tensor saved to 'train_document/doc_200_inlegal/label'


202it [3:15:09, 37.64s/it]

X_train size: torch.Size([74, 1, 768])	Y_train size: torch.Size([74])
Tensor saved to 'train_document/doc_201_inlegal/embedding'
Tensor saved to 'train_document/doc_201_inlegal/label'


203it [3:15:55, 40.20s/it]

X_train size: torch.Size([92, 1, 768])	Y_train size: torch.Size([92])
Tensor saved to 'train_document/doc_202_inlegal/embedding'
Tensor saved to 'train_document/doc_202_inlegal/label'


204it [3:17:33, 57.44s/it]

X_train size: torch.Size([240, 1, 768])	Y_train size: torch.Size([240])
Tensor saved to 'train_document/doc_203_inlegal/embedding'
Tensor saved to 'train_document/doc_203_inlegal/label'


205it [3:18:57, 65.40s/it]

X_train size: torch.Size([102, 1, 768])	Y_train size: torch.Size([102])
Tensor saved to 'train_document/doc_204_inlegal/embedding'
Tensor saved to 'train_document/doc_204_inlegal/label'


206it [3:19:48, 61.19s/it]

X_train size: torch.Size([107, 1, 768])	Y_train size: torch.Size([107])
Tensor saved to 'train_document/doc_205_inlegal/embedding'
Tensor saved to 'train_document/doc_205_inlegal/label'


207it [3:20:10, 49.33s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_206_inlegal/embedding'
Tensor saved to 'train_document/doc_206_inlegal/label'


208it [3:21:48, 63.97s/it]

X_train size: torch.Size([168, 1, 768])	Y_train size: torch.Size([168])
Tensor saved to 'train_document/doc_207_inlegal/embedding'
Tensor saved to 'train_document/doc_207_inlegal/label'


209it [3:23:58, 83.68s/it]

X_train size: torch.Size([181, 1, 768])	Y_train size: torch.Size([181])
Tensor saved to 'train_document/doc_208_inlegal/embedding'
Tensor saved to 'train_document/doc_208_inlegal/label'


210it [3:24:12, 62.78s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_209_inlegal/embedding'
Tensor saved to 'train_document/doc_209_inlegal/label'


211it [3:26:11, 79.64s/it]

X_train size: torch.Size([248, 1, 768])	Y_train size: torch.Size([248])
Tensor saved to 'train_document/doc_210_inlegal/embedding'
Tensor saved to 'train_document/doc_210_inlegal/label'


212it [3:27:40, 82.60s/it]

X_train size: torch.Size([141, 1, 768])	Y_train size: torch.Size([141])
Tensor saved to 'train_document/doc_211_inlegal/embedding'
Tensor saved to 'train_document/doc_211_inlegal/label'


213it [3:28:08, 66.05s/it]

X_train size: torch.Size([47, 1, 768])	Y_train size: torch.Size([47])
Tensor saved to 'train_document/doc_212_inlegal/embedding'
Tensor saved to 'train_document/doc_212_inlegal/label'


214it [3:28:43, 56.99s/it]

X_train size: torch.Size([81, 1, 768])	Y_train size: torch.Size([81])
Tensor saved to 'train_document/doc_213_inlegal/embedding'
Tensor saved to 'train_document/doc_213_inlegal/label'


215it [3:29:34, 55.11s/it]

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_214_inlegal/embedding'
Tensor saved to 'train_document/doc_214_inlegal/label'


216it [3:29:50, 43.46s/it]

X_train size: torch.Size([60, 1, 768])	Y_train size: torch.Size([60])
Tensor saved to 'train_document/doc_215_inlegal/embedding'
Tensor saved to 'train_document/doc_215_inlegal/label'


217it [3:31:11, 54.47s/it]

X_train size: torch.Size([179, 1, 768])	Y_train size: torch.Size([179])
Tensor saved to 'train_document/doc_216_inlegal/embedding'
Tensor saved to 'train_document/doc_216_inlegal/label'


218it [3:31:49, 49.58s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_217_inlegal/embedding'
Tensor saved to 'train_document/doc_217_inlegal/label'


219it [3:32:58, 55.43s/it]

X_train size: torch.Size([176, 1, 768])	Y_train size: torch.Size([176])
Tensor saved to 'train_document/doc_218_inlegal/embedding'
Tensor saved to 'train_document/doc_218_inlegal/label'


220it [3:35:29, 84.02s/it]

X_train size: torch.Size([283, 1, 768])	Y_train size: torch.Size([283])
Tensor saved to 'train_document/doc_219_inlegal/embedding'
Tensor saved to 'train_document/doc_219_inlegal/label'


221it [3:37:10, 89.38s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_220_inlegal/embedding'
Tensor saved to 'train_document/doc_220_inlegal/label'


222it [3:38:49, 92.11s/it]

X_train size: torch.Size([263, 1, 768])	Y_train size: torch.Size([263])
Tensor saved to 'train_document/doc_221_inlegal/embedding'
Tensor saved to 'train_document/doc_221_inlegal/label'


223it [3:39:10, 70.71s/it]

X_train size: torch.Size([58, 1, 768])	Y_train size: torch.Size([58])
Tensor saved to 'train_document/doc_222_inlegal/embedding'
Tensor saved to 'train_document/doc_222_inlegal/label'


224it [3:41:33, 92.59s/it]

X_train size: torch.Size([193, 1, 768])	Y_train size: torch.Size([193])
Tensor saved to 'train_document/doc_223_inlegal/embedding'
Tensor saved to 'train_document/doc_223_inlegal/label'


225it [3:42:25, 80.36s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'train_document/doc_224_inlegal/embedding'
Tensor saved to 'train_document/doc_224_inlegal/label'
X_train size: torch.Size([0, 1, 768])	Y_train size: torch.Size([0])
Tensor saved to 'train_document/doc_225_inlegal/embedding'
Tensor saved to 'train_document/doc_225_inlegal/label'


227it [3:43:08, 53.04s/it]

X_train size: torch.Size([148, 1, 768])	Y_train size: torch.Size([148])
Tensor saved to 'train_document/doc_226_inlegal/embedding'
Tensor saved to 'train_document/doc_226_inlegal/label'


228it [3:43:26, 44.39s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_227_inlegal/embedding'
Tensor saved to 'train_document/doc_227_inlegal/label'


229it [3:44:38, 51.59s/it]

X_train size: torch.Size([255, 1, 768])	Y_train size: torch.Size([255])
Tensor saved to 'train_document/doc_228_inlegal/embedding'
Tensor saved to 'train_document/doc_228_inlegal/label'


230it [3:45:03, 44.33s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'train_document/doc_229_inlegal/embedding'
Tensor saved to 'train_document/doc_229_inlegal/label'


231it [3:45:24, 37.92s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_230_inlegal/embedding'
Tensor saved to 'train_document/doc_230_inlegal/label'


232it [3:47:02, 55.04s/it]

X_train size: torch.Size([214, 1, 768])	Y_train size: torch.Size([214])
Tensor saved to 'train_document/doc_231_inlegal/embedding'
Tensor saved to 'train_document/doc_231_inlegal/label'


233it [3:48:35, 66.05s/it]

X_train size: torch.Size([246, 1, 768])	Y_train size: torch.Size([246])
Tensor saved to 'train_document/doc_232_inlegal/embedding'
Tensor saved to 'train_document/doc_232_inlegal/label'


234it [3:49:40, 65.73s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_233_inlegal/embedding'
Tensor saved to 'train_document/doc_233_inlegal/label'


235it [3:50:36, 62.78s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_234_inlegal/embedding'
Tensor saved to 'train_document/doc_234_inlegal/label'


236it [3:50:58, 50.91s/it]

X_train size: torch.Size([67, 1, 768])	Y_train size: torch.Size([67])
Tensor saved to 'train_document/doc_235_inlegal/embedding'
Tensor saved to 'train_document/doc_235_inlegal/label'


237it [3:51:15, 40.79s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_236_inlegal/embedding'
Tensor saved to 'train_document/doc_236_inlegal/label'


238it [3:51:40, 36.11s/it]

X_train size: torch.Size([56, 1, 768])	Y_train size: torch.Size([56])
Tensor saved to 'train_document/doc_237_inlegal/embedding'
Tensor saved to 'train_document/doc_237_inlegal/label'


239it [3:53:22, 55.58s/it]

X_train size: torch.Size([212, 1, 768])	Y_train size: torch.Size([212])
Tensor saved to 'train_document/doc_238_inlegal/embedding'
Tensor saved to 'train_document/doc_238_inlegal/label'


240it [3:55:17, 73.44s/it]

X_train size: torch.Size([215, 1, 768])	Y_train size: torch.Size([215])
Tensor saved to 'train_document/doc_239_inlegal/embedding'
Tensor saved to 'train_document/doc_239_inlegal/label'


241it [3:55:53, 62.15s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_240_inlegal/embedding'
Tensor saved to 'train_document/doc_240_inlegal/label'


242it [3:56:06, 47.53s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_241_inlegal/embedding'
Tensor saved to 'train_document/doc_241_inlegal/label'


243it [3:56:40, 43.47s/it]

X_train size: torch.Size([112, 1, 768])	Y_train size: torch.Size([112])
Tensor saved to 'train_document/doc_242_inlegal/embedding'
Tensor saved to 'train_document/doc_242_inlegal/label'


244it [3:57:03, 37.45s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'train_document/doc_243_inlegal/embedding'
Tensor saved to 'train_document/doc_243_inlegal/label'


245it [3:58:13, 47.13s/it]

X_train size: torch.Size([90, 1, 768])	Y_train size: torch.Size([90])
Tensor saved to 'train_document/doc_244_inlegal/embedding'
Tensor saved to 'train_document/doc_244_inlegal/label'


246it [3:59:33, 58.43s/it]


X_train size: torch.Size([122, 1, 768])	Y_train size: torch.Size([122])
Tensor saved to 'train_document/doc_245_inlegal/embedding'
Tensor saved to 'train_document/doc_245_inlegal/label'


1it [01:09, 69.86s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'test_document/doc_0_inlegal/embedding'
Tensor saved to 'test_document/doc_0_inlegal/label'


2it [02:05, 61.62s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'test_document/doc_1_inlegal/embedding'
Tensor saved to 'test_document/doc_1_inlegal/label'


3it [03:00, 58.62s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'test_document/doc_2_inlegal/embedding'
Tensor saved to 'test_document/doc_2_inlegal/label'


4it [03:24, 44.86s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'test_document/doc_3_inlegal/embedding'
Tensor saved to 'test_document/doc_3_inlegal/label'


5it [05:14, 68.28s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'test_document/doc_4_inlegal/embedding'
Tensor saved to 'test_document/doc_4_inlegal/label'


6it [05:33, 51.56s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'test_document/doc_5_inlegal/embedding'
Tensor saved to 'test_document/doc_5_inlegal/label'


7it [06:00, 43.40s/it]

X_train size: torch.Size([68, 1, 768])	Y_train size: torch.Size([68])
Tensor saved to 'test_document/doc_6_inlegal/embedding'
Tensor saved to 'test_document/doc_6_inlegal/label'


8it [06:31, 39.58s/it]

X_train size: torch.Size([113, 1, 768])	Y_train size: torch.Size([113])
Tensor saved to 'test_document/doc_7_inlegal/embedding'
Tensor saved to 'test_document/doc_7_inlegal/label'


9it [08:17, 60.29s/it]

X_train size: torch.Size([199, 1, 768])	Y_train size: torch.Size([199])
Tensor saved to 'test_document/doc_8_inlegal/embedding'
Tensor saved to 'test_document/doc_8_inlegal/label'


10it [09:02, 55.52s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'test_document/doc_9_inlegal/embedding'
Tensor saved to 'test_document/doc_9_inlegal/label'


11it [09:25, 45.63s/it]

X_train size: torch.Size([76, 1, 768])	Y_train size: torch.Size([76])
Tensor saved to 'test_document/doc_10_inlegal/embedding'
Tensor saved to 'test_document/doc_10_inlegal/label'


12it [09:53, 40.29s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'test_document/doc_11_inlegal/embedding'
Tensor saved to 'test_document/doc_11_inlegal/label'


13it [11:08, 50.71s/it]

X_train size: torch.Size([209, 1, 768])	Y_train size: torch.Size([209])
Tensor saved to 'test_document/doc_12_inlegal/embedding'
Tensor saved to 'test_document/doc_12_inlegal/label'


14it [11:45, 46.67s/it]

X_train size: torch.Size([135, 1, 768])	Y_train size: torch.Size([135])
Tensor saved to 'test_document/doc_13_inlegal/embedding'
Tensor saved to 'test_document/doc_13_inlegal/label'


15it [12:13, 40.95s/it]

X_train size: torch.Size([64, 1, 768])	Y_train size: torch.Size([64])
Tensor saved to 'test_document/doc_14_inlegal/embedding'
Tensor saved to 'test_document/doc_14_inlegal/label'


16it [12:57, 42.02s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'test_document/doc_15_inlegal/embedding'
Tensor saved to 'test_document/doc_15_inlegal/label'


17it [13:36, 41.11s/it]

X_train size: torch.Size([98, 1, 768])	Y_train size: torch.Size([98])
Tensor saved to 'test_document/doc_16_inlegal/embedding'
Tensor saved to 'test_document/doc_16_inlegal/label'


18it [14:28, 44.34s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'test_document/doc_17_inlegal/embedding'
Tensor saved to 'test_document/doc_17_inlegal/label'


19it [14:58, 40.00s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'test_document/doc_18_inlegal/embedding'
Tensor saved to 'test_document/doc_18_inlegal/label'


20it [16:14, 50.95s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'test_document/doc_19_inlegal/embedding'
Tensor saved to 'test_document/doc_19_inlegal/label'


21it [16:34, 41.52s/it]

X_train size: torch.Size([46, 1, 768])	Y_train size: torch.Size([46])
Tensor saved to 'test_document/doc_20_inlegal/embedding'
Tensor saved to 'test_document/doc_20_inlegal/label'


22it [17:04, 38.03s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'test_document/doc_21_inlegal/embedding'
Tensor saved to 'test_document/doc_21_inlegal/label'


23it [17:23, 32.49s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'test_document/doc_22_inlegal/embedding'
Tensor saved to 'test_document/doc_22_inlegal/label'


24it [17:34, 25.96s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'test_document/doc_23_inlegal/embedding'
Tensor saved to 'test_document/doc_23_inlegal/label'


25it [18:18, 31.28s/it]

X_train size: torch.Size([112, 1, 768])	Y_train size: torch.Size([112])
Tensor saved to 'test_document/doc_24_inlegal/embedding'
Tensor saved to 'test_document/doc_24_inlegal/label'


26it [19:37, 45.82s/it]

X_train size: torch.Size([186, 1, 768])	Y_train size: torch.Size([186])
Tensor saved to 'test_document/doc_25_inlegal/embedding'
Tensor saved to 'test_document/doc_25_inlegal/label'


27it [19:55, 37.24s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'test_document/doc_26_inlegal/embedding'
Tensor saved to 'test_document/doc_26_inlegal/label'


28it [20:09, 30.42s/it]

X_train size: torch.Size([43, 1, 768])	Y_train size: torch.Size([43])
Tensor saved to 'test_document/doc_27_inlegal/embedding'
Tensor saved to 'test_document/doc_27_inlegal/label'


29it [20:26, 42.30s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'test_document/doc_28_inlegal/embedding'
Tensor saved to 'test_document/doc_28_inlegal/label'





## BiLSTM

In [51]:
model5 = BiLSTM(hidden_size=128, dropout= 0.3, output_size= 11)
optimizer = torch.optim.Adam(model5.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model5.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/embedding")
        TRAIN_labels = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model5(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.05:
        break

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:19, 12.35it/s]

100%|██████████| 246/246 [00:11<00:00, 20.92it/s]


Epoch: 1 	 Loss: 1.42305


100%|██████████| 246/246 [00:10<00:00, 23.23it/s]


Epoch: 2 	 Loss: 0.84203


100%|██████████| 246/246 [00:08<00:00, 27.97it/s]


Epoch: 3 	 Loss: 0.66003


100%|██████████| 246/246 [00:11<00:00, 22.28it/s]


Epoch: 4 	 Loss: 0.55399


100%|██████████| 246/246 [00:09<00:00, 24.94it/s]


Epoch: 5 	 Loss: 0.46063


 70%|██████▉   | 171/246 [00:06<00:02, 34.84it/s]

In [None]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"test_document/doc_{i}_inlegal/embedding")
    TEST_labels = load_tensor(filepath=f"test_document/doc_{i}_inlegal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model5, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.78586724 0.64       0.73699422 0.69090909 0.86       0.53896104
 0.9960396  0.47826087 0.59574468 0.88888889 0.38297872] 
 Average acccuracy: 0.6904222139891361
F1 Scores: [0.76898895 0.5363128  0.80251765 0.72380947 0.88205123 0.54248361
 0.99702671 0.46808506 0.34355824 0.85207095 0.47368416] 
 Average F1: 0.671871712604626


## CNN-BiLSTM

In [None]:
model6 = BiLSTM(hidden_size=128, dropout= 0.3, output_size= 11)
optimizer = torch.optim.Adam(model6.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model6.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/embedding")
        TRAIN_labels = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model6(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.05:
        break

In [None]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"test_document/doc_{i}_inlegal/embedding")
    TEST_labels = load_tensor(filepath=f"test_document/doc_{i}_inlegal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model6, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))