In [31]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'..')))
from utils import load_tensor
from models import BiLSTM, CNN_BiLSTM
from evaluation import calculate_confusion_matrix, class_accuracy, class_f1_score

In [32]:
def remap_targets(target_tensor: torch.TensorType, old_le, new_le):
    inverse_tensor = old_le.inverse_transform(target_tensor.long())
    for idx, label in enumerate(inverse_tensor):
        if label == 'ARG_RESPONDENT' or label == 'ARG_PETITIONER':
            inverse_tensor[idx] = "ARG"
        elif label == 'PRE_NOT_RELIED' or label == 'PRE_RELIED':
            inverse_tensor[idx] = 'PRE'
    new_tensor = torch.tensor(new_le.transform(inverse_tensor))
    return new_tensor

In [33]:
list_of_targets_old = ['ISSUE', 'FAC', 'NONE', 'ARG_PETITIONER', 'PRE_NOT_RELIED', 'STA', 'RPC', 'ARG_RESPONDENT', 'PREAMBLE', 'ANALYSIS', 'RLC', 'PRE_RELIED', 'RATIO']
label_encoder_old = LabelEncoder().fit(list_of_targets_old)
list_of_targets_new = ['ISSUE', 'FAC', 'NONE', 'ARG', 'PRE', 'STA', 'RPC', 'PREAMBLE', 'ANALYSIS', 'RLC', 'RATIO']
label_encoder_new = LabelEncoder().fit(list_of_targets_new)

In [34]:
sample_input, sample_target = None, None
for idx in range(246):
    if sample_input is None:
        sample_input = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        sample_target = load_tensor(filepath=f"../train_document/doc_{idx}/label")
    else:
        sample_input = torch.cat((sample_input,load_tensor(filepath=f"../train_document/doc_{idx}/embedding")), dim=0)
        sample_target = torch.cat((sample_target,load_tensor(filepath=f"../train_document/doc_{idx}/label")), dim=0)

In [35]:
sample_target.size()

torch.Size([28864])

In [36]:
remapped_target = remap_targets(sample_target, label_encoder_old, label_encoder_new)

In [37]:
from sklearn.utils.class_weight import compute_class_weight
    
class_weights = compute_class_weight(class_weight = "balanced",
                                    classes = np.unique(remapped_target.numpy()),
                                    y = remapped_target.numpy())
class_weights = torch.FloatTensor(class_weights)
class_weights

tensor([0.2460, 1.3107, 0.4584, 7.1499, 1.8544, 1.6639, 0.6347, 3.8990, 3.5033,
        2.4341, 5.4895])

# BERT-Base

## BiLSTM

In [44]:
model1 = BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model1.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model1.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model1(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.05:
        break
# batch_loss.append(loss.item())


-----------------------------------------Starting Training------------------------------------------


  0%|          | 0/246 [00:00<?, ?it/s]

100%|██████████| 246/246 [00:08<00:00, 27.83it/s]


Epoch: 1 	 Loss: 1.43475


100%|██████████| 246/246 [00:06<00:00, 40.76it/s]


Epoch: 2 	 Loss: 0.93312


100%|██████████| 246/246 [00:07<00:00, 34.07it/s]


Epoch: 3 	 Loss: 0.79179


100%|██████████| 246/246 [00:06<00:00, 37.67it/s]


Epoch: 4 	 Loss: 0.70181


100%|██████████| 246/246 [00:05<00:00, 42.72it/s]


Epoch: 5 	 Loss: 0.64911


100%|██████████| 246/246 [00:06<00:00, 38.52it/s]


Epoch: 6 	 Loss: 0.57823


100%|██████████| 246/246 [00:06<00:00, 38.34it/s]


Epoch: 7 	 Loss: 0.52952


100%|██████████| 246/246 [00:05<00:00, 44.64it/s]


Epoch: 8 	 Loss: 0.48770


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 9 	 Loss: 0.44855


100%|██████████| 246/246 [00:06<00:00, 38.05it/s]


Epoch: 10 	 Loss: 0.40444


100%|██████████| 246/246 [00:06<00:00, 35.71it/s]


Epoch: 11 	 Loss: 0.37040


100%|██████████| 246/246 [00:06<00:00, 40.90it/s]


Epoch: 12 	 Loss: 0.34482


100%|██████████| 246/246 [00:06<00:00, 39.34it/s]


Epoch: 13 	 Loss: 0.31127


100%|██████████| 246/246 [00:06<00:00, 38.86it/s]


Epoch: 14 	 Loss: 0.28966


100%|██████████| 246/246 [00:06<00:00, 38.85it/s]


Epoch: 15 	 Loss: 0.27008


100%|██████████| 246/246 [00:05<00:00, 41.98it/s]


Epoch: 16 	 Loss: 0.25314


100%|██████████| 246/246 [00:06<00:00, 35.79it/s]


Epoch: 17 	 Loss: 0.22750


100%|██████████| 246/246 [00:06<00:00, 37.25it/s]


Epoch: 18 	 Loss: 0.21079


100%|██████████| 246/246 [00:06<00:00, 39.78it/s]


Epoch: 19 	 Loss: 0.20896


100%|██████████| 246/246 [00:06<00:00, 38.35it/s]


Epoch: 20 	 Loss: 0.19060


100%|██████████| 246/246 [00:07<00:00, 33.94it/s]


Epoch: 21 	 Loss: 0.21203


100%|██████████| 246/246 [00:07<00:00, 34.20it/s]


Epoch: 22 	 Loss: 0.19578


100%|██████████| 246/246 [00:07<00:00, 32.72it/s]


Epoch: 23 	 Loss: 0.18171


100%|██████████| 246/246 [00:07<00:00, 34.95it/s]


Epoch: 24 	 Loss: 0.16755


100%|██████████| 246/246 [00:09<00:00, 25.85it/s]


Epoch: 25 	 Loss: 0.16802


100%|██████████| 246/246 [00:07<00:00, 32.82it/s]


Epoch: 26 	 Loss: 0.15102


100%|██████████| 246/246 [00:12<00:00, 19.75it/s]


Epoch: 27 	 Loss: 0.13811


100%|██████████| 246/246 [00:07<00:00, 33.66it/s]


Epoch: 28 	 Loss: 0.12462


100%|██████████| 246/246 [00:07<00:00, 32.52it/s]


Epoch: 29 	 Loss: 0.11645


100%|██████████| 246/246 [00:07<00:00, 33.01it/s]


Epoch: 30 	 Loss: 0.11505


100%|██████████| 246/246 [00:08<00:00, 29.56it/s]


Epoch: 31 	 Loss: 0.11638


100%|██████████| 246/246 [00:07<00:00, 32.87it/s]


Epoch: 32 	 Loss: 0.10205


100%|██████████| 246/246 [00:08<00:00, 28.31it/s]


Epoch: 33 	 Loss: 0.09478


100%|██████████| 246/246 [00:09<00:00, 25.58it/s]


Epoch: 34 	 Loss: 0.10039


100%|██████████| 246/246 [00:10<00:00, 24.50it/s]


Epoch: 35 	 Loss: 0.14272


100%|██████████| 246/246 [00:09<00:00, 26.57it/s]


Epoch: 36 	 Loss: 0.16838


100%|██████████| 246/246 [00:09<00:00, 25.05it/s]


Epoch: 37 	 Loss: 0.09507


100%|██████████| 246/246 [00:09<00:00, 25.06it/s]


Epoch: 38 	 Loss: 0.08410


100%|██████████| 246/246 [00:09<00:00, 24.89it/s]


Epoch: 39 	 Loss: 0.07366


100%|██████████| 246/246 [00:09<00:00, 25.59it/s]


Epoch: 40 	 Loss: 0.06558


100%|██████████| 246/246 [00:09<00:00, 25.37it/s]


Epoch: 41 	 Loss: 0.07380


100%|██████████| 246/246 [00:10<00:00, 24.42it/s]


Epoch: 42 	 Loss: 0.06507


100%|██████████| 246/246 [00:10<00:00, 24.47it/s]


Epoch: 43 	 Loss: 0.05514


100%|██████████| 246/246 [00:08<00:00, 29.49it/s]


Epoch: 44 	 Loss: 0.06345


100%|██████████| 246/246 [00:09<00:00, 25.40it/s]


Epoch: 45 	 Loss: 0.05806


100%|██████████| 246/246 [00:08<00:00, 29.00it/s]


Epoch: 46 	 Loss: 0.07255


100%|██████████| 246/246 [00:08<00:00, 28.72it/s]


Epoch: 47 	 Loss: 0.05761


100%|██████████| 246/246 [00:08<00:00, 27.39it/s]


Epoch: 48 	 Loss: 0.07663


100%|██████████| 246/246 [00:08<00:00, 28.58it/s]


Epoch: 49 	 Loss: 0.08222


100%|██████████| 246/246 [00:09<00:00, 24.92it/s]


Epoch: 50 	 Loss: 0.05691


100%|██████████| 246/246 [00:09<00:00, 26.51it/s]

Epoch: 51 	 Loss: 0.04955





In [45]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model1, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))
# result.append((config, (average_accuracy, average_f1)))

Accuracies: [0.84980237 0.39393939 0.7751938  0.78       0.92857143 0.45962733
 0.98434442 0.54878049 0.60416667 0.85555556 0.47368421] 
 Average acccuracy: 0.6957877877225896
F1 Scores: [0.74394459 0.44067792 0.81699341 0.77999995 0.9086021  0.62447253
 0.99113295 0.58441553 0.35365849 0.86516849 0.53731338] 
 Average F1: 0.6951253954753968


## CNN-BiLSTM

In [46]:
model2 = CNN_BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model2.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model2.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model2(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.05:
        break
# batch_loss.append(loss.item())

-----------------------------------------Starting Training------------------------------------------


  0%|          | 0/246 [00:00<?, ?it/s]

100%|██████████| 246/246 [00:18<00:00, 13.05it/s]


Epoch: 1 	 Loss: 1.56903


100%|██████████| 246/246 [00:09<00:00, 26.63it/s]


Epoch: 2 	 Loss: 1.05583


100%|██████████| 246/246 [00:08<00:00, 29.41it/s]


Epoch: 3 	 Loss: 0.87655


100%|██████████| 246/246 [00:09<00:00, 25.09it/s]


Epoch: 4 	 Loss: 0.80474


100%|██████████| 246/246 [00:09<00:00, 24.65it/s]


Epoch: 5 	 Loss: 0.77298


100%|██████████| 246/246 [00:06<00:00, 37.04it/s]


Epoch: 6 	 Loss: 0.68548


100%|██████████| 246/246 [00:10<00:00, 23.57it/s]


Epoch: 7 	 Loss: 0.63012


100%|██████████| 246/246 [00:10<00:00, 23.34it/s]


Epoch: 8 	 Loss: 0.59393


100%|██████████| 246/246 [00:08<00:00, 27.99it/s]


Epoch: 9 	 Loss: 0.55545


100%|██████████| 246/246 [00:10<00:00, 22.64it/s]


Epoch: 10 	 Loss: 0.51326


100%|██████████| 246/246 [00:07<00:00, 34.23it/s]


Epoch: 11 	 Loss: 0.47468


100%|██████████| 246/246 [00:06<00:00, 35.50it/s]


Epoch: 12 	 Loss: 0.46699


100%|██████████| 246/246 [00:06<00:00, 37.84it/s]


Epoch: 13 	 Loss: 0.42601


100%|██████████| 246/246 [00:06<00:00, 38.00it/s]


Epoch: 14 	 Loss: 0.41642


100%|██████████| 246/246 [00:06<00:00, 37.15it/s]


Epoch: 15 	 Loss: 0.43293


100%|██████████| 246/246 [00:06<00:00, 36.71it/s]


Epoch: 16 	 Loss: 0.36979


100%|██████████| 246/246 [00:07<00:00, 34.77it/s]


Epoch: 17 	 Loss: 0.33711


100%|██████████| 246/246 [00:06<00:00, 37.87it/s]


Epoch: 18 	 Loss: 0.37952


100%|██████████| 246/246 [00:06<00:00, 37.88it/s]


Epoch: 19 	 Loss: 0.34698


100%|██████████| 246/246 [00:06<00:00, 37.10it/s]


Epoch: 20 	 Loss: 0.29451


100%|██████████| 246/246 [00:06<00:00, 37.33it/s]


Epoch: 21 	 Loss: 0.27839


100%|██████████| 246/246 [00:06<00:00, 37.34it/s]


Epoch: 22 	 Loss: 0.24977


100%|██████████| 246/246 [00:06<00:00, 36.24it/s]


Epoch: 23 	 Loss: 0.24219


100%|██████████| 246/246 [00:06<00:00, 37.09it/s]


Epoch: 24 	 Loss: 0.22128


100%|██████████| 246/246 [00:06<00:00, 37.80it/s]


Epoch: 25 	 Loss: 0.23306


100%|██████████| 246/246 [00:06<00:00, 37.47it/s]


Epoch: 26 	 Loss: 0.20954


100%|██████████| 246/246 [00:06<00:00, 37.83it/s]


Epoch: 27 	 Loss: 0.20804


100%|██████████| 246/246 [00:06<00:00, 36.41it/s]


Epoch: 28 	 Loss: 0.19685


100%|██████████| 246/246 [00:06<00:00, 37.64it/s]


Epoch: 29 	 Loss: 0.18290


100%|██████████| 246/246 [00:06<00:00, 37.35it/s]


Epoch: 30 	 Loss: 0.17063


100%|██████████| 246/246 [00:06<00:00, 37.38it/s]


Epoch: 31 	 Loss: 0.16107


100%|██████████| 246/246 [00:06<00:00, 37.71it/s]


Epoch: 32 	 Loss: 0.30598


100%|██████████| 246/246 [00:06<00:00, 36.63it/s]


Epoch: 33 	 Loss: 0.21737


100%|██████████| 246/246 [00:06<00:00, 36.21it/s]


Epoch: 34 	 Loss: 0.16275


100%|██████████| 246/246 [00:06<00:00, 38.04it/s]


Epoch: 35 	 Loss: 0.14778


100%|██████████| 246/246 [00:06<00:00, 37.79it/s]


Epoch: 36 	 Loss: 0.13089


100%|██████████| 246/246 [00:06<00:00, 36.70it/s]


Epoch: 37 	 Loss: 0.13415


100%|██████████| 246/246 [00:06<00:00, 38.20it/s]


Epoch: 38 	 Loss: 0.13086


100%|██████████| 246/246 [00:06<00:00, 37.73it/s]


Epoch: 39 	 Loss: 0.11118


100%|██████████| 246/246 [00:06<00:00, 37.78it/s]


Epoch: 40 	 Loss: 0.10705


100%|██████████| 246/246 [00:07<00:00, 32.67it/s]


Epoch: 41 	 Loss: 0.10147


100%|██████████| 246/246 [00:06<00:00, 37.58it/s]


Epoch: 42 	 Loss: 0.10382


100%|██████████| 246/246 [00:06<00:00, 36.87it/s]


Epoch: 43 	 Loss: 0.10093


100%|██████████| 246/246 [00:06<00:00, 35.25it/s]


Epoch: 44 	 Loss: 0.11698


100%|██████████| 246/246 [00:06<00:00, 38.19it/s]


Epoch: 45 	 Loss: 0.19784


100%|██████████| 246/246 [00:06<00:00, 37.39it/s]


Epoch: 46 	 Loss: 0.16095


100%|██████████| 246/246 [00:07<00:00, 34.99it/s]


Epoch: 47 	 Loss: 0.09443


100%|██████████| 246/246 [00:06<00:00, 36.75it/s]


Epoch: 48 	 Loss: 0.07920


100%|██████████| 246/246 [00:06<00:00, 37.49it/s]


Epoch: 49 	 Loss: 0.07375


100%|██████████| 246/246 [00:06<00:00, 39.04it/s]


Epoch: 50 	 Loss: 0.06886


100%|██████████| 246/246 [00:06<00:00, 38.32it/s]


Epoch: 51 	 Loss: 0.06632


100%|██████████| 246/246 [00:06<00:00, 36.28it/s]


Epoch: 52 	 Loss: 0.07045


100%|██████████| 246/246 [00:06<00:00, 38.92it/s]


Epoch: 53 	 Loss: 0.06948


100%|██████████| 246/246 [00:06<00:00, 38.99it/s]


Epoch: 54 	 Loss: 0.07456


100%|██████████| 246/246 [00:06<00:00, 38.63it/s]


Epoch: 55 	 Loss: 0.07857


100%|██████████| 246/246 [00:06<00:00, 35.66it/s]


Epoch: 56 	 Loss: 0.07103


100%|██████████| 246/246 [00:06<00:00, 39.13it/s]


Epoch: 57 	 Loss: 0.05528


100%|██████████| 246/246 [00:06<00:00, 39.06it/s]


Epoch: 58 	 Loss: 0.05638


100%|██████████| 246/246 [00:06<00:00, 38.11it/s]


Epoch: 59 	 Loss: 0.13428


100%|██████████| 246/246 [00:06<00:00, 38.12it/s]


Epoch: 60 	 Loss: 0.17127


100%|██████████| 246/246 [00:06<00:00, 39.07it/s]


Epoch: 61 	 Loss: 0.07679


100%|██████████| 246/246 [00:06<00:00, 38.90it/s]


Epoch: 62 	 Loss: 0.05715


100%|██████████| 246/246 [00:06<00:00, 39.18it/s]

Epoch: 63 	 Loss: 0.04668





In [47]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model2, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.83424808 0.51351351 0.78187404 0.73584906 0.94505495 0.58947368
 0.99405941 0.47524752 0.58928571 0.86170213 0.5       ] 
 Average acccuracy: 0.7109370991812456
F1 Scores: [0.80593844 0.42696624 0.82764223 0.7572815  0.92473113 0.65497071
 0.99504455 0.55491325 0.38372089 0.89010984 0.59154925] 
 Average F1: 0.7102607293865731


# LEGAL-BERT

## BiLSTM

In [48]:
model3 = BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model3.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model3.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model3(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.1:
        break

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:16, 14.64it/s]

100%|██████████| 246/246 [00:06<00:00, 40.87it/s]


Epoch: 1 	 Loss: 2.23895


100%|██████████| 246/246 [00:05<00:00, 46.04it/s]


Epoch: 2 	 Loss: 1.92811


100%|██████████| 246/246 [00:05<00:00, 46.17it/s]


Epoch: 3 	 Loss: 1.74216


100%|██████████| 246/246 [00:10<00:00, 24.01it/s]


Epoch: 4 	 Loss: 1.64417


100%|██████████| 246/246 [00:13<00:00, 18.84it/s]


Epoch: 5 	 Loss: 1.56349


100%|██████████| 246/246 [00:05<00:00, 42.69it/s]


Epoch: 6 	 Loss: 1.50654


100%|██████████| 246/246 [00:05<00:00, 42.75it/s]


Epoch: 7 	 Loss: 1.44931


100%|██████████| 246/246 [00:05<00:00, 42.88it/s]


Epoch: 8 	 Loss: 1.39359


100%|██████████| 246/246 [00:06<00:00, 40.72it/s]


Epoch: 9 	 Loss: 1.33977


100%|██████████| 246/246 [00:05<00:00, 42.74it/s]


Epoch: 10 	 Loss: 1.30155


100%|██████████| 246/246 [00:05<00:00, 42.59it/s]


Epoch: 11 	 Loss: 1.24400


100%|██████████| 246/246 [00:05<00:00, 43.17it/s]


Epoch: 12 	 Loss: 1.22424


100%|██████████| 246/246 [00:07<00:00, 35.14it/s]


Epoch: 13 	 Loss: 1.18072


100%|██████████| 246/246 [00:06<00:00, 36.53it/s]


Epoch: 14 	 Loss: 1.11589


100%|██████████| 246/246 [00:07<00:00, 34.73it/s]


Epoch: 15 	 Loss: 1.07232


100%|██████████| 246/246 [00:05<00:00, 43.39it/s]


Epoch: 16 	 Loss: 1.03613


100%|██████████| 246/246 [00:05<00:00, 42.27it/s]


Epoch: 17 	 Loss: 1.03635


100%|██████████| 246/246 [00:05<00:00, 42.37it/s]


Epoch: 18 	 Loss: 0.96901


100%|██████████| 246/246 [00:05<00:00, 43.44it/s]


Epoch: 19 	 Loss: 0.93235


100%|██████████| 246/246 [00:05<00:00, 43.64it/s]


Epoch: 20 	 Loss: 0.91914


100%|██████████| 246/246 [00:05<00:00, 42.04it/s]


Epoch: 21 	 Loss: 0.90930


100%|██████████| 246/246 [00:06<00:00, 40.81it/s]


Epoch: 22 	 Loss: 0.85934


100%|██████████| 246/246 [00:08<00:00, 27.93it/s]


Epoch: 23 	 Loss: 0.84030


100%|██████████| 246/246 [00:05<00:00, 42.79it/s]


Epoch: 24 	 Loss: 0.82298


100%|██████████| 246/246 [00:06<00:00, 36.08it/s]


Epoch: 25 	 Loss: 0.78593


100%|██████████| 246/246 [00:08<00:00, 30.46it/s]


Epoch: 26 	 Loss: 0.75203


100%|██████████| 246/246 [00:07<00:00, 32.98it/s]


Epoch: 27 	 Loss: 0.70009


100%|██████████| 246/246 [00:05<00:00, 42.21it/s]


Epoch: 28 	 Loss: 0.69822


100%|██████████| 246/246 [00:05<00:00, 42.92it/s]


Epoch: 29 	 Loss: 0.68217


100%|██████████| 246/246 [00:05<00:00, 42.51it/s]


Epoch: 30 	 Loss: 0.66032


100%|██████████| 246/246 [00:06<00:00, 40.10it/s]


Epoch: 31 	 Loss: 0.64246


100%|██████████| 246/246 [00:05<00:00, 42.25it/s]


Epoch: 32 	 Loss: 0.61633


100%|██████████| 246/246 [00:05<00:00, 42.75it/s]


Epoch: 33 	 Loss: 0.62024


100%|██████████| 246/246 [00:05<00:00, 42.24it/s]


Epoch: 34 	 Loss: 0.59078


100%|██████████| 246/246 [00:06<00:00, 39.40it/s]


Epoch: 35 	 Loss: 0.59181


100%|██████████| 246/246 [00:05<00:00, 42.56it/s]


Epoch: 36 	 Loss: 0.55289


100%|██████████| 246/246 [00:05<00:00, 42.34it/s]


Epoch: 37 	 Loss: 0.55273


100%|██████████| 246/246 [00:06<00:00, 38.85it/s]


Epoch: 38 	 Loss: 0.53812


100%|██████████| 246/246 [00:05<00:00, 41.07it/s]


Epoch: 39 	 Loss: 0.52794


100%|██████████| 246/246 [00:05<00:00, 42.98it/s]


Epoch: 40 	 Loss: 0.49875


100%|██████████| 246/246 [00:05<00:00, 42.22it/s]


Epoch: 41 	 Loss: 0.48304


100%|██████████| 246/246 [00:05<00:00, 42.06it/s]


Epoch: 42 	 Loss: 0.47894


100%|██████████| 246/246 [00:06<00:00, 40.91it/s]


Epoch: 43 	 Loss: 0.45923


100%|██████████| 246/246 [00:05<00:00, 42.11it/s]


Epoch: 44 	 Loss: 0.45032


100%|██████████| 246/246 [00:05<00:00, 42.61it/s]


Epoch: 45 	 Loss: 0.43891


100%|██████████| 246/246 [00:05<00:00, 42.34it/s]


Epoch: 46 	 Loss: 0.43217


100%|██████████| 246/246 [00:06<00:00, 40.82it/s]


Epoch: 47 	 Loss: 0.43255


100%|██████████| 246/246 [00:05<00:00, 41.11it/s]


Epoch: 48 	 Loss: 0.41746


100%|██████████| 246/246 [00:05<00:00, 41.96it/s]


Epoch: 49 	 Loss: 0.38461


100%|██████████| 246/246 [00:05<00:00, 42.27it/s]


Epoch: 50 	 Loss: 0.38363


100%|██████████| 246/246 [00:06<00:00, 40.89it/s]


Epoch: 51 	 Loss: 0.38307


100%|██████████| 246/246 [00:05<00:00, 41.76it/s]


Epoch: 52 	 Loss: 0.37923


100%|██████████| 246/246 [00:05<00:00, 42.57it/s]


Epoch: 53 	 Loss: 0.37958


100%|██████████| 246/246 [00:05<00:00, 41.08it/s]


Epoch: 54 	 Loss: 0.36562


100%|██████████| 246/246 [00:05<00:00, 41.01it/s]


Epoch: 55 	 Loss: 0.35973


100%|██████████| 246/246 [00:05<00:00, 42.32it/s]


Epoch: 56 	 Loss: 0.34195


100%|██████████| 246/246 [00:05<00:00, 42.66it/s]


Epoch: 57 	 Loss: 0.34643


100%|██████████| 246/246 [00:05<00:00, 42.39it/s]


Epoch: 58 	 Loss: 0.35170


100%|██████████| 246/246 [00:06<00:00, 40.22it/s]


Epoch: 59 	 Loss: 0.33467


100%|██████████| 246/246 [00:05<00:00, 42.61it/s]


Epoch: 60 	 Loss: 0.32935


100%|██████████| 246/246 [00:05<00:00, 42.24it/s]


Epoch: 61 	 Loss: 0.32701


100%|██████████| 246/246 [00:05<00:00, 41.78it/s]


Epoch: 62 	 Loss: 0.30996


100%|██████████| 246/246 [00:06<00:00, 40.27it/s]


Epoch: 63 	 Loss: 0.29837


100%|██████████| 246/246 [00:05<00:00, 42.23it/s]


Epoch: 64 	 Loss: 0.29015


100%|██████████| 246/246 [00:06<00:00, 39.98it/s]


Epoch: 65 	 Loss: 0.27646


100%|██████████| 246/246 [00:05<00:00, 41.83it/s]


Epoch: 66 	 Loss: 0.28763


100%|██████████| 246/246 [00:06<00:00, 40.65it/s]


Epoch: 67 	 Loss: 0.25804


100%|██████████| 246/246 [00:05<00:00, 41.60it/s]


Epoch: 68 	 Loss: 0.27068


100%|██████████| 246/246 [00:05<00:00, 41.05it/s]


Epoch: 69 	 Loss: 0.28240


100%|██████████| 246/246 [00:05<00:00, 42.62it/s]


Epoch: 70 	 Loss: 0.28448


100%|██████████| 246/246 [00:06<00:00, 40.68it/s]


Epoch: 71 	 Loss: 0.27360


100%|██████████| 246/246 [00:05<00:00, 41.66it/s]


Epoch: 72 	 Loss: 0.25750


100%|██████████| 246/246 [00:05<00:00, 41.57it/s]


Epoch: 73 	 Loss: 0.25119


100%|██████████| 246/246 [00:05<00:00, 42.03it/s]


Epoch: 74 	 Loss: 0.24336


100%|██████████| 246/246 [00:06<00:00, 38.79it/s]


Epoch: 75 	 Loss: 0.22851


100%|██████████| 246/246 [00:05<00:00, 42.30it/s]


Epoch: 76 	 Loss: 0.25323


100%|██████████| 246/246 [00:06<00:00, 40.83it/s]


Epoch: 77 	 Loss: 0.24356


100%|██████████| 246/246 [00:05<00:00, 41.71it/s]


Epoch: 78 	 Loss: 0.23069


100%|██████████| 246/246 [00:05<00:00, 41.80it/s]


Epoch: 79 	 Loss: 0.22705


100%|██████████| 246/246 [00:06<00:00, 38.68it/s]


Epoch: 80 	 Loss: 0.23338


100%|██████████| 246/246 [00:06<00:00, 39.52it/s]


Epoch: 81 	 Loss: 0.23776


100%|██████████| 246/246 [00:05<00:00, 41.35it/s]


Epoch: 82 	 Loss: 0.22577


100%|██████████| 246/246 [00:05<00:00, 41.78it/s]


Epoch: 83 	 Loss: 0.22173


100%|██████████| 246/246 [00:05<00:00, 41.82it/s]


Epoch: 84 	 Loss: 0.22511


100%|██████████| 246/246 [00:06<00:00, 40.59it/s]


Epoch: 85 	 Loss: 0.24544


100%|██████████| 246/246 [00:05<00:00, 41.52it/s]


Epoch: 86 	 Loss: 0.22524


100%|██████████| 246/246 [00:05<00:00, 42.22it/s]


Epoch: 87 	 Loss: 0.22943


100%|██████████| 246/246 [00:06<00:00, 39.65it/s]


Epoch: 88 	 Loss: 0.21228


100%|██████████| 246/246 [00:06<00:00, 39.49it/s]


Epoch: 89 	 Loss: 0.26141


100%|██████████| 246/246 [00:05<00:00, 41.66it/s]


Epoch: 90 	 Loss: 0.20946


100%|██████████| 246/246 [00:05<00:00, 41.80it/s]


Epoch: 91 	 Loss: 0.18445


100%|██████████| 246/246 [00:06<00:00, 40.90it/s]


Epoch: 92 	 Loss: 0.18738


100%|██████████| 246/246 [00:06<00:00, 40.60it/s]


Epoch: 93 	 Loss: 0.20279


100%|██████████| 246/246 [00:05<00:00, 42.12it/s]


Epoch: 94 	 Loss: 0.18066


100%|██████████| 246/246 [00:06<00:00, 35.97it/s]


Epoch: 95 	 Loss: 0.18778


100%|██████████| 246/246 [00:07<00:00, 32.68it/s]


Epoch: 96 	 Loss: 0.21845


100%|██████████| 246/246 [00:05<00:00, 42.80it/s]


Epoch: 97 	 Loss: 0.19246


100%|██████████| 246/246 [00:05<00:00, 45.80it/s]


Epoch: 98 	 Loss: 0.18572


100%|██████████| 246/246 [00:05<00:00, 44.88it/s]


Epoch: 99 	 Loss: 0.17958


100%|██████████| 246/246 [00:05<00:00, 43.83it/s]

Epoch: 100 	 Loss: 0.18010





In [49]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model3, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.62295082 0.05357143 0.58757962 0.30909091 0.70833333 0.25
 0.75136116 0.25714286 0.25       0.4609375  0.5       ] 
 Average acccuracy: 0.43190614792450505
F1 Scores: [0.70744676 0.03749995 0.61143326 0.32380947 0.38931294 0.01282051
 0.78483407 0.25352108 0.03225805 0.54629625 0.17142854] 
 Average F1: 0.35187826173658254


## CNN-BiLSTM

In [50]:
model4 = CNN_BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
optimizer = torch.optim.Adam(model4.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model4.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model4(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.1:
        break

-----------------------------------------Starting Training------------------------------------------


  1%|          | 2/246 [00:00<00:14, 17.37it/s]

100%|██████████| 246/246 [00:11<00:00, 22.25it/s]


Epoch: 1 	 Loss: 2.23016


100%|██████████| 246/246 [00:08<00:00, 28.18it/s]


Epoch: 2 	 Loss: 1.82783


100%|██████████| 246/246 [00:07<00:00, 31.20it/s]


Epoch: 3 	 Loss: 1.70305


100%|██████████| 246/246 [00:06<00:00, 40.32it/s]


Epoch: 4 	 Loss: 1.65398


100%|██████████| 246/246 [00:06<00:00, 39.69it/s]


Epoch: 5 	 Loss: 1.55114


100%|██████████| 246/246 [00:06<00:00, 38.77it/s]


Epoch: 6 	 Loss: 1.52771


100%|██████████| 246/246 [00:06<00:00, 39.59it/s]


Epoch: 7 	 Loss: 1.44216


100%|██████████| 246/246 [00:06<00:00, 39.97it/s]


Epoch: 8 	 Loss: 1.38309


100%|██████████| 246/246 [00:06<00:00, 39.19it/s]


Epoch: 9 	 Loss: 1.33259


100%|██████████| 246/246 [00:06<00:00, 39.60it/s]


Epoch: 10 	 Loss: 1.28982


100%|██████████| 246/246 [00:07<00:00, 32.96it/s]


Epoch: 11 	 Loss: 1.25569


100%|██████████| 246/246 [00:09<00:00, 26.87it/s]


Epoch: 12 	 Loss: 1.22596


100%|██████████| 246/246 [00:08<00:00, 28.76it/s]


Epoch: 13 	 Loss: 1.16082


100%|██████████| 246/246 [00:06<00:00, 37.37it/s]


Epoch: 14 	 Loss: 1.11026


100%|██████████| 246/246 [00:06<00:00, 36.61it/s]


Epoch: 15 	 Loss: 1.09190


100%|██████████| 246/246 [00:06<00:00, 36.24it/s]


Epoch: 16 	 Loss: 1.05987


100%|██████████| 246/246 [00:06<00:00, 35.70it/s]


Epoch: 17 	 Loss: 1.04109


100%|██████████| 246/246 [00:12<00:00, 20.09it/s]


Epoch: 18 	 Loss: 0.97349


100%|██████████| 246/246 [00:12<00:00, 19.22it/s]


Epoch: 19 	 Loss: 0.93753


100%|██████████| 246/246 [00:07<00:00, 31.34it/s]


Epoch: 20 	 Loss: 0.87788


100%|██████████| 246/246 [00:07<00:00, 31.68it/s]


Epoch: 21 	 Loss: 0.83817


100%|██████████| 246/246 [00:06<00:00, 35.76it/s]


Epoch: 22 	 Loss: 0.80332


100%|██████████| 246/246 [00:06<00:00, 35.86it/s]


Epoch: 23 	 Loss: 0.76806


100%|██████████| 246/246 [00:06<00:00, 37.10it/s]


Epoch: 24 	 Loss: 0.73375


100%|██████████| 246/246 [00:06<00:00, 35.52it/s]


Epoch: 25 	 Loss: 0.72809


100%|██████████| 246/246 [00:06<00:00, 36.55it/s]


Epoch: 26 	 Loss: 0.69432


100%|██████████| 246/246 [00:08<00:00, 29.35it/s]


Epoch: 27 	 Loss: 0.64843


100%|██████████| 246/246 [00:11<00:00, 20.86it/s]


Epoch: 28 	 Loss: 0.62447


100%|██████████| 246/246 [00:06<00:00, 36.81it/s]


Epoch: 29 	 Loss: 0.59264


100%|██████████| 246/246 [00:06<00:00, 36.21it/s]


Epoch: 30 	 Loss: 0.57366


100%|██████████| 246/246 [00:06<00:00, 37.85it/s]


Epoch: 31 	 Loss: 0.55848


100%|██████████| 246/246 [00:06<00:00, 37.61it/s]


Epoch: 32 	 Loss: 0.51971


100%|██████████| 246/246 [00:06<00:00, 37.13it/s]


Epoch: 33 	 Loss: 0.48456


100%|██████████| 246/246 [00:06<00:00, 35.93it/s]


Epoch: 34 	 Loss: 0.46976


100%|██████████| 246/246 [00:06<00:00, 37.32it/s]


Epoch: 35 	 Loss: 0.50405


100%|██████████| 246/246 [00:06<00:00, 37.58it/s]


Epoch: 36 	 Loss: 0.46265


100%|██████████| 246/246 [00:06<00:00, 37.42it/s]


Epoch: 37 	 Loss: 0.42403


100%|██████████| 246/246 [00:07<00:00, 33.62it/s]


Epoch: 38 	 Loss: 0.41553


100%|██████████| 246/246 [00:06<00:00, 37.21it/s]


Epoch: 39 	 Loss: 0.39802


100%|██████████| 246/246 [00:08<00:00, 27.74it/s]


Epoch: 40 	 Loss: 0.38532


100%|██████████| 246/246 [00:07<00:00, 33.55it/s]


Epoch: 41 	 Loss: 0.42109


100%|██████████| 246/246 [00:07<00:00, 34.36it/s]


Epoch: 42 	 Loss: 0.39670


100%|██████████| 246/246 [00:06<00:00, 35.95it/s]


Epoch: 43 	 Loss: 0.36449


100%|██████████| 246/246 [00:07<00:00, 34.63it/s]


Epoch: 44 	 Loss: 0.33632


100%|██████████| 246/246 [00:08<00:00, 29.83it/s]


Epoch: 45 	 Loss: 0.31946


100%|██████████| 246/246 [00:07<00:00, 35.02it/s]


Epoch: 46 	 Loss: 0.31166


100%|██████████| 246/246 [00:06<00:00, 36.91it/s]


Epoch: 47 	 Loss: 0.30993


100%|██████████| 246/246 [00:06<00:00, 37.34it/s]


Epoch: 48 	 Loss: 0.29542


100%|██████████| 246/246 [00:06<00:00, 36.68it/s]


Epoch: 49 	 Loss: 0.28753


100%|██████████| 246/246 [00:06<00:00, 35.53it/s]


Epoch: 50 	 Loss: 0.27431


100%|██████████| 246/246 [00:06<00:00, 37.11it/s]


Epoch: 51 	 Loss: 0.29131


100%|██████████| 246/246 [00:06<00:00, 35.96it/s]


Epoch: 52 	 Loss: 0.27400


100%|██████████| 246/246 [00:07<00:00, 33.64it/s]


Epoch: 53 	 Loss: 0.26121


100%|██████████| 246/246 [00:06<00:00, 36.47it/s]


Epoch: 54 	 Loss: 0.24936


100%|██████████| 246/246 [00:06<00:00, 38.10it/s]


Epoch: 55 	 Loss: 0.25465


100%|██████████| 246/246 [00:06<00:00, 37.10it/s]


Epoch: 56 	 Loss: 0.26221


100%|██████████| 246/246 [00:06<00:00, 39.80it/s]


Epoch: 57 	 Loss: 0.24459


100%|██████████| 246/246 [00:06<00:00, 38.09it/s]


Epoch: 58 	 Loss: 0.22608


100%|██████████| 246/246 [00:06<00:00, 38.83it/s]


Epoch: 59 	 Loss: 0.21138


100%|██████████| 246/246 [00:06<00:00, 39.69it/s]


Epoch: 60 	 Loss: 0.19502


100%|██████████| 246/246 [00:06<00:00, 38.25it/s]


Epoch: 61 	 Loss: 0.20184


100%|██████████| 246/246 [00:06<00:00, 37.56it/s]


Epoch: 62 	 Loss: 0.19503


100%|██████████| 246/246 [00:06<00:00, 38.67it/s]


Epoch: 63 	 Loss: 0.21230


100%|██████████| 246/246 [00:06<00:00, 38.39it/s]


Epoch: 64 	 Loss: 0.27440


100%|██████████| 246/246 [00:06<00:00, 37.05it/s]


Epoch: 65 	 Loss: 0.20984


100%|██████████| 246/246 [00:06<00:00, 39.37it/s]


Epoch: 66 	 Loss: 0.19070


100%|██████████| 246/246 [00:07<00:00, 35.14it/s]


Epoch: 67 	 Loss: 0.18538


100%|██████████| 246/246 [00:06<00:00, 35.68it/s]


Epoch: 68 	 Loss: 0.17133


100%|██████████| 246/246 [00:06<00:00, 37.06it/s]


Epoch: 69 	 Loss: 0.16424


100%|██████████| 246/246 [00:06<00:00, 37.31it/s]


Epoch: 70 	 Loss: 0.16703


100%|██████████| 246/246 [00:06<00:00, 36.81it/s]


Epoch: 71 	 Loss: 0.16285


100%|██████████| 246/246 [00:06<00:00, 35.51it/s]


Epoch: 72 	 Loss: 0.16380


100%|██████████| 246/246 [00:06<00:00, 37.34it/s]


Epoch: 73 	 Loss: 0.15749


100%|██████████| 246/246 [00:07<00:00, 34.42it/s]


Epoch: 74 	 Loss: 0.19536


100%|██████████| 246/246 [00:07<00:00, 31.43it/s]


Epoch: 75 	 Loss: 0.18533


100%|██████████| 246/246 [00:08<00:00, 29.89it/s]


Epoch: 76 	 Loss: 0.15734


100%|██████████| 246/246 [00:07<00:00, 32.48it/s]


Epoch: 77 	 Loss: 0.14052


100%|██████████| 246/246 [00:07<00:00, 34.03it/s]


Epoch: 78 	 Loss: 0.13425


100%|██████████| 246/246 [00:07<00:00, 34.91it/s]


Epoch: 79 	 Loss: 0.14590


100%|██████████| 246/246 [00:08<00:00, 30.15it/s]


Epoch: 80 	 Loss: 0.14511


100%|██████████| 246/246 [00:08<00:00, 27.47it/s]


Epoch: 81 	 Loss: 0.15181


100%|██████████| 246/246 [00:07<00:00, 33.01it/s]


Epoch: 82 	 Loss: 0.15033


100%|██████████| 246/246 [00:06<00:00, 36.51it/s]


Epoch: 83 	 Loss: 0.15032


100%|██████████| 246/246 [00:07<00:00, 34.54it/s]


Epoch: 84 	 Loss: 0.14110


100%|██████████| 246/246 [00:06<00:00, 35.90it/s]


Epoch: 85 	 Loss: 0.13455


100%|██████████| 246/246 [00:06<00:00, 36.68it/s]


Epoch: 86 	 Loss: 0.12295


100%|██████████| 246/246 [00:06<00:00, 37.28it/s]


Epoch: 87 	 Loss: 0.11941


100%|██████████| 246/246 [00:06<00:00, 36.24it/s]


Epoch: 88 	 Loss: 0.11215


100%|██████████| 246/246 [00:06<00:00, 36.81it/s]


Epoch: 89 	 Loss: 0.11334


100%|██████████| 246/246 [00:07<00:00, 34.99it/s]


Epoch: 90 	 Loss: 0.13012


100%|██████████| 246/246 [00:07<00:00, 33.82it/s]


Epoch: 91 	 Loss: 0.12559


100%|██████████| 246/246 [00:06<00:00, 37.28it/s]


Epoch: 92 	 Loss: 0.12931


100%|██████████| 246/246 [00:06<00:00, 36.94it/s]


Epoch: 93 	 Loss: 0.12529


100%|██████████| 246/246 [00:06<00:00, 36.11it/s]


Epoch: 94 	 Loss: 0.12122


100%|██████████| 246/246 [00:06<00:00, 36.06it/s]


Epoch: 95 	 Loss: 0.11331


100%|██████████| 246/246 [00:06<00:00, 36.55it/s]


Epoch: 96 	 Loss: 0.11554


100%|██████████| 246/246 [00:06<00:00, 37.31it/s]


Epoch: 97 	 Loss: 0.11335


100%|██████████| 246/246 [00:06<00:00, 35.57it/s]


Epoch: 98 	 Loss: 0.10374


100%|██████████| 246/246 [00:06<00:00, 37.36it/s]


Epoch: 99 	 Loss: 0.10235


100%|██████████| 246/246 [00:07<00:00, 31.69it/s]

Epoch: 100 	 Loss: 0.10181





In [51]:
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model4, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

Accuracies: [0.68943436 0.09677419 0.52981651 0.35294118 0.51908397 0.10769231
 0.84485981 0.30909091 0.20689655 0.55208333 0.        ] 
 Average acccuracy: 0.382606648469596
F1 Scores: [0.67573217 0.04444441 0.63680216 0.28571424 0.42367596 0.09929073
 0.87006732 0.26771649 0.08275859 0.57608691 0.        ] 
 Average F1: 0.36020808798010634
