In [8]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'..')))
from utils import load_tensor, save_model
from models import BiLSTM, CNN_BiLSTM
from evaluation_functions import class_accuracy, class_f1_score

In [9]:
def remap_targets(target_tensor: torch.TensorType, old_le, new_le):
    inverse_tensor = old_le.inverse_transform(target_tensor.long())
    for idx, label in enumerate(inverse_tensor):
        if label == 'ARG_RESPONDENT' or label == 'ARG_PETITIONER':
            inverse_tensor[idx] = "ARG"
        elif label == 'PRE_NOT_RELIED' or label == 'PRE_RELIED':
            inverse_tensor[idx] = 'PRE'
    new_tensor = torch.tensor(new_le.transform(inverse_tensor))
    return new_tensor

In [10]:
list_of_targets_old = ['ISSUE', 'FAC', 'NONE', 'ARG_PETITIONER', 'PRE_NOT_RELIED', 'STA', 'RPC', 'ARG_RESPONDENT', 'PREAMBLE', 'ANALYSIS', 'RLC', 'PRE_RELIED', 'RATIO']
label_encoder_old = LabelEncoder().fit(list_of_targets_old)
list_of_targets_new = ['ISSUE', 'FAC', 'NONE', 'ARG', 'PRE', 'STA', 'RPC', 'PREAMBLE', 'ANALYSIS', 'RLC', 'RATIO']
label_encoder_new = LabelEncoder().fit(list_of_targets_new)

In [11]:
sample_input, sample_target = None, None
for idx in range(246):
    if sample_input is None:
        sample_input = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        sample_target = load_tensor(filepath=f"../train_document/doc_{idx}/label")
    else:
        sample_input = torch.cat((sample_input,load_tensor(filepath=f"../train_document/doc_{idx}/embedding")), dim=0)
        sample_target = torch.cat((sample_target,load_tensor(filepath=f"../train_document/doc_{idx}/label")), dim=0)

In [12]:
sample_target.size()

torch.Size([28864])

In [13]:
remapped_target = remap_targets(sample_target, label_encoder_old, label_encoder_new)

In [14]:
from sklearn.utils.class_weight import compute_class_weight
    
class_weights = compute_class_weight(class_weight = "balanced",
                                    classes = np.unique(remapped_target.numpy()),
                                    y = remapped_target.numpy())
class_weights = torch.FloatTensor(class_weights)
class_weights

tensor([0.2460, 1.3107, 0.4584, 7.1499, 1.8544, 1.6639, 0.6347, 3.8990, 3.5033,
        2.4341, 5.4895])

In [17]:
def calculate_confusion_matrix(test_emb, test_labels, model, num_labels = 11):
    """

    Parameters:

    Returns:
    """
    model.eval()
    output = model(test_emb)
    return confusion_matrix(output, test_labels, num_labels)

def confusion_matrix(y_pred, y_true, num_classes):
    """
    Create a confusion matrix for label encodings in PyTorch.

    Parameters:
    y_pred (torch.Tensor): Predicted labels tensor.
    y_true (torch.Tensor): True labels tensor.
    num_classes (int): Number of classes.

    Returns:
    numpy.ndarray: Confusion matrix.
    """ 
    conf_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    y_pred_np = y_pred.argmax(dim=1).cpu().numpy()
    y_true_np = y_true.cpu().numpy()

    for pred, true in zip(y_pred_np, y_true_np):
        conf_matrix[pred, true] += 1

    return conf_matrix

# BERT-Base

## BiLSTM

In [104]:
model1 = BiLSTM(hidden_size=128, dropout= 0.30, output_size= 11)
optimizer = torch.optim.Adam(model1.parameters(), lr= 2e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model1.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model1(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.03:
        break
# batch_loss.append(loss.item())

cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model1, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))
# result.append((config, (average_accuracy, average_f1)))

-----------------------------------------Starting Training------------------------------------------


 76%|███████▌  | 186/246 [00:05<00:01, 35.37it/s]


KeyboardInterrupt: 

RUN 1

Accuracies: [0.82905028 0.40425532 0.77134146 0.78723404 0.94886364 0.53488372
 0.99401198 0.51470588 0.63076923 0.79207921 0.43902439] 

Average acccuracy: 0.6951108317340918

F1 Scores: [0.79358284 0.38383833 0.81943315 0.76288655 0.91256826 0.62670295
 0.99104473 0.49999995 0.45303863 0.8465608  0.51428566] 

Average F1: 0.6912674393701633

RUN 2

Accuracies: [0.81752701 0.40449438 0.7826087  0.73584906 0.96089385 0.44256757
 0.99013807 0.53521127 0.60714286 0.87234043 0.51351351] 

Average acccuracy: 0.6965715180230886

F1 Scores: [0.75331853 0.37305694 0.82420273 0.7572815  0.93224927 0.58482138
 0.99307611 0.53146848 0.39534879 0.90109885 0.57575752] 

Average F1: 0.6928800115505303

RUN 3

Accuracies: [0.84358974 0.39361702 0.75362319 0.74545455 0.93513514 0.42598187
 0.99404762 0.56896552 0.54054054 0.86021505 0.5       ] 

Average acccuracy: 0.6873791125060527

F1 Scores: [0.7498575  0.37373732 0.8195429  0.78095233 0.92266662 0.58385089
 0.99404757 0.50769226 0.26143787 0.88397785 0.52459011] 

Average F1: 0.6729412013793659

In [106]:
avg_acc = (0.6951108317340918 + 0.6965715180230886 + 0.6873791125060527) / 3
avg_f1 = (0.6912674393701633 + 0.6928800115505303 + 0.6729412013793659) / 3
print(f"{avg_acc:.4f}")
print(f"{avg_f1:.4f}")

0.6930
0.6857


## CNN-BiLSTM

In [110]:
model2 = CNN_BiLSTM(hidden_size=128, dropout= 0.30, output_size= 11)
optimizer = torch.optim.Adam(model2.parameters(), lr= 5e-4)
loss_function = nn.CrossEntropyLoss(weight= class_weights)

print(f'{"Starting Training":-^100}')
model2.train()
loss_list = []
for epoch in range(100):
    running_loss = 0
    for idx in tqdm(range(246)):
        TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}/embedding")
        TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}/label")
        TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
        if TRAIN_emb.size(0) == 0:
            continue
        output = model2(TRAIN_emb)
        loss = loss_function(output,TRAIN_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # scheduler.step()
    # scheduler1.step()
    # scheduler2.step()
    # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
    loss_list.append(running_loss/246)
    print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
    if running_loss/246 < 0.03:
        break
# batch_loss.append(loss.item())
cm = None
for i in range(29):
    TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}/embedding")
    TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}/label")
    TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
    conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model2, num_labels= 11)
    if cm is None:
        cm = conf_matrix_helper
    else:
        cm = np.add(cm, conf_matrix_helper)
        
accuracies = class_accuracy(cm)
f1_scores = class_f1_score(cm)
average_accuracy = np.mean(accuracies)
average_f1 = np.mean(f1_scores)

print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

-----------------------------------------Starting Training------------------------------------------


  1%|          | 3/246 [00:00<00:10, 24.06it/s]

100%|██████████| 246/246 [00:07<00:00, 33.51it/s]


Epoch: 1 	 Loss: 1.60613


100%|██████████| 246/246 [00:07<00:00, 34.61it/s]


Epoch: 2 	 Loss: 1.03862


100%|██████████| 246/246 [00:07<00:00, 34.09it/s]


Epoch: 3 	 Loss: 0.89027


100%|██████████| 246/246 [00:08<00:00, 30.07it/s]


Epoch: 4 	 Loss: 0.78925


100%|██████████| 246/246 [00:07<00:00, 33.22it/s]


Epoch: 5 	 Loss: 0.71849


100%|██████████| 246/246 [00:06<00:00, 38.97it/s]


Epoch: 6 	 Loss: 0.71630


100%|██████████| 246/246 [00:06<00:00, 38.82it/s]


Epoch: 7 	 Loss: 0.62485


100%|██████████| 246/246 [00:07<00:00, 32.77it/s]


Epoch: 8 	 Loss: 0.58211


100%|██████████| 246/246 [00:07<00:00, 31.76it/s]


Epoch: 9 	 Loss: 0.55680


100%|██████████| 246/246 [00:06<00:00, 35.20it/s]


Epoch: 10 	 Loss: 0.51907


100%|██████████| 246/246 [00:08<00:00, 30.28it/s]


Epoch: 11 	 Loss: 0.48725


100%|██████████| 246/246 [00:07<00:00, 33.62it/s]


Epoch: 12 	 Loss: 0.45692


100%|██████████| 246/246 [00:09<00:00, 25.28it/s]


Epoch: 13 	 Loss: 0.42348


100%|██████████| 246/246 [00:10<00:00, 22.70it/s]


Epoch: 14 	 Loss: 0.42160


100%|██████████| 246/246 [00:08<00:00, 29.04it/s]


Epoch: 15 	 Loss: 0.36754


100%|██████████| 246/246 [00:08<00:00, 28.06it/s]


Epoch: 16 	 Loss: 0.38325


100%|██████████| 246/246 [00:07<00:00, 32.13it/s]


Epoch: 17 	 Loss: 0.35926


100%|██████████| 246/246 [00:07<00:00, 31.74it/s]


Epoch: 18 	 Loss: 0.32548


100%|██████████| 246/246 [00:07<00:00, 30.84it/s]


Epoch: 19 	 Loss: 0.29720


100%|██████████| 246/246 [00:09<00:00, 26.17it/s]


Epoch: 20 	 Loss: 0.27792


100%|██████████| 246/246 [00:07<00:00, 31.05it/s]


Epoch: 21 	 Loss: 0.29934


100%|██████████| 246/246 [00:08<00:00, 28.32it/s]


Epoch: 22 	 Loss: 0.26319


100%|██████████| 246/246 [00:09<00:00, 26.70it/s]


Epoch: 23 	 Loss: 0.25062


100%|██████████| 246/246 [00:07<00:00, 31.84it/s]


Epoch: 24 	 Loss: 0.21757


100%|██████████| 246/246 [00:08<00:00, 29.82it/s]


Epoch: 25 	 Loss: 0.20819


100%|██████████| 246/246 [00:08<00:00, 27.55it/s]


Epoch: 26 	 Loss: 0.21479


100%|██████████| 246/246 [00:08<00:00, 30.71it/s]


Epoch: 27 	 Loss: 0.20082


100%|██████████| 246/246 [00:07<00:00, 33.63it/s]


Epoch: 28 	 Loss: 0.19171


100%|██████████| 246/246 [00:08<00:00, 30.05it/s]


Epoch: 29 	 Loss: 0.17613


100%|██████████| 246/246 [00:07<00:00, 31.57it/s]


Epoch: 30 	 Loss: 0.24534


100%|██████████| 246/246 [00:08<00:00, 30.29it/s]


Epoch: 31 	 Loss: 0.18947


100%|██████████| 246/246 [00:08<00:00, 30.16it/s]


Epoch: 32 	 Loss: 0.15510


100%|██████████| 246/246 [00:08<00:00, 30.30it/s]


Epoch: 33 	 Loss: 0.14584


100%|██████████| 246/246 [00:07<00:00, 31.88it/s]


Epoch: 34 	 Loss: 0.12713


100%|██████████| 246/246 [00:07<00:00, 30.81it/s]


Epoch: 35 	 Loss: 0.12812


100%|██████████| 246/246 [00:08<00:00, 30.21it/s]


Epoch: 36 	 Loss: 0.13802


100%|██████████| 246/246 [00:07<00:00, 32.93it/s]


Epoch: 37 	 Loss: 0.16252


100%|██████████| 246/246 [00:07<00:00, 32.60it/s]


Epoch: 38 	 Loss: 0.11336


100%|██████████| 246/246 [00:07<00:00, 33.67it/s]


Epoch: 39 	 Loss: 0.11063


100%|██████████| 246/246 [00:07<00:00, 31.99it/s]


Epoch: 40 	 Loss: 0.21687


100%|██████████| 246/246 [00:06<00:00, 38.02it/s]


Epoch: 41 	 Loss: 0.15149


100%|██████████| 246/246 [00:06<00:00, 37.80it/s]


Epoch: 42 	 Loss: 0.10449


100%|██████████| 246/246 [00:07<00:00, 34.31it/s]


Epoch: 43 	 Loss: 0.09087


100%|██████████| 246/246 [00:07<00:00, 32.34it/s]


Epoch: 44 	 Loss: 0.08974


100%|██████████| 246/246 [00:07<00:00, 32.89it/s]


Epoch: 45 	 Loss: 0.08720


100%|██████████| 246/246 [00:08<00:00, 28.90it/s]


Epoch: 46 	 Loss: 0.08044


100%|██████████| 246/246 [00:07<00:00, 31.12it/s]


Epoch: 47 	 Loss: 0.07796


100%|██████████| 246/246 [00:07<00:00, 31.14it/s]


Epoch: 48 	 Loss: 0.08443


100%|██████████| 246/246 [00:07<00:00, 31.31it/s]


Epoch: 49 	 Loss: 0.16880


100%|██████████| 246/246 [00:07<00:00, 32.65it/s]


Epoch: 50 	 Loss: 0.33339


100%|██████████| 246/246 [00:09<00:00, 25.26it/s]


Epoch: 51 	 Loss: 0.12241


100%|██████████| 246/246 [00:10<00:00, 23.18it/s]


Epoch: 52 	 Loss: 0.09076


100%|██████████| 246/246 [00:08<00:00, 29.78it/s]


Epoch: 53 	 Loss: 0.07824


100%|██████████| 246/246 [00:09<00:00, 25.09it/s]


Epoch: 54 	 Loss: 0.07072


100%|██████████| 246/246 [00:09<00:00, 25.40it/s]


Epoch: 55 	 Loss: 0.06425


100%|██████████| 246/246 [00:09<00:00, 26.23it/s]


Epoch: 56 	 Loss: 0.05711


100%|██████████| 246/246 [00:09<00:00, 26.02it/s]


Epoch: 57 	 Loss: 0.06051


100%|██████████| 246/246 [00:08<00:00, 28.31it/s]


Epoch: 58 	 Loss: 0.05815


100%|██████████| 246/246 [00:08<00:00, 29.39it/s]


Epoch: 59 	 Loss: 0.05513


100%|██████████| 246/246 [00:08<00:00, 29.61it/s]


Epoch: 60 	 Loss: 0.05297


100%|██████████| 246/246 [00:08<00:00, 29.87it/s]


Epoch: 61 	 Loss: 0.07083


100%|██████████| 246/246 [00:07<00:00, 33.49it/s]


Epoch: 62 	 Loss: 0.06204


100%|██████████| 246/246 [00:06<00:00, 35.15it/s]


Epoch: 63 	 Loss: 0.05060


100%|██████████| 246/246 [00:07<00:00, 33.88it/s]


Epoch: 64 	 Loss: 0.05400


100%|██████████| 246/246 [00:07<00:00, 31.56it/s]


Epoch: 65 	 Loss: 0.05495


100%|██████████| 246/246 [00:09<00:00, 26.88it/s]


Epoch: 66 	 Loss: 0.05038


100%|██████████| 246/246 [00:09<00:00, 25.16it/s]


Epoch: 67 	 Loss: 0.04715


100%|██████████| 246/246 [00:08<00:00, 29.69it/s]


Epoch: 68 	 Loss: 0.04976


100%|██████████| 246/246 [00:07<00:00, 31.07it/s]


Epoch: 69 	 Loss: 0.16736


100%|██████████| 246/246 [00:08<00:00, 30.01it/s]


Epoch: 70 	 Loss: 0.08154


100%|██████████| 246/246 [00:08<00:00, 29.28it/s]


Epoch: 71 	 Loss: 0.05215


100%|██████████| 246/246 [00:08<00:00, 30.69it/s]


Epoch: 72 	 Loss: 0.04524


100%|██████████| 246/246 [00:08<00:00, 29.90it/s]


Epoch: 73 	 Loss: 0.03699


100%|██████████| 246/246 [00:07<00:00, 30.84it/s]


Epoch: 74 	 Loss: 0.03416


100%|██████████| 246/246 [00:08<00:00, 30.53it/s]


Epoch: 75 	 Loss: 0.03076


100%|██████████| 246/246 [00:07<00:00, 31.81it/s]


Epoch: 76 	 Loss: 0.02964
Accuracies: [0.83219955 0.61016949 0.75426136 0.77083333 0.88829787 0.57073171
 0.998      0.54054054 0.60377358 0.8        0.45652174] 
 Average acccuracy: 0.7113935617467742
F1 Scores: [0.7905223  0.44171774 0.82774742 0.75510199 0.88359783 0.65546214
 0.99402385 0.54794515 0.37869818 0.85106378 0.55999995] 
 Average F1: 0.6987163943896938


RUN 1

Accuracies: [0.82635342 0.59322034 0.78768233 0.77083333 0.92972973 0.65853659
 0.99405941 0.61842105 0.56338028 0.89010989 0.390625  ] 

Average acccuracy: 0.729359215774254

F1 Scores: [0.82804499 0.42944781 0.81270898 0.75510199 0.91733328 0.68354425
 0.99504455 0.63513508 0.42780744 0.90502788 0.53763436] 
 
Average F1: 0.7206209650612753

RUN 2

Accuracies: [0.82359192 0.53731343 0.77828746 0.7755102  0.94444444 0.66459627
 0.99013807 0.51807229 0.54794521 0.88888889 0.37037037] 

Average acccuracy: 0.7126507782608781

F1 Scores: [0.80897699 0.42105258 0.8256285  0.76767672 0.91891887 0.68370602
 0.99307611 0.55483866 0.42328038 0.89887635 0.48192766] 

Average F1: 0.7070871669091839

RUN 3

Accuracies: [0.83219955 0.61016949 0.75426136 0.77083333 0.88829787 0.57073171
 0.998      0.54054054 0.60377358 0.8        0.45652174] 

Average acccuracy: 0.7113935617467742

F1 Scores: [0.7905223  0.44171774 0.82774742 0.75510199 0.88359783 0.65546214
 0.99402385 0.54794515 0.37869818 0.85106378 0.55999995] 

Average F1: 0.6987163943896938


In [112]:
avg_acc = (0.729359215774254 + 0.7126507782608781 + 0.7113935617467742) / 3
avg_f1 = (0.7206209650612753 + 0.7070871669091839 + 0.6987163943896938) / 3
print(f"{avg_acc:.4f}")
print(f"{avg_f1:.4f}")

0.7178
0.7088


# LEGAL-BERT

## BiLSTM

In [114]:
while_i = 0
accs3 = []
f1s3 = []
micro_accs = []
micro_f1s = []
while while_i < 3:
    while_i += 1
    model3 = BiLSTM(hidden_size=128, dropout= 0.30, output_size= 11)
    optimizer = torch.optim.Adam(model3.parameters(), lr= 5e-4)
    loss_function = nn.CrossEntropyLoss(weight= class_weights)

    print(f'{"Starting Training":-^100}')
    model3.train()
    loss_list = []
    for epoch in range(100):
        running_loss = 0
        for idx in tqdm(range(246)):
            TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
            TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
            TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
            if TRAIN_emb.size(0) == 0:
                continue
            output = model3(TRAIN_emb)
            loss = loss_function(output,TRAIN_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        # scheduler.step()
        # scheduler1.step()
        # scheduler2.step()
        # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
        loss_list.append(running_loss/246)
        print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
        if running_loss/246 < 0.1:
            break
    cm = None
    for i in range(29):
        TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
        TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
        TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
        conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model3, num_labels= 11)
        if cm is None:
            cm = conf_matrix_helper
        else:
            cm = np.add(cm, conf_matrix_helper)
            
    accuracies = class_accuracy(cm)
    f1_scores = class_f1_score(cm)
    average_accuracy = np.mean(accuracies)
    average_f1 = np.mean(f1_scores)
    
    accs3.append(average_accuracy)
    f1s3.append(average_f1)
    micro_accs.append(accuracies)
    micro_f1s.append(f1_scores)

    print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
    print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

-----------------------------------------Starting Training------------------------------------------


  0%|          | 0/246 [00:00<?, ?it/s]

100%|██████████| 246/246 [00:07<00:00, 32.29it/s]


Epoch: 1 	 Loss: 2.25662


100%|██████████| 246/246 [00:07<00:00, 31.90it/s]


Epoch: 2 	 Loss: 1.94651


100%|██████████| 246/246 [00:05<00:00, 41.71it/s]


Epoch: 3 	 Loss: 1.76240


100%|██████████| 246/246 [00:05<00:00, 43.05it/s]


Epoch: 4 	 Loss: 1.64288


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 5 	 Loss: 1.57042


100%|██████████| 246/246 [00:06<00:00, 39.19it/s]


Epoch: 6 	 Loss: 1.50262


100%|██████████| 246/246 [00:05<00:00, 43.14it/s]


Epoch: 7 	 Loss: 1.46249


100%|██████████| 246/246 [00:05<00:00, 42.50it/s]


Epoch: 8 	 Loss: 1.41072


100%|██████████| 246/246 [00:05<00:00, 42.26it/s]


Epoch: 9 	 Loss: 1.35293


100%|██████████| 246/246 [00:05<00:00, 43.00it/s]


Epoch: 10 	 Loss: 1.31057


100%|██████████| 246/246 [00:05<00:00, 41.22it/s]


Epoch: 11 	 Loss: 1.25604


100%|██████████| 246/246 [00:05<00:00, 43.32it/s]


Epoch: 12 	 Loss: 1.22323


100%|██████████| 246/246 [00:05<00:00, 42.35it/s]


Epoch: 13 	 Loss: 1.19013


100%|██████████| 246/246 [00:05<00:00, 42.99it/s]


Epoch: 14 	 Loss: 1.16236


100%|██████████| 246/246 [00:05<00:00, 43.08it/s]


Epoch: 15 	 Loss: 1.11909


100%|██████████| 246/246 [00:06<00:00, 38.48it/s]


Epoch: 16 	 Loss: 1.09130


100%|██████████| 246/246 [00:05<00:00, 43.21it/s]


Epoch: 17 	 Loss: 1.05280


100%|██████████| 246/246 [00:05<00:00, 43.65it/s]


Epoch: 18 	 Loss: 1.03816


100%|██████████| 246/246 [00:05<00:00, 42.78it/s]


Epoch: 19 	 Loss: 1.01530


100%|██████████| 246/246 [00:05<00:00, 41.11it/s]


Epoch: 20 	 Loss: 0.97474


100%|██████████| 246/246 [00:05<00:00, 43.27it/s]


Epoch: 21 	 Loss: 0.93647


100%|██████████| 246/246 [00:05<00:00, 41.07it/s]


Epoch: 22 	 Loss: 0.90988


100%|██████████| 246/246 [00:05<00:00, 43.43it/s]


Epoch: 23 	 Loss: 0.88089


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 24 	 Loss: 0.88620


100%|██████████| 246/246 [00:05<00:00, 41.89it/s]


Epoch: 25 	 Loss: 0.84103


100%|██████████| 246/246 [00:05<00:00, 42.18it/s]


Epoch: 26 	 Loss: 0.79326


100%|██████████| 246/246 [00:05<00:00, 42.53it/s]


Epoch: 27 	 Loss: 0.77729


100%|██████████| 246/246 [00:05<00:00, 41.51it/s]


Epoch: 28 	 Loss: 0.74173


100%|██████████| 246/246 [00:06<00:00, 40.08it/s]


Epoch: 29 	 Loss: 0.72893


100%|██████████| 246/246 [00:05<00:00, 42.92it/s]


Epoch: 30 	 Loss: 0.70538


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 31 	 Loss: 0.68163


100%|██████████| 246/246 [00:05<00:00, 43.63it/s]


Epoch: 32 	 Loss: 0.66916


100%|██████████| 246/246 [00:05<00:00, 43.21it/s]


Epoch: 33 	 Loss: 0.65051


100%|██████████| 246/246 [00:05<00:00, 42.02it/s]


Epoch: 34 	 Loss: 0.64432


100%|██████████| 246/246 [00:05<00:00, 43.52it/s]


Epoch: 35 	 Loss: 0.61468


100%|██████████| 246/246 [00:05<00:00, 43.44it/s]


Epoch: 36 	 Loss: 0.61099


100%|██████████| 246/246 [00:05<00:00, 42.72it/s]


Epoch: 37 	 Loss: 0.60230


100%|██████████| 246/246 [00:05<00:00, 41.29it/s]


Epoch: 38 	 Loss: 0.58313


100%|██████████| 246/246 [00:06<00:00, 38.77it/s]


Epoch: 39 	 Loss: 0.55908


100%|██████████| 246/246 [00:05<00:00, 42.49it/s]


Epoch: 40 	 Loss: 0.57165


100%|██████████| 246/246 [00:05<00:00, 43.31it/s]


Epoch: 41 	 Loss: 0.54400


100%|██████████| 246/246 [00:05<00:00, 43.18it/s]


Epoch: 42 	 Loss: 0.51508


100%|██████████| 246/246 [00:05<00:00, 43.35it/s]


Epoch: 43 	 Loss: 0.50408


100%|██████████| 246/246 [00:05<00:00, 41.33it/s]


Epoch: 44 	 Loss: 0.48846


100%|██████████| 246/246 [00:05<00:00, 42.98it/s]


Epoch: 45 	 Loss: 0.49326


100%|██████████| 246/246 [00:05<00:00, 42.53it/s]


Epoch: 46 	 Loss: 0.46342


100%|██████████| 246/246 [00:05<00:00, 43.39it/s]


Epoch: 47 	 Loss: 0.44962


100%|██████████| 246/246 [00:05<00:00, 41.77it/s]


Epoch: 48 	 Loss: 0.47173


100%|██████████| 246/246 [00:05<00:00, 43.07it/s]


Epoch: 49 	 Loss: 0.46420


100%|██████████| 246/246 [00:05<00:00, 41.59it/s]


Epoch: 50 	 Loss: 0.42582


100%|██████████| 246/246 [00:05<00:00, 42.66it/s]


Epoch: 51 	 Loss: 0.41738


100%|██████████| 246/246 [00:05<00:00, 41.40it/s]


Epoch: 52 	 Loss: 0.40542


100%|██████████| 246/246 [00:05<00:00, 41.21it/s]


Epoch: 53 	 Loss: 0.41384


100%|██████████| 246/246 [00:05<00:00, 43.30it/s]


Epoch: 54 	 Loss: 0.43695


100%|██████████| 246/246 [00:05<00:00, 42.06it/s]


Epoch: 55 	 Loss: 0.39510


100%|██████████| 246/246 [00:05<00:00, 41.05it/s]


Epoch: 56 	 Loss: 0.38309


100%|██████████| 246/246 [00:05<00:00, 43.13it/s]


Epoch: 57 	 Loss: 0.37235


100%|██████████| 246/246 [00:06<00:00, 40.98it/s]


Epoch: 58 	 Loss: 0.36525


100%|██████████| 246/246 [00:05<00:00, 42.76it/s]


Epoch: 59 	 Loss: 0.35385


100%|██████████| 246/246 [00:05<00:00, 43.41it/s]


Epoch: 60 	 Loss: 0.34839


100%|██████████| 246/246 [00:05<00:00, 42.52it/s]


Epoch: 61 	 Loss: 0.39656


100%|██████████| 246/246 [00:05<00:00, 43.03it/s]


Epoch: 62 	 Loss: 0.36348


100%|██████████| 246/246 [00:05<00:00, 41.47it/s]


Epoch: 63 	 Loss: 0.35251


100%|██████████| 246/246 [00:05<00:00, 43.36it/s]


Epoch: 64 	 Loss: 0.33821


100%|██████████| 246/246 [00:05<00:00, 42.54it/s]


Epoch: 65 	 Loss: 0.31791


100%|██████████| 246/246 [00:05<00:00, 43.55it/s]


Epoch: 66 	 Loss: 0.32667


100%|██████████| 246/246 [00:05<00:00, 42.16it/s]


Epoch: 67 	 Loss: 0.34800


100%|██████████| 246/246 [00:06<00:00, 40.24it/s]


Epoch: 68 	 Loss: 0.32332


100%|██████████| 246/246 [00:05<00:00, 43.16it/s]


Epoch: 69 	 Loss: 0.31518


100%|██████████| 246/246 [00:05<00:00, 43.13it/s]


Epoch: 70 	 Loss: 0.29286


100%|██████████| 246/246 [00:05<00:00, 42.19it/s]


Epoch: 71 	 Loss: 0.29569


100%|██████████| 246/246 [00:05<00:00, 43.26it/s]


Epoch: 72 	 Loss: 0.29098


100%|██████████| 246/246 [00:06<00:00, 39.97it/s]


Epoch: 73 	 Loss: 0.28240


100%|██████████| 246/246 [00:05<00:00, 43.16it/s]


Epoch: 74 	 Loss: 0.30120


100%|██████████| 246/246 [00:05<00:00, 42.66it/s]


Epoch: 75 	 Loss: 0.30531


100%|██████████| 246/246 [00:05<00:00, 42.92it/s]


Epoch: 76 	 Loss: 0.28144


100%|██████████| 246/246 [00:06<00:00, 40.94it/s]


Epoch: 77 	 Loss: 0.26735


100%|██████████| 246/246 [00:05<00:00, 43.01it/s]


Epoch: 78 	 Loss: 0.28241


100%|██████████| 246/246 [00:05<00:00, 42.44it/s]


Epoch: 79 	 Loss: 0.28904


100%|██████████| 246/246 [00:06<00:00, 40.54it/s]


Epoch: 80 	 Loss: 0.27171


100%|██████████| 246/246 [00:07<00:00, 35.14it/s]


Epoch: 81 	 Loss: 0.26093


100%|██████████| 246/246 [00:06<00:00, 38.32it/s]


Epoch: 82 	 Loss: 0.27965


100%|██████████| 246/246 [00:05<00:00, 42.34it/s]


Epoch: 83 	 Loss: 0.30071


100%|██████████| 246/246 [00:05<00:00, 42.92it/s]


Epoch: 84 	 Loss: 0.26841


100%|██████████| 246/246 [00:05<00:00, 41.48it/s]


Epoch: 85 	 Loss: 0.26212


100%|██████████| 246/246 [00:05<00:00, 43.49it/s]


Epoch: 86 	 Loss: 0.23686


100%|██████████| 246/246 [00:05<00:00, 43.22it/s]


Epoch: 87 	 Loss: 0.24971


100%|██████████| 246/246 [00:06<00:00, 38.24it/s]


Epoch: 88 	 Loss: 0.23284


100%|██████████| 246/246 [00:05<00:00, 42.68it/s]


Epoch: 89 	 Loss: 0.23949


100%|██████████| 246/246 [00:05<00:00, 41.52it/s]


Epoch: 90 	 Loss: 0.24646


100%|██████████| 246/246 [00:05<00:00, 42.25it/s]


Epoch: 91 	 Loss: 0.25287


100%|██████████| 246/246 [00:05<00:00, 43.38it/s]


Epoch: 92 	 Loss: 0.24314


100%|██████████| 246/246 [00:05<00:00, 43.07it/s]


Epoch: 93 	 Loss: 0.22542


100%|██████████| 246/246 [00:05<00:00, 43.29it/s]


Epoch: 94 	 Loss: 0.23485


100%|██████████| 246/246 [00:05<00:00, 41.34it/s]


Epoch: 95 	 Loss: 0.22189


100%|██████████| 246/246 [00:05<00:00, 42.59it/s]


Epoch: 96 	 Loss: 0.24895


100%|██████████| 246/246 [00:05<00:00, 43.50it/s]


Epoch: 97 	 Loss: 0.23202


100%|██████████| 246/246 [00:05<00:00, 42.74it/s]


Epoch: 98 	 Loss: 0.22352


100%|██████████| 246/246 [00:05<00:00, 43.05it/s]


Epoch: 99 	 Loss: 0.21386


100%|██████████| 246/246 [00:05<00:00, 41.47it/s]


Epoch: 100 	 Loss: 0.20842
Accuracies: [0.66202946 0.02702703 0.58013245 0.61538462 0.68269231 0.21052632
 0.80670611 0.36363636 0.         0.69014085 0.13207547] 
 Average acccuracy: 0.43366827008451553
F1 Scores: [0.73645876 0.01418436 0.65667161 0.25396822 0.48299315 0.04678361
 0.80909985 0.37583888 0.         0.61635215 0.17073166] 
 Average F1: 0.3784620225322724
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:05<00:00, 41.88it/s]


Epoch: 1 	 Loss: 2.22584


100%|██████████| 246/246 [00:06<00:00, 35.68it/s]


Epoch: 2 	 Loss: 1.90066


100%|██████████| 246/246 [00:06<00:00, 38.02it/s]


Epoch: 3 	 Loss: 1.77282


100%|██████████| 246/246 [00:05<00:00, 43.10it/s]


Epoch: 4 	 Loss: 1.64825


100%|██████████| 246/246 [00:07<00:00, 33.74it/s]


Epoch: 5 	 Loss: 1.57632


100%|██████████| 246/246 [00:08<00:00, 30.43it/s]


Epoch: 6 	 Loss: 1.50495


100%|██████████| 246/246 [00:08<00:00, 30.02it/s]


Epoch: 7 	 Loss: 1.44684


100%|██████████| 246/246 [00:06<00:00, 40.21it/s]


Epoch: 8 	 Loss: 1.42465


100%|██████████| 246/246 [00:06<00:00, 38.95it/s]


Epoch: 9 	 Loss: 1.37166


100%|██████████| 246/246 [00:07<00:00, 31.39it/s]


Epoch: 10 	 Loss: 1.32834


100%|██████████| 246/246 [00:06<00:00, 40.49it/s]


Epoch: 11 	 Loss: 1.26943


100%|██████████| 246/246 [00:06<00:00, 38.71it/s]


Epoch: 12 	 Loss: 1.23391


100%|██████████| 246/246 [00:06<00:00, 35.55it/s]


Epoch: 13 	 Loss: 1.20026


100%|██████████| 246/246 [00:05<00:00, 41.74it/s]


Epoch: 14 	 Loss: 1.14413


100%|██████████| 246/246 [00:06<00:00, 37.46it/s]


Epoch: 15 	 Loss: 1.12318


100%|██████████| 246/246 [00:06<00:00, 40.29it/s]


Epoch: 16 	 Loss: 1.08021


100%|██████████| 246/246 [00:07<00:00, 30.90it/s]


Epoch: 17 	 Loss: 1.05401


100%|██████████| 246/246 [00:07<00:00, 33.82it/s]


Epoch: 18 	 Loss: 1.01728


100%|██████████| 246/246 [00:07<00:00, 34.64it/s]


Epoch: 19 	 Loss: 0.97143


100%|██████████| 246/246 [00:06<00:00, 35.47it/s]


Epoch: 20 	 Loss: 0.96676


100%|██████████| 246/246 [00:07<00:00, 32.03it/s]


Epoch: 21 	 Loss: 0.92041


100%|██████████| 246/246 [00:09<00:00, 26.67it/s]


Epoch: 22 	 Loss: 0.88085


100%|██████████| 246/246 [00:05<00:00, 42.88it/s]


Epoch: 23 	 Loss: 0.87780


100%|██████████| 246/246 [00:05<00:00, 46.99it/s]


Epoch: 24 	 Loss: 0.87431


100%|██████████| 246/246 [00:05<00:00, 46.43it/s]


Epoch: 25 	 Loss: 0.82704


100%|██████████| 246/246 [00:05<00:00, 44.44it/s]


Epoch: 26 	 Loss: 0.79732


100%|██████████| 246/246 [00:05<00:00, 44.47it/s]


Epoch: 27 	 Loss: 0.77414


100%|██████████| 246/246 [00:05<00:00, 44.77it/s]


Epoch: 28 	 Loss: 0.75089


100%|██████████| 246/246 [00:05<00:00, 46.42it/s]


Epoch: 29 	 Loss: 0.73629


100%|██████████| 246/246 [00:05<00:00, 47.17it/s]


Epoch: 30 	 Loss: 0.72237


100%|██████████| 246/246 [00:05<00:00, 45.27it/s]


Epoch: 31 	 Loss: 0.69058


100%|██████████| 246/246 [00:05<00:00, 45.56it/s]


Epoch: 32 	 Loss: 0.66618


100%|██████████| 246/246 [00:05<00:00, 46.38it/s]


Epoch: 33 	 Loss: 0.65356


100%|██████████| 246/246 [00:05<00:00, 46.91it/s]


Epoch: 34 	 Loss: 0.66806


100%|██████████| 246/246 [00:05<00:00, 47.13it/s]


Epoch: 35 	 Loss: 0.64407


100%|██████████| 246/246 [00:05<00:00, 42.29it/s]


Epoch: 36 	 Loss: 0.61029


100%|██████████| 246/246 [00:06<00:00, 40.42it/s]


Epoch: 37 	 Loss: 0.57790


100%|██████████| 246/246 [00:08<00:00, 27.70it/s]


Epoch: 38 	 Loss: 0.55430


100%|██████████| 246/246 [00:06<00:00, 36.82it/s]


Epoch: 39 	 Loss: 0.55355


100%|██████████| 246/246 [00:06<00:00, 37.02it/s]


Epoch: 40 	 Loss: 0.53031


100%|██████████| 246/246 [00:07<00:00, 30.83it/s]


Epoch: 41 	 Loss: 0.50937


100%|██████████| 246/246 [00:06<00:00, 39.76it/s]


Epoch: 42 	 Loss: 0.51860


100%|██████████| 246/246 [00:06<00:00, 36.46it/s]


Epoch: 43 	 Loss: 0.49372


100%|██████████| 246/246 [00:06<00:00, 36.12it/s]


Epoch: 44 	 Loss: 0.50184


100%|██████████| 246/246 [00:07<00:00, 33.04it/s]


Epoch: 45 	 Loss: 0.46951


100%|██████████| 246/246 [00:08<00:00, 30.32it/s]


Epoch: 46 	 Loss: 0.46682


100%|██████████| 246/246 [00:07<00:00, 31.25it/s]


Epoch: 47 	 Loss: 0.44682


100%|██████████| 246/246 [00:09<00:00, 26.33it/s]


Epoch: 48 	 Loss: 0.44281


100%|██████████| 246/246 [00:07<00:00, 32.84it/s]


Epoch: 49 	 Loss: 0.44000


100%|██████████| 246/246 [00:07<00:00, 31.86it/s]


Epoch: 50 	 Loss: 0.45066


100%|██████████| 246/246 [00:07<00:00, 31.36it/s]


Epoch: 51 	 Loss: 0.41264


100%|██████████| 246/246 [00:06<00:00, 39.98it/s]


Epoch: 52 	 Loss: 0.41953


100%|██████████| 246/246 [00:05<00:00, 42.95it/s]


Epoch: 53 	 Loss: 0.44655


100%|██████████| 246/246 [00:05<00:00, 43.20it/s]


Epoch: 54 	 Loss: 0.44331


100%|██████████| 246/246 [00:05<00:00, 41.32it/s]


Epoch: 55 	 Loss: 0.41281


100%|██████████| 246/246 [00:05<00:00, 42.67it/s]


Epoch: 56 	 Loss: 0.38597


100%|██████████| 246/246 [00:05<00:00, 42.96it/s]


Epoch: 57 	 Loss: 0.37973


100%|██████████| 246/246 [00:05<00:00, 42.76it/s]


Epoch: 58 	 Loss: 0.36509


100%|██████████| 246/246 [00:05<00:00, 41.08it/s]


Epoch: 59 	 Loss: 0.35338


100%|██████████| 246/246 [00:05<00:00, 43.00it/s]


Epoch: 60 	 Loss: 0.35615


100%|██████████| 246/246 [00:05<00:00, 42.93it/s]


Epoch: 61 	 Loss: 0.35120


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 62 	 Loss: 0.33413


100%|██████████| 246/246 [00:05<00:00, 41.41it/s]


Epoch: 63 	 Loss: 0.34036


100%|██████████| 246/246 [00:06<00:00, 39.70it/s]


Epoch: 64 	 Loss: 0.34926


100%|██████████| 246/246 [00:05<00:00, 42.91it/s]


Epoch: 65 	 Loss: 0.32519


100%|██████████| 246/246 [00:05<00:00, 42.96it/s]


Epoch: 66 	 Loss: 0.30664


100%|██████████| 246/246 [00:05<00:00, 42.86it/s]


Epoch: 67 	 Loss: 0.30469


100%|██████████| 246/246 [00:06<00:00, 37.84it/s]


Epoch: 68 	 Loss: 0.30944


100%|██████████| 246/246 [00:05<00:00, 42.76it/s]


Epoch: 69 	 Loss: 0.33697


100%|██████████| 246/246 [00:06<00:00, 39.62it/s]


Epoch: 70 	 Loss: 0.32158


100%|██████████| 246/246 [00:05<00:00, 42.74it/s]


Epoch: 71 	 Loss: 0.30876


100%|██████████| 246/246 [00:05<00:00, 42.69it/s]


Epoch: 72 	 Loss: 0.28304


100%|██████████| 246/246 [00:06<00:00, 40.86it/s]


Epoch: 73 	 Loss: 0.28267


100%|██████████| 246/246 [00:05<00:00, 42.29it/s]


Epoch: 74 	 Loss: 0.28116


100%|██████████| 246/246 [00:05<00:00, 42.23it/s]


Epoch: 75 	 Loss: 0.28222


100%|██████████| 246/246 [00:06<00:00, 39.39it/s]


Epoch: 76 	 Loss: 0.26706


100%|██████████| 246/246 [00:05<00:00, 41.81it/s]


Epoch: 77 	 Loss: 0.26684


100%|██████████| 246/246 [00:06<00:00, 35.78it/s]


Epoch: 78 	 Loss: 0.25574


100%|██████████| 246/246 [00:07<00:00, 32.65it/s]


Epoch: 79 	 Loss: 0.26004


100%|██████████| 246/246 [00:05<00:00, 43.12it/s]


Epoch: 80 	 Loss: 0.28154


100%|██████████| 246/246 [00:06<00:00, 39.39it/s]


Epoch: 81 	 Loss: 0.27326


100%|██████████| 246/246 [00:05<00:00, 42.70it/s]


Epoch: 82 	 Loss: 0.26891


100%|██████████| 246/246 [00:05<00:00, 42.37it/s]


Epoch: 83 	 Loss: 0.29304


100%|██████████| 246/246 [00:05<00:00, 42.82it/s]


Epoch: 84 	 Loss: 0.25252


100%|██████████| 246/246 [00:06<00:00, 40.40it/s]


Epoch: 85 	 Loss: 0.24947


100%|██████████| 246/246 [00:08<00:00, 28.11it/s]


Epoch: 86 	 Loss: 0.25404


100%|██████████| 246/246 [00:05<00:00, 42.95it/s]


Epoch: 87 	 Loss: 0.27331


100%|██████████| 246/246 [00:05<00:00, 42.81it/s]


Epoch: 88 	 Loss: 0.27699


100%|██████████| 246/246 [00:05<00:00, 41.76it/s]


Epoch: 89 	 Loss: 0.24976


100%|██████████| 246/246 [00:06<00:00, 40.61it/s]


Epoch: 90 	 Loss: 0.22658


100%|██████████| 246/246 [00:05<00:00, 42.96it/s]


Epoch: 91 	 Loss: 0.23070


100%|██████████| 246/246 [00:05<00:00, 42.15it/s]


Epoch: 92 	 Loss: 0.23453


100%|██████████| 246/246 [00:05<00:00, 42.19it/s]


Epoch: 93 	 Loss: 0.21325


100%|██████████| 246/246 [00:06<00:00, 40.74it/s]


Epoch: 94 	 Loss: 0.21089


100%|██████████| 246/246 [00:05<00:00, 42.18it/s]


Epoch: 95 	 Loss: 0.22424


100%|██████████| 246/246 [00:05<00:00, 41.63it/s]


Epoch: 96 	 Loss: 0.25514


100%|██████████| 246/246 [00:05<00:00, 42.71it/s]


Epoch: 97 	 Loss: 0.22827


100%|██████████| 246/246 [00:05<00:00, 42.30it/s]


Epoch: 98 	 Loss: 0.21352


100%|██████████| 246/246 [00:06<00:00, 40.16it/s]


Epoch: 99 	 Loss: 0.30488


100%|██████████| 246/246 [00:05<00:00, 41.88it/s]


Epoch: 100 	 Loss: 0.23069
Accuracies: [0.68268156 0.09090909 0.47253886 0.5        0.81034483 0.22619048
 0.7628866  0.34782609 0.11764706 0.58333333 0.17948718] 
 Average acccuracy: 0.4339859159612654
F1 Scores: [0.65347589 0.05405401 0.59067353 0.16666664 0.37903222 0.1610169
 0.81767951 0.34042548 0.03007517 0.6086956  0.2058823 ] 
 Average F1: 0.3643342954433432
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:05<00:00, 42.31it/s]


Epoch: 1 	 Loss: 2.21800


100%|██████████| 246/246 [00:05<00:00, 42.70it/s]


Epoch: 2 	 Loss: 1.90333


100%|██████████| 246/246 [00:06<00:00, 39.91it/s]


Epoch: 3 	 Loss: 1.74036


100%|██████████| 246/246 [00:05<00:00, 42.38it/s]


Epoch: 4 	 Loss: 1.62440


100%|██████████| 246/246 [00:07<00:00, 33.13it/s]


Epoch: 5 	 Loss: 1.55564


100%|██████████| 246/246 [00:08<00:00, 27.71it/s]


Epoch: 6 	 Loss: 1.49227


100%|██████████| 246/246 [00:16<00:00, 14.52it/s]


Epoch: 7 	 Loss: 1.46216


100%|██████████| 246/246 [00:08<00:00, 28.57it/s]


Epoch: 8 	 Loss: 1.40156


100%|██████████| 246/246 [00:08<00:00, 29.68it/s]


Epoch: 9 	 Loss: 1.37178


100%|██████████| 246/246 [00:07<00:00, 34.61it/s]


Epoch: 10 	 Loss: 1.31669


100%|██████████| 246/246 [00:06<00:00, 40.85it/s]


Epoch: 11 	 Loss: 1.27326


100%|██████████| 246/246 [00:07<00:00, 30.80it/s]


Epoch: 12 	 Loss: 1.21506


100%|██████████| 246/246 [00:09<00:00, 26.93it/s]


Epoch: 13 	 Loss: 1.16609


100%|██████████| 246/246 [00:07<00:00, 32.52it/s]


Epoch: 14 	 Loss: 1.14901


100%|██████████| 246/246 [00:07<00:00, 32.52it/s]


Epoch: 15 	 Loss: 1.11119


100%|██████████| 246/246 [00:07<00:00, 32.98it/s]


Epoch: 16 	 Loss: 1.07018


100%|██████████| 246/246 [00:07<00:00, 34.05it/s]


Epoch: 17 	 Loss: 1.04509


100%|██████████| 246/246 [00:07<00:00, 34.60it/s]


Epoch: 18 	 Loss: 1.01283


100%|██████████| 246/246 [00:06<00:00, 35.85it/s]


Epoch: 19 	 Loss: 0.97454


100%|██████████| 246/246 [00:07<00:00, 34.63it/s]


Epoch: 20 	 Loss: 0.94853


100%|██████████| 246/246 [00:06<00:00, 36.51it/s]


Epoch: 21 	 Loss: 0.91058


100%|██████████| 246/246 [00:05<00:00, 41.76it/s]


Epoch: 22 	 Loss: 0.87812


100%|██████████| 246/246 [00:07<00:00, 33.34it/s]


Epoch: 23 	 Loss: 0.87577


100%|██████████| 246/246 [00:06<00:00, 40.58it/s]


Epoch: 24 	 Loss: 0.88226


100%|██████████| 246/246 [00:07<00:00, 34.47it/s]


Epoch: 25 	 Loss: 0.82498


100%|██████████| 246/246 [00:06<00:00, 36.98it/s]


Epoch: 26 	 Loss: 0.79245


100%|██████████| 246/246 [00:06<00:00, 36.74it/s]


Epoch: 27 	 Loss: 0.75436


100%|██████████| 246/246 [00:05<00:00, 42.34it/s]


Epoch: 28 	 Loss: 0.72891


100%|██████████| 246/246 [00:06<00:00, 40.81it/s]


Epoch: 29 	 Loss: 0.71379


100%|██████████| 246/246 [00:06<00:00, 40.72it/s]


Epoch: 30 	 Loss: 0.69376


100%|██████████| 246/246 [00:05<00:00, 41.65it/s]


Epoch: 31 	 Loss: 0.68337


100%|██████████| 246/246 [00:05<00:00, 43.02it/s]


Epoch: 32 	 Loss: 0.68545


100%|██████████| 246/246 [00:06<00:00, 39.21it/s]


Epoch: 33 	 Loss: 0.65102


100%|██████████| 246/246 [00:07<00:00, 34.92it/s]


Epoch: 34 	 Loss: 0.63442


100%|██████████| 246/246 [00:06<00:00, 36.88it/s]


Epoch: 35 	 Loss: 0.60449


100%|██████████| 246/246 [00:07<00:00, 31.98it/s]


Epoch: 36 	 Loss: 0.58991


100%|██████████| 246/246 [00:07<00:00, 31.68it/s]


Epoch: 37 	 Loss: 0.58203


100%|██████████| 246/246 [00:07<00:00, 31.55it/s]


Epoch: 38 	 Loss: 0.56859


100%|██████████| 246/246 [00:07<00:00, 34.06it/s]


Epoch: 39 	 Loss: 0.56015


100%|██████████| 246/246 [00:07<00:00, 31.49it/s]


Epoch: 40 	 Loss: 0.55426


100%|██████████| 246/246 [00:08<00:00, 29.36it/s]


Epoch: 41 	 Loss: 0.53442


100%|██████████| 246/246 [00:09<00:00, 25.77it/s]


Epoch: 42 	 Loss: 0.51987


100%|██████████| 246/246 [00:07<00:00, 34.24it/s]


Epoch: 43 	 Loss: 0.49015


100%|██████████| 246/246 [00:12<00:00, 20.46it/s]


Epoch: 44 	 Loss: 0.48692


100%|██████████| 246/246 [00:09<00:00, 27.01it/s]


Epoch: 45 	 Loss: 0.48608


100%|██████████| 246/246 [00:07<00:00, 31.44it/s]


Epoch: 46 	 Loss: 0.46718


100%|██████████| 246/246 [00:08<00:00, 27.34it/s]


Epoch: 47 	 Loss: 0.45628


100%|██████████| 246/246 [00:07<00:00, 32.87it/s]


Epoch: 48 	 Loss: 0.43740


100%|██████████| 246/246 [00:07<00:00, 30.86it/s]


Epoch: 49 	 Loss: 0.43815


100%|██████████| 246/246 [00:05<00:00, 41.41it/s]


Epoch: 50 	 Loss: 0.42548


100%|██████████| 246/246 [00:06<00:00, 38.62it/s]


Epoch: 51 	 Loss: 0.42857


100%|██████████| 246/246 [00:06<00:00, 38.36it/s]


Epoch: 52 	 Loss: 0.39779


100%|██████████| 246/246 [00:06<00:00, 38.43it/s]


Epoch: 53 	 Loss: 0.42061


100%|██████████| 246/246 [00:06<00:00, 40.30it/s]


Epoch: 54 	 Loss: 0.42193


100%|██████████| 246/246 [00:05<00:00, 45.88it/s]


Epoch: 55 	 Loss: 0.40756


100%|██████████| 246/246 [00:05<00:00, 46.31it/s]


Epoch: 56 	 Loss: 0.38923


100%|██████████| 246/246 [00:05<00:00, 46.43it/s]


Epoch: 57 	 Loss: 0.39409


100%|██████████| 246/246 [00:06<00:00, 37.32it/s]


Epoch: 58 	 Loss: 0.38772


100%|██████████| 246/246 [00:05<00:00, 41.48it/s]


Epoch: 59 	 Loss: 0.36966


100%|██████████| 246/246 [00:06<00:00, 40.14it/s]


Epoch: 60 	 Loss: 0.36157


100%|██████████| 246/246 [00:07<00:00, 32.33it/s]


Epoch: 61 	 Loss: 0.35525


100%|██████████| 246/246 [00:07<00:00, 34.92it/s]


Epoch: 62 	 Loss: 0.33892


100%|██████████| 246/246 [00:08<00:00, 30.14it/s]


Epoch: 63 	 Loss: 0.34665


100%|██████████| 246/246 [00:06<00:00, 35.72it/s]


Epoch: 64 	 Loss: 0.34064


100%|██████████| 246/246 [00:06<00:00, 36.63it/s]


Epoch: 65 	 Loss: 0.34758


100%|██████████| 246/246 [00:07<00:00, 33.87it/s]


Epoch: 66 	 Loss: 0.32218


100%|██████████| 246/246 [00:06<00:00, 38.81it/s]


Epoch: 67 	 Loss: 0.32744


100%|██████████| 246/246 [00:06<00:00, 38.83it/s]


Epoch: 68 	 Loss: 0.30612


100%|██████████| 246/246 [00:06<00:00, 40.38it/s]


Epoch: 69 	 Loss: 0.30531


100%|██████████| 246/246 [00:06<00:00, 38.05it/s]


Epoch: 70 	 Loss: 0.29394


100%|██████████| 246/246 [00:05<00:00, 44.72it/s]


Epoch: 71 	 Loss: 0.28614


100%|██████████| 246/246 [00:05<00:00, 43.45it/s]


Epoch: 72 	 Loss: 0.28593


100%|██████████| 246/246 [00:05<00:00, 42.07it/s]


Epoch: 73 	 Loss: 0.27796


100%|██████████| 246/246 [00:06<00:00, 40.03it/s]


Epoch: 74 	 Loss: 0.30417


100%|██████████| 246/246 [00:06<00:00, 37.65it/s]


Epoch: 75 	 Loss: 0.30422


100%|██████████| 246/246 [00:06<00:00, 38.82it/s]


Epoch: 76 	 Loss: 0.29084


100%|██████████| 246/246 [00:07<00:00, 34.38it/s]


Epoch: 77 	 Loss: 0.29665


100%|██████████| 246/246 [00:08<00:00, 30.37it/s]


Epoch: 78 	 Loss: 0.27348


100%|██████████| 246/246 [00:06<00:00, 36.47it/s]


Epoch: 79 	 Loss: 0.25628


100%|██████████| 246/246 [00:05<00:00, 41.96it/s]


Epoch: 80 	 Loss: 0.25088


100%|██████████| 246/246 [00:06<00:00, 38.49it/s]


Epoch: 81 	 Loss: 0.25238


100%|██████████| 246/246 [00:08<00:00, 28.66it/s]


Epoch: 82 	 Loss: 0.25768


100%|██████████| 246/246 [00:14<00:00, 16.70it/s]


Epoch: 83 	 Loss: 0.27270


100%|██████████| 246/246 [00:10<00:00, 23.50it/s]


Epoch: 84 	 Loss: 0.30125


100%|██████████| 246/246 [00:06<00:00, 38.69it/s]


Epoch: 85 	 Loss: 0.27149


100%|██████████| 246/246 [00:06<00:00, 40.00it/s]


Epoch: 86 	 Loss: 0.23567


100%|██████████| 246/246 [00:07<00:00, 33.86it/s]


Epoch: 87 	 Loss: 0.23950


100%|██████████| 246/246 [00:08<00:00, 30.66it/s]


Epoch: 88 	 Loss: 0.23912


100%|██████████| 246/246 [00:06<00:00, 36.46it/s]


Epoch: 89 	 Loss: 0.24315


100%|██████████| 246/246 [00:08<00:00, 29.33it/s]


Epoch: 90 	 Loss: 0.24651


100%|██████████| 246/246 [00:06<00:00, 38.16it/s]


Epoch: 91 	 Loss: 0.24493


100%|██████████| 246/246 [00:06<00:00, 40.23it/s]


Epoch: 92 	 Loss: 0.22315


100%|██████████| 246/246 [00:05<00:00, 42.12it/s]


Epoch: 93 	 Loss: 0.21431


100%|██████████| 246/246 [00:05<00:00, 41.08it/s]


Epoch: 94 	 Loss: 0.21476


100%|██████████| 246/246 [00:05<00:00, 42.81it/s]


Epoch: 95 	 Loss: 0.22934


100%|██████████| 246/246 [00:05<00:00, 41.25it/s]


Epoch: 96 	 Loss: 0.23472


100%|██████████| 246/246 [00:05<00:00, 41.85it/s]


Epoch: 97 	 Loss: 0.22868


100%|██████████| 246/246 [00:06<00:00, 40.36it/s]


Epoch: 98 	 Loss: 0.22095


100%|██████████| 246/246 [00:05<00:00, 41.54it/s]


Epoch: 99 	 Loss: 0.21726


100%|██████████| 246/246 [00:05<00:00, 43.15it/s]


Epoch: 100 	 Loss: 0.20940
Accuracies: [0.71106557 0.08125    0.53672316 0.61538462 0.74242424 0.16739447
 0.84140969 0.3        0.4        0.48854962 0.1038961 ] 
 Average acccuracy: 0.4534634070883946
F1 Scores: [0.47436769 0.0984848  0.59052054 0.25396822 0.38281246 0.27413585
 0.79749473 0.27272722 0.09160303 0.58447484 0.15094336] 
 Average F1: 0.3610484311527964


In [116]:
print(f"{np.mean(accs3):.4f}")
print(f"{np.mean(f1s3):.4f}")

0.4404
0.3679


## CNN-BiLSTM

In [117]:
while_i = 0
accs4 = []
f1s4 = []
micro_accs1 = []
micro_f1s1 = []
while while_i < 3:
    while_i += 1
    model4 = CNN_BiLSTM(hidden_size=128, dropout= 0.25, output_size= 11)
    optimizer = torch.optim.Adam(model4.parameters(), lr= 5e-4)
    loss_function = nn.CrossEntropyLoss(weight= class_weights)

    print(f'{"Starting Training":-^100}')
    model4.train()
    loss_list = []
    for epoch in range(100):
        running_loss = 0
        for idx in tqdm(range(246)):
            TRAIN_emb = load_tensor(filepath=f"../train_document/doc_{idx}_legal/embedding")
            TRAIN_labels = load_tensor(filepath=f"../train_document/doc_{idx}_legal/label")
            TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
            if TRAIN_emb.size(0) == 0:
                continue
            output = model4(TRAIN_emb)
            loss = loss_function(output,TRAIN_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        # scheduler.step()
        # scheduler1.step()
        # scheduler2.step()
        # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
        loss_list.append(running_loss/246)
        print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
        if running_loss/246 < 0.1:
            break
    cm = None
    for i in range(29):
        TEST_emb = load_tensor(filepath=f"../test_document/doc_{i}_legal/embedding")
        TEST_labels = load_tensor(filepath=f"../test_document/doc_{i}_legal/label")
        TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
        conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model4, num_labels= 11)
        if cm is None:
            cm = conf_matrix_helper
        else:
            cm = np.add(cm, conf_matrix_helper)
            
    accuracies = class_accuracy(cm)
    f1_scores = class_f1_score(cm)
    average_accuracy = np.mean(accuracies)
    average_f1 = np.mean(f1_scores)
    
    accs4.append(average_accuracy)
    f1s4.append(average_f1)
    micro_accs1.append(accuracies)
    micro_f1s1.append(f1_scores)

    print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
    print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

-----------------------------------------Starting Training------------------------------------------


  1%|          | 3/246 [00:00<00:10, 22.54it/s]

100%|██████████| 246/246 [00:12<00:00, 19.91it/s]


Epoch: 1 	 Loss: 2.22142


100%|██████████| 246/246 [00:08<00:00, 28.41it/s]


Epoch: 2 	 Loss: 1.89837


100%|██████████| 246/246 [00:10<00:00, 24.23it/s]


Epoch: 3 	 Loss: 1.82966


100%|██████████| 246/246 [00:07<00:00, 34.34it/s]


Epoch: 4 	 Loss: 1.68003


100%|██████████| 246/246 [00:09<00:00, 27.03it/s]


Epoch: 5 	 Loss: 1.69514


100%|██████████| 246/246 [00:07<00:00, 32.74it/s]


Epoch: 6 	 Loss: 1.60166


100%|██████████| 246/246 [00:06<00:00, 35.95it/s]


Epoch: 7 	 Loss: 1.56711


100%|██████████| 246/246 [00:07<00:00, 32.75it/s]


Epoch: 8 	 Loss: 1.47113


100%|██████████| 246/246 [00:06<00:00, 35.38it/s]


Epoch: 9 	 Loss: 1.42910


100%|██████████| 246/246 [00:06<00:00, 36.13it/s]


Epoch: 10 	 Loss: 1.37812


100%|██████████| 246/246 [00:06<00:00, 36.56it/s]


Epoch: 11 	 Loss: 1.35992


100%|██████████| 246/246 [00:06<00:00, 35.36it/s]


Epoch: 12 	 Loss: 1.32959


100%|██████████| 246/246 [00:06<00:00, 35.46it/s]


Epoch: 13 	 Loss: 1.28132


100%|██████████| 246/246 [00:06<00:00, 36.76it/s]


Epoch: 14 	 Loss: 1.21998


100%|██████████| 246/246 [00:06<00:00, 35.45it/s]


Epoch: 15 	 Loss: 1.26984


100%|██████████| 246/246 [00:07<00:00, 33.52it/s]


Epoch: 16 	 Loss: 1.17579


100%|██████████| 246/246 [00:06<00:00, 35.65it/s]


Epoch: 17 	 Loss: 1.14806


100%|██████████| 246/246 [00:07<00:00, 34.92it/s]


Epoch: 18 	 Loss: 1.08548


100%|██████████| 246/246 [00:06<00:00, 35.53it/s]


Epoch: 19 	 Loss: 1.03094


100%|██████████| 246/246 [00:06<00:00, 36.55it/s]


Epoch: 20 	 Loss: 1.00031


100%|██████████| 246/246 [00:06<00:00, 36.47it/s]


Epoch: 21 	 Loss: 0.96066


100%|██████████| 246/246 [00:07<00:00, 33.56it/s]


Epoch: 22 	 Loss: 0.90778


100%|██████████| 246/246 [00:07<00:00, 35.00it/s]


Epoch: 23 	 Loss: 0.88116


100%|██████████| 246/246 [00:07<00:00, 34.86it/s]


Epoch: 24 	 Loss: 0.89221


100%|██████████| 246/246 [00:06<00:00, 36.66it/s]


Epoch: 25 	 Loss: 0.85708


100%|██████████| 246/246 [00:06<00:00, 35.49it/s]


Epoch: 26 	 Loss: 0.78231


100%|██████████| 246/246 [00:06<00:00, 36.19it/s]


Epoch: 27 	 Loss: 0.75353


100%|██████████| 246/246 [00:07<00:00, 34.21it/s]


Epoch: 28 	 Loss: 0.73324


100%|██████████| 246/246 [00:06<00:00, 36.62it/s]


Epoch: 29 	 Loss: 0.71866


100%|██████████| 246/246 [00:08<00:00, 28.89it/s]


Epoch: 30 	 Loss: 0.68225


100%|██████████| 246/246 [00:08<00:00, 30.08it/s]


Epoch: 31 	 Loss: 0.65934


100%|██████████| 246/246 [00:08<00:00, 27.72it/s]


Epoch: 32 	 Loss: 0.60813


100%|██████████| 246/246 [00:08<00:00, 30.70it/s]


Epoch: 33 	 Loss: 0.59295


100%|██████████| 246/246 [00:06<00:00, 36.79it/s]


Epoch: 34 	 Loss: 0.58214


100%|██████████| 246/246 [00:08<00:00, 28.55it/s]


Epoch: 35 	 Loss: 0.55911


100%|██████████| 246/246 [00:07<00:00, 31.06it/s]


Epoch: 36 	 Loss: 0.52518


100%|██████████| 246/246 [00:08<00:00, 28.18it/s]


Epoch: 37 	 Loss: 0.50388


100%|██████████| 246/246 [00:07<00:00, 30.92it/s]


Epoch: 38 	 Loss: 0.52711


100%|██████████| 246/246 [00:08<00:00, 30.44it/s]


Epoch: 39 	 Loss: 0.49318


100%|██████████| 246/246 [00:08<00:00, 29.19it/s]


Epoch: 40 	 Loss: 0.47564


100%|██████████| 246/246 [00:08<00:00, 28.84it/s]


Epoch: 41 	 Loss: 0.44701


100%|██████████| 246/246 [00:07<00:00, 33.26it/s]


Epoch: 42 	 Loss: 0.42291


100%|██████████| 246/246 [00:07<00:00, 33.45it/s]


Epoch: 43 	 Loss: 0.40612


100%|██████████| 246/246 [00:07<00:00, 33.23it/s]


Epoch: 44 	 Loss: 0.40061


100%|██████████| 246/246 [00:06<00:00, 36.59it/s]


Epoch: 45 	 Loss: 0.37905


100%|██████████| 246/246 [00:06<00:00, 39.09it/s]


Epoch: 46 	 Loss: 0.37308


100%|██████████| 246/246 [00:06<00:00, 39.25it/s]


Epoch: 47 	 Loss: 0.36238


100%|██████████| 246/246 [00:06<00:00, 38.39it/s]


Epoch: 48 	 Loss: 0.35669


100%|██████████| 246/246 [00:06<00:00, 37.48it/s]


Epoch: 49 	 Loss: 0.36288


100%|██████████| 246/246 [00:06<00:00, 37.85it/s]


Epoch: 50 	 Loss: 0.35352


100%|██████████| 246/246 [00:06<00:00, 39.16it/s]


Epoch: 51 	 Loss: 0.32627


100%|██████████| 246/246 [00:06<00:00, 38.87it/s]


Epoch: 52 	 Loss: 0.30695


100%|██████████| 246/246 [00:06<00:00, 38.47it/s]


Epoch: 53 	 Loss: 0.28905


100%|██████████| 246/246 [00:06<00:00, 38.89it/s]


Epoch: 54 	 Loss: 0.28760


100%|██████████| 246/246 [00:06<00:00, 37.12it/s]


Epoch: 55 	 Loss: 0.28044


100%|██████████| 246/246 [00:06<00:00, 38.71it/s]


Epoch: 56 	 Loss: 0.27754


100%|██████████| 246/246 [00:06<00:00, 38.86it/s]


Epoch: 57 	 Loss: 0.26495


100%|██████████| 246/246 [00:06<00:00, 39.13it/s]


Epoch: 58 	 Loss: 0.30428


100%|██████████| 246/246 [00:06<00:00, 38.82it/s]


Epoch: 59 	 Loss: 0.33499


100%|██████████| 246/246 [00:06<00:00, 37.89it/s]


Epoch: 60 	 Loss: 0.29765


100%|██████████| 246/246 [00:06<00:00, 37.51it/s]


Epoch: 61 	 Loss: 0.27975


100%|██████████| 246/246 [00:06<00:00, 38.74it/s]


Epoch: 62 	 Loss: 0.24676


100%|██████████| 246/246 [00:06<00:00, 39.18it/s]


Epoch: 63 	 Loss: 0.22769


100%|██████████| 246/246 [00:06<00:00, 38.78it/s]


Epoch: 64 	 Loss: 0.21263


100%|██████████| 246/246 [00:06<00:00, 39.09it/s]


Epoch: 65 	 Loss: 0.20178


100%|██████████| 246/246 [00:06<00:00, 36.15it/s]


Epoch: 66 	 Loss: 0.20509


100%|██████████| 246/246 [00:06<00:00, 38.49it/s]


Epoch: 67 	 Loss: 0.20328


100%|██████████| 246/246 [00:06<00:00, 38.70it/s]


Epoch: 68 	 Loss: 0.20013


100%|██████████| 246/246 [00:06<00:00, 38.99it/s]


Epoch: 69 	 Loss: 0.20607


100%|██████████| 246/246 [00:06<00:00, 39.00it/s]


Epoch: 70 	 Loss: 0.23553


100%|██████████| 246/246 [00:06<00:00, 37.42it/s]


Epoch: 71 	 Loss: 0.22586


100%|██████████| 246/246 [00:06<00:00, 39.26it/s]


Epoch: 72 	 Loss: 0.19516


100%|██████████| 246/246 [00:06<00:00, 35.67it/s]


Epoch: 73 	 Loss: 0.17698


100%|██████████| 246/246 [00:06<00:00, 35.48it/s]


Epoch: 74 	 Loss: 0.17140


100%|██████████| 246/246 [00:06<00:00, 39.07it/s]


Epoch: 75 	 Loss: 0.16565


100%|██████████| 246/246 [00:06<00:00, 35.86it/s]


Epoch: 76 	 Loss: 0.16451


100%|██████████| 246/246 [00:06<00:00, 39.10it/s]


Epoch: 77 	 Loss: 0.16673


100%|██████████| 246/246 [00:07<00:00, 33.24it/s]


Epoch: 78 	 Loss: 0.17670


100%|██████████| 246/246 [00:07<00:00, 34.96it/s]


Epoch: 79 	 Loss: 0.18415


100%|██████████| 246/246 [00:06<00:00, 38.69it/s]


Epoch: 80 	 Loss: 0.17813


100%|██████████| 246/246 [00:06<00:00, 39.00it/s]


Epoch: 81 	 Loss: 0.16570


100%|██████████| 246/246 [00:06<00:00, 37.24it/s]


Epoch: 82 	 Loss: 0.16678


100%|██████████| 246/246 [00:06<00:00, 37.83it/s]


Epoch: 83 	 Loss: 0.15459


100%|██████████| 246/246 [00:06<00:00, 38.95it/s]


Epoch: 84 	 Loss: 0.16174


100%|██████████| 246/246 [00:07<00:00, 31.14it/s]


Epoch: 85 	 Loss: 0.15091


100%|██████████| 246/246 [00:08<00:00, 29.92it/s]


Epoch: 86 	 Loss: 0.14705


100%|██████████| 246/246 [00:08<00:00, 29.20it/s]


Epoch: 87 	 Loss: 0.13726


100%|██████████| 246/246 [00:07<00:00, 30.81it/s]


Epoch: 88 	 Loss: 0.13555


100%|██████████| 246/246 [00:07<00:00, 33.73it/s]


Epoch: 89 	 Loss: 0.22233


100%|██████████| 246/246 [00:07<00:00, 32.02it/s]


Epoch: 90 	 Loss: 0.18323


100%|██████████| 246/246 [00:07<00:00, 32.89it/s]


Epoch: 91 	 Loss: 0.14608


100%|██████████| 246/246 [00:06<00:00, 36.78it/s]


Epoch: 92 	 Loss: 0.13440


100%|██████████| 246/246 [00:09<00:00, 27.21it/s]


Epoch: 93 	 Loss: 0.12084


100%|██████████| 246/246 [00:07<00:00, 34.44it/s]


Epoch: 94 	 Loss: 0.11882


100%|██████████| 246/246 [00:07<00:00, 33.25it/s]


Epoch: 95 	 Loss: 0.11615


100%|██████████| 246/246 [00:06<00:00, 36.07it/s]


Epoch: 96 	 Loss: 0.12019


100%|██████████| 246/246 [00:06<00:00, 36.71it/s]


Epoch: 97 	 Loss: 0.11109


100%|██████████| 246/246 [00:07<00:00, 34.63it/s]


Epoch: 98 	 Loss: 0.11716


100%|██████████| 246/246 [00:06<00:00, 36.60it/s]


Epoch: 99 	 Loss: 0.11582


100%|██████████| 246/246 [00:06<00:00, 35.17it/s]


Epoch: 100 	 Loss: 0.12647
Accuracies: [0.69194313 0.16666667 0.47075743 0.34210526 0.48717949 0.0617284
 0.80357143 0.33695652 0.         0.6        0.125     ] 
 Average acccuracy: 0.3714462109843087
F1 Scores: [0.642111   0.01818181 0.60542535 0.2954545  0.37133546 0.04291841
 0.80357138 0.37804873 0.         0.62295077 0.13114749] 
 Average F1: 0.3555586267686182
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:06<00:00, 37.48it/s]


Epoch: 1 	 Loss: 2.21626


100%|██████████| 246/246 [00:06<00:00, 38.77it/s]


Epoch: 2 	 Loss: 1.86637


100%|██████████| 246/246 [00:06<00:00, 38.59it/s]


Epoch: 3 	 Loss: 1.72766


100%|██████████| 246/246 [00:06<00:00, 37.62it/s]


Epoch: 4 	 Loss: 1.66140


100%|██████████| 246/246 [00:07<00:00, 34.95it/s]


Epoch: 5 	 Loss: 1.57917


100%|██████████| 246/246 [00:06<00:00, 36.42it/s]


Epoch: 6 	 Loss: 1.52429


100%|██████████| 246/246 [00:08<00:00, 27.64it/s]


Epoch: 7 	 Loss: 1.46795


100%|██████████| 246/246 [00:23<00:00, 10.44it/s]


Epoch: 8 	 Loss: 1.40844


100%|██████████| 246/246 [00:14<00:00, 17.35it/s]


Epoch: 9 	 Loss: 1.39937


100%|██████████| 246/246 [00:19<00:00, 12.67it/s]


Epoch: 10 	 Loss: 1.31784


100%|██████████| 246/246 [00:17<00:00, 13.74it/s]


Epoch: 11 	 Loss: 1.26425


100%|██████████| 246/246 [00:15<00:00, 15.67it/s]


Epoch: 12 	 Loss: 1.21189


100%|██████████| 246/246 [00:16<00:00, 14.96it/s]


Epoch: 13 	 Loss: 1.17320


100%|██████████| 246/246 [00:18<00:00, 13.60it/s]


Epoch: 14 	 Loss: 1.11480


100%|██████████| 246/246 [00:22<00:00, 11.10it/s]


Epoch: 15 	 Loss: 1.08026


100%|██████████| 246/246 [00:13<00:00, 17.96it/s]


Epoch: 16 	 Loss: 1.03238


100%|██████████| 246/246 [00:15<00:00, 16.17it/s]


Epoch: 17 	 Loss: 1.00555


100%|██████████| 246/246 [00:16<00:00, 14.69it/s]


Epoch: 18 	 Loss: 0.96752


100%|██████████| 246/246 [00:14<00:00, 17.46it/s]


Epoch: 19 	 Loss: 0.91671


100%|██████████| 246/246 [00:12<00:00, 19.61it/s]


Epoch: 20 	 Loss: 0.88460


100%|██████████| 246/246 [00:11<00:00, 20.81it/s]


Epoch: 21 	 Loss: 0.86698


100%|██████████| 246/246 [00:12<00:00, 20.43it/s]


Epoch: 22 	 Loss: 0.79288


100%|██████████| 246/246 [00:10<00:00, 22.52it/s]


Epoch: 23 	 Loss: 0.75682


100%|██████████| 246/246 [00:20<00:00, 12.24it/s]


Epoch: 24 	 Loss: 0.75678


100%|██████████| 246/246 [00:12<00:00, 19.55it/s]


Epoch: 25 	 Loss: 0.69871


100%|██████████| 246/246 [00:14<00:00, 17.13it/s]


Epoch: 26 	 Loss: 0.65276


100%|██████████| 246/246 [00:15<00:00, 16.36it/s]


Epoch: 27 	 Loss: 0.60262


100%|██████████| 246/246 [00:13<00:00, 18.79it/s]


Epoch: 28 	 Loss: 0.58984


100%|██████████| 246/246 [00:13<00:00, 17.74it/s]


Epoch: 29 	 Loss: 0.57754


100%|██████████| 246/246 [00:14<00:00, 16.79it/s]


Epoch: 30 	 Loss: 0.55877


100%|██████████| 246/246 [00:13<00:00, 18.05it/s]


Epoch: 31 	 Loss: 0.53249


100%|██████████| 246/246 [00:11<00:00, 20.83it/s]


Epoch: 32 	 Loss: 0.50742


100%|██████████| 246/246 [00:11<00:00, 22.20it/s]


Epoch: 33 	 Loss: 0.49671


100%|██████████| 246/246 [00:11<00:00, 21.41it/s]


Epoch: 34 	 Loss: 0.45479


100%|██████████| 246/246 [00:11<00:00, 21.85it/s]


Epoch: 35 	 Loss: 0.42928


100%|██████████| 246/246 [00:12<00:00, 19.50it/s]


Epoch: 36 	 Loss: 0.41458


100%|██████████| 246/246 [00:10<00:00, 22.72it/s]


Epoch: 37 	 Loss: 0.39631


100%|██████████| 246/246 [00:12<00:00, 20.36it/s]


Epoch: 38 	 Loss: 0.58003


100%|██████████| 246/246 [00:13<00:00, 18.36it/s]


Epoch: 39 	 Loss: 0.46264


100%|██████████| 246/246 [00:11<00:00, 22.29it/s]


Epoch: 40 	 Loss: 0.37265


100%|██████████| 246/246 [00:12<00:00, 19.19it/s]


Epoch: 41 	 Loss: 0.35273


100%|██████████| 246/246 [00:12<00:00, 19.57it/s]


Epoch: 42 	 Loss: 0.33287


100%|██████████| 246/246 [00:14<00:00, 16.48it/s]


Epoch: 43 	 Loss: 0.32179


100%|██████████| 246/246 [00:14<00:00, 17.39it/s]


Epoch: 44 	 Loss: 0.30351


100%|██████████| 246/246 [00:12<00:00, 19.83it/s]


Epoch: 45 	 Loss: 0.30642


100%|██████████| 246/246 [00:11<00:00, 21.17it/s]


Epoch: 46 	 Loss: 0.29842


100%|██████████| 246/246 [00:12<00:00, 20.46it/s]


Epoch: 47 	 Loss: 0.29066


100%|██████████| 246/246 [00:11<00:00, 21.63it/s]


Epoch: 48 	 Loss: 0.29852


100%|██████████| 246/246 [00:11<00:00, 21.85it/s]


Epoch: 49 	 Loss: 0.29791


100%|██████████| 246/246 [00:11<00:00, 21.81it/s]


Epoch: 50 	 Loss: 0.27458


100%|██████████| 246/246 [00:11<00:00, 21.28it/s]


Epoch: 51 	 Loss: 0.25578


100%|██████████| 246/246 [00:11<00:00, 22.27it/s]


Epoch: 52 	 Loss: 0.25004


100%|██████████| 246/246 [00:11<00:00, 21.28it/s]


Epoch: 53 	 Loss: 0.24742


100%|██████████| 246/246 [00:11<00:00, 21.69it/s]


Epoch: 54 	 Loss: 0.24640


100%|██████████| 246/246 [00:10<00:00, 22.70it/s]


Epoch: 55 	 Loss: 0.24931


100%|██████████| 246/246 [00:11<00:00, 21.69it/s]


Epoch: 56 	 Loss: 0.25249


100%|██████████| 246/246 [00:14<00:00, 17.48it/s]


Epoch: 57 	 Loss: 0.22871


100%|██████████| 246/246 [00:12<00:00, 20.23it/s]


Epoch: 58 	 Loss: 0.22103


100%|██████████| 246/246 [00:09<00:00, 25.08it/s]


Epoch: 59 	 Loss: 0.21611


100%|██████████| 246/246 [00:09<00:00, 27.30it/s]


Epoch: 60 	 Loss: 0.20350


100%|██████████| 246/246 [00:11<00:00, 21.74it/s]


Epoch: 61 	 Loss: 0.20576


100%|██████████| 246/246 [00:10<00:00, 24.49it/s]


Epoch: 62 	 Loss: 0.20311


100%|██████████| 246/246 [00:09<00:00, 26.05it/s]


Epoch: 63 	 Loss: 0.18737


100%|██████████| 246/246 [00:10<00:00, 24.36it/s]


Epoch: 64 	 Loss: 0.18424


100%|██████████| 246/246 [00:09<00:00, 25.11it/s]


Epoch: 65 	 Loss: 0.19292


100%|██████████| 246/246 [00:09<00:00, 26.63it/s]


Epoch: 66 	 Loss: 0.19105


100%|██████████| 246/246 [00:10<00:00, 24.56it/s]


Epoch: 67 	 Loss: 0.17923


100%|██████████| 246/246 [00:10<00:00, 24.32it/s]


Epoch: 68 	 Loss: 0.17328


100%|██████████| 246/246 [00:09<00:00, 26.14it/s]


Epoch: 69 	 Loss: 0.16998


100%|██████████| 246/246 [00:09<00:00, 26.81it/s]


Epoch: 70 	 Loss: 0.15786


100%|██████████| 246/246 [00:09<00:00, 25.08it/s]


Epoch: 71 	 Loss: 0.17531


100%|██████████| 246/246 [00:09<00:00, 26.18it/s]


Epoch: 72 	 Loss: 0.17527


100%|██████████| 246/246 [00:09<00:00, 26.18it/s]


Epoch: 73 	 Loss: 0.16234


100%|██████████| 246/246 [00:11<00:00, 20.98it/s]


Epoch: 74 	 Loss: 0.15296


100%|██████████| 246/246 [00:10<00:00, 23.29it/s]


Epoch: 75 	 Loss: 0.14080


100%|██████████| 246/246 [00:09<00:00, 25.60it/s]


Epoch: 76 	 Loss: 0.13909


100%|██████████| 246/246 [00:10<00:00, 23.62it/s]


Epoch: 77 	 Loss: 0.14194


100%|██████████| 246/246 [00:10<00:00, 23.04it/s]


Epoch: 78 	 Loss: 0.14256


100%|██████████| 246/246 [00:09<00:00, 25.33it/s]


Epoch: 79 	 Loss: 0.15777


100%|██████████| 246/246 [00:09<00:00, 25.04it/s]


Epoch: 80 	 Loss: 0.14630


100%|██████████| 246/246 [00:09<00:00, 26.23it/s]


Epoch: 81 	 Loss: 0.13627


100%|██████████| 246/246 [00:09<00:00, 25.44it/s]


Epoch: 82 	 Loss: 0.13079


100%|██████████| 246/246 [00:10<00:00, 22.43it/s]


Epoch: 83 	 Loss: 0.13452


100%|██████████| 246/246 [00:09<00:00, 26.90it/s]


Epoch: 84 	 Loss: 0.13197


100%|██████████| 246/246 [00:09<00:00, 25.38it/s]


Epoch: 85 	 Loss: 0.14351


100%|██████████| 246/246 [00:09<00:00, 25.09it/s]


Epoch: 86 	 Loss: 0.12971


100%|██████████| 246/246 [00:10<00:00, 23.84it/s]


Epoch: 87 	 Loss: 0.11923


100%|██████████| 246/246 [00:10<00:00, 23.15it/s]


Epoch: 88 	 Loss: 0.11522


100%|██████████| 246/246 [00:11<00:00, 21.91it/s]


Epoch: 89 	 Loss: 0.11953


100%|██████████| 246/246 [00:09<00:00, 26.94it/s]


Epoch: 90 	 Loss: 0.11273


100%|██████████| 246/246 [00:11<00:00, 21.87it/s]


Epoch: 91 	 Loss: 0.11037


100%|██████████| 246/246 [00:09<00:00, 24.84it/s]


Epoch: 92 	 Loss: 0.12474


100%|██████████| 246/246 [00:09<00:00, 25.33it/s]


Epoch: 93 	 Loss: 0.12321


100%|██████████| 246/246 [00:09<00:00, 25.47it/s]


Epoch: 94 	 Loss: 0.11490


100%|██████████| 246/246 [00:10<00:00, 24.11it/s]


Epoch: 95 	 Loss: 0.12183


100%|██████████| 246/246 [00:09<00:00, 24.82it/s]


Epoch: 96 	 Loss: 0.11399


100%|██████████| 246/246 [00:10<00:00, 24.53it/s]


Epoch: 97 	 Loss: 0.11104


100%|██████████| 246/246 [00:10<00:00, 23.12it/s]


Epoch: 98 	 Loss: 0.11335


100%|██████████| 246/246 [00:09<00:00, 25.24it/s]


Epoch: 99 	 Loss: 0.10704


100%|██████████| 246/246 [00:09<00:00, 24.91it/s]


Epoch: 100 	 Loss: 0.09025
Accuracies: [0.6503006  0.08474576 0.55068493 0.38709677 0.54362416 0.16964286
 0.87723214 0.39285714 0.02884615 0.52066116 0.13043478] 
 Average acccuracy: 0.3941933151841162
F1 Scores: [0.65788135 0.06134965 0.61420927 0.29629625 0.47787606 0.14393934
 0.8256302  0.42307687 0.02727268 0.60287076 0.11538457] 
 Average F1: 0.38598063621948825
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:09<00:00, 26.70it/s]


Epoch: 1 	 Loss: 2.18662


100%|██████████| 246/246 [00:09<00:00, 27.16it/s]


Epoch: 2 	 Loss: 1.88087


100%|██████████| 246/246 [00:08<00:00, 27.86it/s]


Epoch: 3 	 Loss: 1.72150


100%|██████████| 246/246 [00:09<00:00, 26.53it/s]


Epoch: 4 	 Loss: 1.66432


100%|██████████| 246/246 [00:10<00:00, 23.39it/s]


Epoch: 5 	 Loss: 1.57037


100%|██████████| 246/246 [00:10<00:00, 22.40it/s]


Epoch: 6 	 Loss: 1.52428


100%|██████████| 246/246 [00:11<00:00, 20.87it/s]


Epoch: 7 	 Loss: 1.46637


100%|██████████| 246/246 [00:10<00:00, 23.80it/s]


Epoch: 8 	 Loss: 1.42389


100%|██████████| 246/246 [00:11<00:00, 20.96it/s]


Epoch: 9 	 Loss: 1.36535


100%|██████████| 246/246 [00:10<00:00, 23.66it/s]


Epoch: 10 	 Loss: 1.31118


100%|██████████| 246/246 [00:09<00:00, 24.88it/s]


Epoch: 11 	 Loss: 1.27823


100%|██████████| 246/246 [00:09<00:00, 26.89it/s]


Epoch: 12 	 Loss: 1.29799


100%|██████████| 246/246 [00:10<00:00, 22.74it/s]


Epoch: 13 	 Loss: 1.21011


100%|██████████| 246/246 [00:09<00:00, 26.45it/s]


Epoch: 14 	 Loss: 1.14918


100%|██████████| 246/246 [00:09<00:00, 27.14it/s]


Epoch: 15 	 Loss: 1.09099


100%|██████████| 246/246 [00:09<00:00, 26.28it/s]


Epoch: 16 	 Loss: 1.07701


100%|██████████| 246/246 [00:10<00:00, 24.37it/s]


Epoch: 17 	 Loss: 1.01707


100%|██████████| 246/246 [00:09<00:00, 25.81it/s]


Epoch: 18 	 Loss: 0.97070


100%|██████████| 246/246 [00:10<00:00, 22.50it/s]


Epoch: 19 	 Loss: 0.95810


100%|██████████| 246/246 [00:10<00:00, 23.28it/s]


Epoch: 20 	 Loss: 0.90475


100%|██████████| 246/246 [00:10<00:00, 23.41it/s]


Epoch: 21 	 Loss: 0.92416


100%|██████████| 246/246 [00:12<00:00, 19.36it/s]


Epoch: 22 	 Loss: 0.96203


100%|██████████| 246/246 [00:09<00:00, 26.13it/s]


Epoch: 23 	 Loss: 0.81817


100%|██████████| 246/246 [00:11<00:00, 20.74it/s]


Epoch: 24 	 Loss: 0.78363


100%|██████████| 246/246 [00:12<00:00, 19.74it/s]


Epoch: 25 	 Loss: 0.77249


100%|██████████| 246/246 [00:14<00:00, 17.32it/s]


Epoch: 26 	 Loss: 0.98007


100%|██████████| 246/246 [00:12<00:00, 19.67it/s]


Epoch: 27 	 Loss: 0.76823


100%|██████████| 246/246 [00:11<00:00, 20.74it/s]


Epoch: 28 	 Loss: 0.72596


100%|██████████| 246/246 [00:12<00:00, 19.29it/s]


Epoch: 29 	 Loss: 0.66719


100%|██████████| 246/246 [00:11<00:00, 21.41it/s]


Epoch: 30 	 Loss: 0.61722


100%|██████████| 246/246 [00:09<00:00, 25.08it/s]


Epoch: 31 	 Loss: 0.58224


100%|██████████| 246/246 [00:10<00:00, 24.57it/s]


Epoch: 32 	 Loss: 0.54593


100%|██████████| 246/246 [00:09<00:00, 25.81it/s]


Epoch: 33 	 Loss: 0.53501


100%|██████████| 246/246 [00:09<00:00, 26.72it/s]


Epoch: 34 	 Loss: 0.51518


100%|██████████| 246/246 [00:09<00:00, 26.71it/s]


Epoch: 35 	 Loss: 0.52018


100%|██████████| 246/246 [00:09<00:00, 25.71it/s]


Epoch: 36 	 Loss: 0.48666


100%|██████████| 246/246 [00:09<00:00, 27.15it/s]


Epoch: 37 	 Loss: 0.45820


100%|██████████| 246/246 [00:09<00:00, 24.78it/s]


Epoch: 38 	 Loss: 0.43661


100%|██████████| 246/246 [00:10<00:00, 23.58it/s]


Epoch: 39 	 Loss: 0.42391


100%|██████████| 246/246 [00:10<00:00, 23.10it/s]


Epoch: 40 	 Loss: 0.42031


100%|██████████| 246/246 [00:09<00:00, 24.84it/s]


Epoch: 41 	 Loss: 0.39418


100%|██████████| 246/246 [00:09<00:00, 25.39it/s]


Epoch: 42 	 Loss: 0.41091


100%|██████████| 246/246 [00:11<00:00, 21.42it/s]


Epoch: 43 	 Loss: 0.38920


100%|██████████| 246/246 [00:10<00:00, 24.07it/s]


Epoch: 44 	 Loss: 0.35705


100%|██████████| 246/246 [00:11<00:00, 21.79it/s]


Epoch: 45 	 Loss: 0.34221


100%|██████████| 246/246 [00:10<00:00, 24.12it/s]


Epoch: 46 	 Loss: 0.33677


100%|██████████| 246/246 [00:11<00:00, 21.08it/s]


Epoch: 47 	 Loss: 0.31655


100%|██████████| 246/246 [00:09<00:00, 24.87it/s]


Epoch: 48 	 Loss: 0.31637


100%|██████████| 246/246 [00:09<00:00, 24.77it/s]


Epoch: 49 	 Loss: 0.32929


100%|██████████| 246/246 [00:09<00:00, 26.66it/s]


Epoch: 50 	 Loss: 0.29580


100%|██████████| 246/246 [00:10<00:00, 23.01it/s]


Epoch: 51 	 Loss: 0.30732


100%|██████████| 246/246 [00:09<00:00, 25.12it/s]


Epoch: 52 	 Loss: 0.28748


100%|██████████| 246/246 [00:10<00:00, 22.74it/s]


Epoch: 53 	 Loss: 0.30214


100%|██████████| 246/246 [00:09<00:00, 26.72it/s]


Epoch: 54 	 Loss: 0.27831


100%|██████████| 246/246 [00:09<00:00, 25.82it/s]


Epoch: 55 	 Loss: 0.26979


100%|██████████| 246/246 [00:09<00:00, 24.91it/s]


Epoch: 56 	 Loss: 0.27132


100%|██████████| 246/246 [00:09<00:00, 25.19it/s]


Epoch: 57 	 Loss: 0.25925


100%|██████████| 246/246 [00:17<00:00, 14.16it/s]


Epoch: 58 	 Loss: 0.26270


100%|██████████| 246/246 [00:09<00:00, 25.44it/s]


Epoch: 59 	 Loss: 0.25900


100%|██████████| 246/246 [00:09<00:00, 25.26it/s]


Epoch: 60 	 Loss: 0.23629


100%|██████████| 246/246 [00:12<00:00, 20.46it/s]


Epoch: 61 	 Loss: 0.22216


100%|██████████| 246/246 [00:09<00:00, 26.08it/s]


Epoch: 62 	 Loss: 0.21847


100%|██████████| 246/246 [00:09<00:00, 25.73it/s]


Epoch: 63 	 Loss: 0.20459


100%|██████████| 246/246 [00:10<00:00, 24.43it/s]


Epoch: 64 	 Loss: 0.19863


100%|██████████| 246/246 [00:09<00:00, 24.82it/s]


Epoch: 65 	 Loss: 0.19708


100%|██████████| 246/246 [00:09<00:00, 26.73it/s]


Epoch: 66 	 Loss: 0.19923


100%|██████████| 246/246 [00:09<00:00, 25.57it/s]


Epoch: 67 	 Loss: 0.20731


100%|██████████| 246/246 [00:09<00:00, 24.62it/s]


Epoch: 68 	 Loss: 0.20352


100%|██████████| 246/246 [00:09<00:00, 26.18it/s]


Epoch: 69 	 Loss: 0.18412


100%|██████████| 246/246 [00:09<00:00, 25.80it/s]


Epoch: 70 	 Loss: 0.17059


100%|██████████| 246/246 [00:11<00:00, 21.04it/s]


Epoch: 71 	 Loss: 0.17604


100%|██████████| 246/246 [00:10<00:00, 22.96it/s]


Epoch: 72 	 Loss: 0.19465


100%|██████████| 246/246 [00:09<00:00, 27.15it/s]


Epoch: 73 	 Loss: 0.18953


100%|██████████| 246/246 [00:10<00:00, 24.55it/s]


Epoch: 74 	 Loss: 0.20358


100%|██████████| 246/246 [00:09<00:00, 24.87it/s]


Epoch: 75 	 Loss: 0.17155


100%|██████████| 246/246 [00:09<00:00, 26.20it/s]


Epoch: 76 	 Loss: 0.21632


100%|██████████| 246/246 [00:09<00:00, 26.62it/s]


Epoch: 77 	 Loss: 0.18055


100%|██████████| 246/246 [00:09<00:00, 25.62it/s]


Epoch: 78 	 Loss: 0.15148


100%|██████████| 246/246 [00:09<00:00, 26.91it/s]


Epoch: 79 	 Loss: 0.13897


100%|██████████| 246/246 [00:10<00:00, 24.41it/s]


Epoch: 80 	 Loss: 0.13707


100%|██████████| 246/246 [00:09<00:00, 26.89it/s]


Epoch: 81 	 Loss: 0.13842


100%|██████████| 246/246 [00:08<00:00, 27.44it/s]


Epoch: 82 	 Loss: 0.13482


100%|██████████| 246/246 [00:09<00:00, 25.09it/s]


Epoch: 83 	 Loss: 0.13850


100%|██████████| 246/246 [00:11<00:00, 21.56it/s]


Epoch: 84 	 Loss: 0.13460


100%|██████████| 246/246 [00:09<00:00, 25.82it/s]


Epoch: 85 	 Loss: 0.14006


100%|██████████| 246/246 [00:09<00:00, 24.69it/s]


Epoch: 86 	 Loss: 0.14006


100%|██████████| 246/246 [00:10<00:00, 24.55it/s]


Epoch: 87 	 Loss: 0.12808


100%|██████████| 246/246 [00:09<00:00, 25.25it/s]


Epoch: 88 	 Loss: 0.13302


100%|██████████| 246/246 [00:09<00:00, 25.12it/s]


Epoch: 89 	 Loss: 0.15278


100%|██████████| 246/246 [00:10<00:00, 22.88it/s]


Epoch: 90 	 Loss: 0.14381


100%|██████████| 246/246 [00:11<00:00, 22.36it/s]


Epoch: 91 	 Loss: 0.13020


100%|██████████| 246/246 [00:10<00:00, 23.69it/s]


Epoch: 92 	 Loss: 0.11911


100%|██████████| 246/246 [00:09<00:00, 26.03it/s]


Epoch: 93 	 Loss: 0.11887


100%|██████████| 246/246 [00:10<00:00, 24.35it/s]


Epoch: 94 	 Loss: 0.11531


100%|██████████| 246/246 [00:09<00:00, 26.04it/s]


Epoch: 95 	 Loss: 0.11379


100%|██████████| 246/246 [00:09<00:00, 25.46it/s]


Epoch: 96 	 Loss: 0.11160


100%|██████████| 246/246 [00:09<00:00, 25.71it/s]


Epoch: 97 	 Loss: 0.11654


100%|██████████| 246/246 [00:09<00:00, 24.89it/s]


Epoch: 98 	 Loss: 0.13259


100%|██████████| 246/246 [00:10<00:00, 24.24it/s]


Epoch: 99 	 Loss: 0.11272


100%|██████████| 246/246 [00:09<00:00, 26.60it/s]


Epoch: 100 	 Loss: 0.10305
Accuracies: [0.60479042 0.01075269 0.56730769 0.34210526 0.60714286 0.09821429
 0.88356164 0.26890756 0.15517241 0.46456693 0.04      ] 
 Average acccuracy: 0.367501977767658
F1 Scores: [0.65951488 0.01015223 0.58852863 0.2954545  0.27642273 0.08333328
 0.821656   0.33507849 0.10344823 0.54883716 0.03703699] 
 Average F1: 0.34176937400193425


In [118]:
print(f"{np.mean(accs4):.4f}")
print(f"{np.mean(f1s4):.4f}")

0.3777
0.3611


In [27]:
torch.save(model1, 'bert-base-bilstm.pth')
torch.save(model2, 'bert-base-cnnbilstm.pth')
torch.save(model3, 'bert-legal-bilstm.pth')
torch.save(model4, 'bert-legal-cnnbilstm.pth')

# InLEGAL BERT

In [28]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'..')))

from Dataset_Reader import Dataset_Reader
from utils import read_json, data_to_embeddings, save_tensor, label_encode, organize_data
from utils import document_max_length, write_dictionary_to_json

from main import TRAIN_DATA_PATH, TEST_DATA_PATH

In [29]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
model = AutoModel.from_pretrained("law-ai/InLegalBERT")

tokenizer_config.json: 100%|██████████| 516/516 [00:00<00:00, 419kB/s]
vocab.txt: 100%|██████████| 222k/222k [00:00<00:00, 1.06MB/s]
special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 115kB/s]
config.json: 100%|██████████| 671/671 [00:00<00:00, 492kB/s]
pytorch_model.bin: 100%|██████████| 534M/534M [00:51<00:00, 10.5MB/s] 
  return self.fget.__get__(instance, owner)()


In [32]:
train_data = Dataset_Reader('../data/train.json')
test_data = Dataset_Reader('../data/dev.json')

print(f"Number of sentences in training data: {len(train_data.texts)}")
print(f"Number of sentences in test data: {len(test_data.texts)}")

# Manually defining labels for convenience
list_of_targets = ['ISSUE', 'FAC', 'NONE', 'ARG_PETITIONER', 'PRE_NOT_RELIED',
                'STA', 'RPC', 'ARG_RESPONDENT', 'PREAMBLE', 'ANALYSIS', 'RLC', 'PRE_RELIED', 'RATIO']

# Numerically encode labels
label_encoder = label_encode(list_of_targets)


#Compute the maximum sentence length for each document in the training and test data (to ensure all embeddings will be the same size within a document)
max_length_dict_TRAIN = document_max_length(train_data, tokenizer=tokenizer)
max_length_dict_TEST = document_max_length(test_data, tokenizer=tokenizer)

# # To same time during training process, write these documents to json file
# write_dictionary_to_json(max_length_dict_TRAIN, 'max_length_dicts/max_length_train.json')
# write_dictionary_to_json(max_length_dict_TEST, 'max_length_dicts/max_length_test.json')


#retrieve max_length dictionaries to compute word embeddings
# max_length_dict_TRAIN = read_json('max_length_dicts/max_length_train.json', reading_max_length=True)
# max_length_dict_TEST = read_json('max_length_dicts/max_length_test.json', reading_max_length=True)


#organize and process data

train_doc_idxs, train_batched_texts, train_batched_labels = organize_data(train_data, batch_size= 1) 
test_doc_idxs, test_batched_texts, test_batched_labels = organize_data(test_data, batch_size= 1) 

for idx, train_idx in tqdm(enumerate(train_doc_idxs)):
    TRAIN_emb, TRAIN_labels = data_to_embeddings(train_idx, train_batched_texts[idx], train_batched_labels[idx],
                                                label_encoder,max_length_dict_TRAIN, tokenizer=tokenizer,
                                                emb_model=model)
    save_tensor(TRAIN_emb, 'train_document/doc_'+str(idx)+'_inlegal/',"embedding")
    save_tensor(TRAIN_labels, 'train_document/doc_'+str(idx)+'_inlegal/',"label")
    

for idx, test_idx in tqdm(enumerate(test_doc_idxs)):
    TEST_emb, TEST_labels = data_to_embeddings(test_idx, test_batched_texts[idx], test_batched_labels[idx],
                                            label_encoder,max_length_dict_TEST, tokenizer=tokenizer,
                                            emb_model=model)
    save_tensor(TEST_emb, 'test_document/doc_'+str(idx)+'_inlegal/',"embedding")
    save_tensor(TEST_labels, 'test_document/doc_'+str(idx)+'_inlegal/',"label")
    

Number of sentences in training data: 28986
Number of sentences in test data: 2890


1it [00:32, 32.08s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_0_inlegal/embedding'
Tensor saved to 'train_document/doc_0_inlegal/label'


2it [01:07, 33.86s/it]

X_train size: torch.Size([72, 1, 768])	Y_train size: torch.Size([72])
Tensor saved to 'train_document/doc_1_inlegal/embedding'
Tensor saved to 'train_document/doc_1_inlegal/label'


3it [03:33, 85.00s/it]

X_train size: torch.Size([200, 1, 768])	Y_train size: torch.Size([200])
Tensor saved to 'train_document/doc_2_inlegal/embedding'
Tensor saved to 'train_document/doc_2_inlegal/label'


4it [04:49, 81.65s/it]

X_train size: torch.Size([119, 1, 768])	Y_train size: torch.Size([119])
Tensor saved to 'train_document/doc_3_inlegal/embedding'
Tensor saved to 'train_document/doc_3_inlegal/label'


5it [06:58, 98.67s/it]

X_train size: torch.Size([184, 1, 768])	Y_train size: torch.Size([184])
Tensor saved to 'train_document/doc_4_inlegal/embedding'
Tensor saved to 'train_document/doc_4_inlegal/label'


6it [08:03, 87.17s/it]

X_train size: torch.Size([211, 1, 768])	Y_train size: torch.Size([211])
Tensor saved to 'train_document/doc_5_inlegal/embedding'
Tensor saved to 'train_document/doc_5_inlegal/label'


7it [09:26, 85.89s/it]

X_train size: torch.Size([140, 1, 768])	Y_train size: torch.Size([140])
Tensor saved to 'train_document/doc_6_inlegal/embedding'
Tensor saved to 'train_document/doc_6_inlegal/label'


8it [09:52, 66.69s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_7_inlegal/embedding'
Tensor saved to 'train_document/doc_7_inlegal/label'


9it [11:18, 72.84s/it]

X_train size: torch.Size([228, 1, 768])	Y_train size: torch.Size([228])
Tensor saved to 'train_document/doc_8_inlegal/embedding'
Tensor saved to 'train_document/doc_8_inlegal/label'


10it [11:57, 62.53s/it]

X_train size: torch.Size([99, 1, 768])	Y_train size: torch.Size([99])
Tensor saved to 'train_document/doc_9_inlegal/embedding'
Tensor saved to 'train_document/doc_9_inlegal/label'


11it [12:14, 48.52s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_10_inlegal/embedding'
Tensor saved to 'train_document/doc_10_inlegal/label'


12it [13:40, 59.98s/it]

X_train size: torch.Size([213, 1, 768])	Y_train size: torch.Size([213])
Tensor saved to 'train_document/doc_11_inlegal/embedding'
Tensor saved to 'train_document/doc_11_inlegal/label'


13it [14:25, 55.28s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_12_inlegal/embedding'
Tensor saved to 'train_document/doc_12_inlegal/label'


14it [16:39, 79.20s/it]

X_train size: torch.Size([199, 1, 768])	Y_train size: torch.Size([199])
Tensor saved to 'train_document/doc_13_inlegal/embedding'
Tensor saved to 'train_document/doc_13_inlegal/label'


15it [18:23, 86.63s/it]

X_train size: torch.Size([188, 1, 768])	Y_train size: torch.Size([188])
Tensor saved to 'train_document/doc_14_inlegal/embedding'
Tensor saved to 'train_document/doc_14_inlegal/label'


16it [23:44, 157.16s/it]

X_train size: torch.Size([271, 1, 768])	Y_train size: torch.Size([271])
Tensor saved to 'train_document/doc_15_inlegal/embedding'
Tensor saved to 'train_document/doc_15_inlegal/label'


17it [23:58, 114.12s/it]

X_train size: torch.Size([43, 1, 768])	Y_train size: torch.Size([43])
Tensor saved to 'train_document/doc_16_inlegal/embedding'
Tensor saved to 'train_document/doc_16_inlegal/label'


18it [24:25, 87.84s/it] 

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_17_inlegal/embedding'
Tensor saved to 'train_document/doc_17_inlegal/label'


19it [25:24, 79.24s/it]

X_train size: torch.Size([171, 1, 768])	Y_train size: torch.Size([171])
Tensor saved to 'train_document/doc_18_inlegal/embedding'
Tensor saved to 'train_document/doc_18_inlegal/label'


20it [28:05, 103.88s/it]

X_train size: torch.Size([149, 1, 768])	Y_train size: torch.Size([149])
Tensor saved to 'train_document/doc_19_inlegal/embedding'
Tensor saved to 'train_document/doc_19_inlegal/label'


21it [28:51, 86.31s/it] 

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_20_inlegal/embedding'
Tensor saved to 'train_document/doc_20_inlegal/label'


22it [29:15, 67.81s/it]

X_train size: torch.Size([56, 1, 768])	Y_train size: torch.Size([56])
Tensor saved to 'train_document/doc_21_inlegal/embedding'
Tensor saved to 'train_document/doc_21_inlegal/label'


23it [29:28, 51.33s/it]

X_train size: torch.Size([47, 1, 768])	Y_train size: torch.Size([47])
Tensor saved to 'train_document/doc_22_inlegal/embedding'
Tensor saved to 'train_document/doc_22_inlegal/label'


24it [30:07, 47.49s/it]

X_train size: torch.Size([116, 1, 768])	Y_train size: torch.Size([116])
Tensor saved to 'train_document/doc_23_inlegal/embedding'
Tensor saved to 'train_document/doc_23_inlegal/label'


25it [30:38, 42.56s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_24_inlegal/embedding'
Tensor saved to 'train_document/doc_24_inlegal/label'


26it [30:59, 36.20s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_25_inlegal/embedding'
Tensor saved to 'train_document/doc_25_inlegal/label'


27it [32:14, 47.70s/it]

X_train size: torch.Size([109, 1, 768])	Y_train size: torch.Size([109])
Tensor saved to 'train_document/doc_26_inlegal/embedding'
Tensor saved to 'train_document/doc_26_inlegal/label'


28it [32:58, 46.78s/it]

X_train size: torch.Size([155, 1, 768])	Y_train size: torch.Size([155])
Tensor saved to 'train_document/doc_27_inlegal/embedding'
Tensor saved to 'train_document/doc_27_inlegal/label'


29it [33:58, 50.57s/it]

X_train size: torch.Size([198, 1, 768])	Y_train size: torch.Size([198])
Tensor saved to 'train_document/doc_28_inlegal/embedding'
Tensor saved to 'train_document/doc_28_inlegal/label'


30it [35:37, 65.26s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_29_inlegal/embedding'
Tensor saved to 'train_document/doc_29_inlegal/label'


31it [35:59, 52.09s/it]

X_train size: torch.Size([50, 1, 768])	Y_train size: torch.Size([50])
Tensor saved to 'train_document/doc_30_inlegal/embedding'
Tensor saved to 'train_document/doc_30_inlegal/label'


32it [36:12, 40.58s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_31_inlegal/embedding'
Tensor saved to 'train_document/doc_31_inlegal/label'


33it [37:47, 56.95s/it]

X_train size: torch.Size([264, 1, 768])	Y_train size: torch.Size([264])
Tensor saved to 'train_document/doc_32_inlegal/embedding'
Tensor saved to 'train_document/doc_32_inlegal/label'


34it [38:31, 52.83s/it]

X_train size: torch.Size([114, 1, 768])	Y_train size: torch.Size([114])
Tensor saved to 'train_document/doc_33_inlegal/embedding'
Tensor saved to 'train_document/doc_33_inlegal/label'


35it [39:01, 46.06s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_34_inlegal/embedding'
Tensor saved to 'train_document/doc_34_inlegal/label'


36it [40:39, 61.68s/it]

X_train size: torch.Size([243, 1, 768])	Y_train size: torch.Size([243])
Tensor saved to 'train_document/doc_35_inlegal/embedding'
Tensor saved to 'train_document/doc_35_inlegal/label'


37it [41:58, 66.78s/it]

X_train size: torch.Size([107, 1, 768])	Y_train size: torch.Size([107])
Tensor saved to 'train_document/doc_36_inlegal/embedding'
Tensor saved to 'train_document/doc_36_inlegal/label'


38it [43:11, 68.86s/it]

X_train size: torch.Size([127, 1, 768])	Y_train size: torch.Size([127])
Tensor saved to 'train_document/doc_37_inlegal/embedding'
Tensor saved to 'train_document/doc_37_inlegal/label'


39it [45:00, 80.62s/it]

X_train size: torch.Size([188, 1, 768])	Y_train size: torch.Size([188])
Tensor saved to 'train_document/doc_38_inlegal/embedding'
Tensor saved to 'train_document/doc_38_inlegal/label'


40it [46:06, 76.44s/it]

X_train size: torch.Size([189, 1, 768])	Y_train size: torch.Size([189])
Tensor saved to 'train_document/doc_39_inlegal/embedding'
Tensor saved to 'train_document/doc_39_inlegal/label'


41it [47:13, 73.67s/it]

X_train size: torch.Size([157, 1, 768])	Y_train size: torch.Size([157])
Tensor saved to 'train_document/doc_40_inlegal/embedding'
Tensor saved to 'train_document/doc_40_inlegal/label'


42it [47:32, 57.04s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_41_inlegal/embedding'
Tensor saved to 'train_document/doc_41_inlegal/label'


43it [48:17, 53.41s/it]

X_train size: torch.Size([120, 1, 768])	Y_train size: torch.Size([120])
Tensor saved to 'train_document/doc_42_inlegal/embedding'
Tensor saved to 'train_document/doc_42_inlegal/label'


44it [49:32, 60.00s/it]

X_train size: torch.Size([142, 1, 768])	Y_train size: torch.Size([142])
Tensor saved to 'train_document/doc_43_inlegal/embedding'
Tensor saved to 'train_document/doc_43_inlegal/label'


45it [50:10, 53.30s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_44_inlegal/embedding'
Tensor saved to 'train_document/doc_44_inlegal/label'


46it [51:11, 55.71s/it]

X_train size: torch.Size([146, 1, 768])	Y_train size: torch.Size([146])
Tensor saved to 'train_document/doc_45_inlegal/embedding'
Tensor saved to 'train_document/doc_45_inlegal/label'


47it [52:33, 63.54s/it]

X_train size: torch.Size([185, 1, 768])	Y_train size: torch.Size([185])
Tensor saved to 'train_document/doc_46_inlegal/embedding'
Tensor saved to 'train_document/doc_46_inlegal/label'


48it [52:54, 50.83s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_47_inlegal/embedding'
Tensor saved to 'train_document/doc_47_inlegal/label'


49it [53:12, 40.93s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_48_inlegal/embedding'
Tensor saved to 'train_document/doc_48_inlegal/label'


50it [53:44, 38.24s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_49_inlegal/embedding'
Tensor saved to 'train_document/doc_49_inlegal/label'


51it [54:49, 46.42s/it]

X_train size: torch.Size([187, 1, 768])	Y_train size: torch.Size([187])
Tensor saved to 'train_document/doc_50_inlegal/embedding'
Tensor saved to 'train_document/doc_50_inlegal/label'


52it [55:22, 42.41s/it]

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_51_inlegal/embedding'
Tensor saved to 'train_document/doc_51_inlegal/label'


53it [55:56, 39.74s/it]

X_train size: torch.Size([88, 1, 768])	Y_train size: torch.Size([88])
Tensor saved to 'train_document/doc_52_inlegal/embedding'
Tensor saved to 'train_document/doc_52_inlegal/label'


54it [56:24, 36.24s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'train_document/doc_53_inlegal/embedding'
Tensor saved to 'train_document/doc_53_inlegal/label'


55it [56:46, 32.06s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'train_document/doc_54_inlegal/embedding'
Tensor saved to 'train_document/doc_54_inlegal/label'


56it [57:28, 34.84s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_55_inlegal/embedding'
Tensor saved to 'train_document/doc_55_inlegal/label'


57it [58:37, 45.22s/it]

X_train size: torch.Size([141, 1, 768])	Y_train size: torch.Size([141])
Tensor saved to 'train_document/doc_56_inlegal/embedding'
Tensor saved to 'train_document/doc_56_inlegal/label'


58it [59:00, 38.69s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_57_inlegal/embedding'
Tensor saved to 'train_document/doc_57_inlegal/label'


59it [59:18, 32.21s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_58_inlegal/embedding'
Tensor saved to 'train_document/doc_58_inlegal/label'


60it [59:41, 29.70s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_59_inlegal/embedding'
Tensor saved to 'train_document/doc_59_inlegal/label'


61it [1:00:02, 27.04s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_60_inlegal/embedding'
Tensor saved to 'train_document/doc_60_inlegal/label'


62it [1:00:15, 22.82s/it]

X_train size: torch.Size([39, 1, 768])	Y_train size: torch.Size([39])
Tensor saved to 'train_document/doc_61_inlegal/embedding'
Tensor saved to 'train_document/doc_61_inlegal/label'


63it [1:02:50, 62.49s/it]

X_train size: torch.Size([208, 1, 768])	Y_train size: torch.Size([208])
Tensor saved to 'train_document/doc_62_inlegal/embedding'
Tensor saved to 'train_document/doc_62_inlegal/label'


64it [1:03:40, 58.81s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_63_inlegal/embedding'
Tensor saved to 'train_document/doc_63_inlegal/label'


65it [1:05:33, 75.04s/it]

X_train size: torch.Size([155, 1, 768])	Y_train size: torch.Size([155])
Tensor saved to 'train_document/doc_64_inlegal/embedding'
Tensor saved to 'train_document/doc_64_inlegal/label'


66it [1:06:36, 71.40s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_65_inlegal/embedding'
Tensor saved to 'train_document/doc_65_inlegal/label'


67it [1:07:11, 60.35s/it]

X_train size: torch.Size([52, 1, 768])	Y_train size: torch.Size([52])
Tensor saved to 'train_document/doc_66_inlegal/embedding'
Tensor saved to 'train_document/doc_66_inlegal/label'


68it [1:08:15, 61.39s/it]

X_train size: torch.Size([143, 1, 768])	Y_train size: torch.Size([143])
Tensor saved to 'train_document/doc_67_inlegal/embedding'
Tensor saved to 'train_document/doc_67_inlegal/label'


69it [1:10:03, 75.56s/it]

X_train size: torch.Size([142, 1, 768])	Y_train size: torch.Size([142])
Tensor saved to 'train_document/doc_68_inlegal/embedding'
Tensor saved to 'train_document/doc_68_inlegal/label'
X_train size: torch.Size([0, 1, 768])	Y_train size: torch.Size([0])
Tensor saved to 'train_document/doc_69_inlegal/embedding'
Tensor saved to 'train_document/doc_69_inlegal/label'


71it [1:10:24, 45.49s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'train_document/doc_70_inlegal/embedding'
Tensor saved to 'train_document/doc_70_inlegal/label'


72it [1:11:01, 43.31s/it]

X_train size: torch.Size([116, 1, 768])	Y_train size: torch.Size([116])
Tensor saved to 'train_document/doc_71_inlegal/embedding'
Tensor saved to 'train_document/doc_71_inlegal/label'


73it [1:12:02, 47.87s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_72_inlegal/embedding'
Tensor saved to 'train_document/doc_72_inlegal/label'


74it [1:13:15, 54.95s/it]

X_train size: torch.Size([122, 1, 768])	Y_train size: torch.Size([122])
Tensor saved to 'train_document/doc_73_inlegal/embedding'
Tensor saved to 'train_document/doc_73_inlegal/label'


75it [1:14:01, 52.44s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_74_inlegal/embedding'
Tensor saved to 'train_document/doc_74_inlegal/label'


76it [1:14:50, 51.30s/it]

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_75_inlegal/embedding'
Tensor saved to 'train_document/doc_75_inlegal/label'


77it [1:16:27, 64.73s/it]

X_train size: torch.Size([184, 1, 768])	Y_train size: torch.Size([184])
Tensor saved to 'train_document/doc_76_inlegal/embedding'
Tensor saved to 'train_document/doc_76_inlegal/label'


78it [1:17:16, 60.04s/it]

X_train size: torch.Size([144, 1, 768])	Y_train size: torch.Size([144])
Tensor saved to 'train_document/doc_77_inlegal/embedding'
Tensor saved to 'train_document/doc_77_inlegal/label'


79it [1:17:32, 47.00s/it]

X_train size: torch.Size([35, 1, 768])	Y_train size: torch.Size([35])
Tensor saved to 'train_document/doc_78_inlegal/embedding'
Tensor saved to 'train_document/doc_78_inlegal/label'


80it [1:19:15, 63.56s/it]

X_train size: torch.Size([140, 1, 768])	Y_train size: torch.Size([140])
Tensor saved to 'train_document/doc_79_inlegal/embedding'
Tensor saved to 'train_document/doc_79_inlegal/label'


81it [1:19:36, 51.02s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_80_inlegal/embedding'
Tensor saved to 'train_document/doc_80_inlegal/label'


82it [1:22:14, 82.73s/it]

X_train size: torch.Size([137, 1, 768])	Y_train size: torch.Size([137])
Tensor saved to 'train_document/doc_81_inlegal/embedding'
Tensor saved to 'train_document/doc_81_inlegal/label'


83it [1:22:26, 61.67s/it]

X_train size: torch.Size([24, 1, 768])	Y_train size: torch.Size([24])
Tensor saved to 'train_document/doc_82_inlegal/embedding'
Tensor saved to 'train_document/doc_82_inlegal/label'


84it [1:24:00, 71.24s/it]

X_train size: torch.Size([113, 1, 768])	Y_train size: torch.Size([113])
Tensor saved to 'train_document/doc_83_inlegal/embedding'
Tensor saved to 'train_document/doc_83_inlegal/label'


85it [1:24:29, 58.69s/it]

X_train size: torch.Size([105, 1, 768])	Y_train size: torch.Size([105])
Tensor saved to 'train_document/doc_84_inlegal/embedding'
Tensor saved to 'train_document/doc_84_inlegal/label'


86it [1:24:51, 47.77s/it]

X_train size: torch.Size([80, 1, 768])	Y_train size: torch.Size([80])
Tensor saved to 'train_document/doc_85_inlegal/embedding'
Tensor saved to 'train_document/doc_85_inlegal/label'


87it [1:25:37, 47.17s/it]

X_train size: torch.Size([106, 1, 768])	Y_train size: torch.Size([106])
Tensor saved to 'train_document/doc_86_inlegal/embedding'
Tensor saved to 'train_document/doc_86_inlegal/label'


88it [1:25:55, 38.51s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_87_inlegal/embedding'
Tensor saved to 'train_document/doc_87_inlegal/label'


89it [1:26:41, 40.57s/it]

X_train size: torch.Size([105, 1, 768])	Y_train size: torch.Size([105])
Tensor saved to 'train_document/doc_88_inlegal/embedding'
Tensor saved to 'train_document/doc_88_inlegal/label'


90it [1:26:58, 33.56s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_89_inlegal/embedding'
Tensor saved to 'train_document/doc_89_inlegal/label'


91it [1:28:13, 46.15s/it]

X_train size: torch.Size([136, 1, 768])	Y_train size: torch.Size([136])
Tensor saved to 'train_document/doc_90_inlegal/embedding'
Tensor saved to 'train_document/doc_90_inlegal/label'


92it [1:28:47, 42.36s/it]

X_train size: torch.Size([83, 1, 768])	Y_train size: torch.Size([83])
Tensor saved to 'train_document/doc_91_inlegal/embedding'
Tensor saved to 'train_document/doc_91_inlegal/label'


93it [1:30:36, 62.51s/it]

X_train size: torch.Size([221, 1, 768])	Y_train size: torch.Size([221])
Tensor saved to 'train_document/doc_92_inlegal/embedding'
Tensor saved to 'train_document/doc_92_inlegal/label'


94it [1:31:29, 59.66s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'train_document/doc_93_inlegal/embedding'
Tensor saved to 'train_document/doc_93_inlegal/label'


95it [1:32:41, 63.32s/it]

X_train size: torch.Size([114, 1, 768])	Y_train size: torch.Size([114])
Tensor saved to 'train_document/doc_94_inlegal/embedding'
Tensor saved to 'train_document/doc_94_inlegal/label'


96it [1:33:11, 53.38s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'train_document/doc_95_inlegal/embedding'
Tensor saved to 'train_document/doc_95_inlegal/label'


97it [1:33:31, 43.37s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'train_document/doc_96_inlegal/embedding'
Tensor saved to 'train_document/doc_96_inlegal/label'


98it [1:34:12, 42.63s/it]

X_train size: torch.Size([70, 1, 768])	Y_train size: torch.Size([70])
Tensor saved to 'train_document/doc_97_inlegal/embedding'
Tensor saved to 'train_document/doc_97_inlegal/label'


99it [1:35:57, 61.11s/it]

X_train size: torch.Size([264, 1, 768])	Y_train size: torch.Size([264])
Tensor saved to 'train_document/doc_98_inlegal/embedding'
Tensor saved to 'train_document/doc_98_inlegal/label'


100it [1:37:59, 79.51s/it]

X_train size: torch.Size([167, 1, 768])	Y_train size: torch.Size([167])
Tensor saved to 'train_document/doc_99_inlegal/embedding'
Tensor saved to 'train_document/doc_99_inlegal/label'


101it [1:38:15, 60.59s/it]

X_train size: torch.Size([49, 1, 768])	Y_train size: torch.Size([49])
Tensor saved to 'train_document/doc_100_inlegal/embedding'
Tensor saved to 'train_document/doc_100_inlegal/label'


102it [1:38:48, 52.31s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_101_inlegal/embedding'
Tensor saved to 'train_document/doc_101_inlegal/label'


103it [1:39:06, 41.76s/it]

X_train size: torch.Size([74, 1, 768])	Y_train size: torch.Size([74])
Tensor saved to 'train_document/doc_102_inlegal/embedding'
Tensor saved to 'train_document/doc_102_inlegal/label'


104it [1:40:25, 53.03s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_103_inlegal/embedding'
Tensor saved to 'train_document/doc_103_inlegal/label'


105it [1:41:18, 53.09s/it]

X_train size: torch.Size([147, 1, 768])	Y_train size: torch.Size([147])
Tensor saved to 'train_document/doc_104_inlegal/embedding'
Tensor saved to 'train_document/doc_104_inlegal/label'


106it [1:42:58, 67.10s/it]

X_train size: torch.Size([180, 1, 768])	Y_train size: torch.Size([180])
Tensor saved to 'train_document/doc_105_inlegal/embedding'
Tensor saved to 'train_document/doc_105_inlegal/label'


107it [1:43:27, 55.64s/it]

X_train size: torch.Size([71, 1, 768])	Y_train size: torch.Size([71])
Tensor saved to 'train_document/doc_106_inlegal/embedding'
Tensor saved to 'train_document/doc_106_inlegal/label'


108it [1:44:25, 56.28s/it]

X_train size: torch.Size([174, 1, 768])	Y_train size: torch.Size([174])
Tensor saved to 'train_document/doc_107_inlegal/embedding'
Tensor saved to 'train_document/doc_107_inlegal/label'


109it [1:45:14, 54.10s/it]

X_train size: torch.Size([126, 1, 768])	Y_train size: torch.Size([126])
Tensor saved to 'train_document/doc_108_inlegal/embedding'
Tensor saved to 'train_document/doc_108_inlegal/label'


110it [1:45:53, 49.79s/it]

X_train size: torch.Size([129, 1, 768])	Y_train size: torch.Size([129])
Tensor saved to 'train_document/doc_109_inlegal/embedding'
Tensor saved to 'train_document/doc_109_inlegal/label'


111it [1:46:06, 38.81s/it]

X_train size: torch.Size([32, 1, 768])	Y_train size: torch.Size([32])
Tensor saved to 'train_document/doc_110_inlegal/embedding'
Tensor saved to 'train_document/doc_110_inlegal/label'


112it [1:46:43, 38.12s/it]

X_train size: torch.Size([121, 1, 768])	Y_train size: torch.Size([121])
Tensor saved to 'train_document/doc_111_inlegal/embedding'
Tensor saved to 'train_document/doc_111_inlegal/label'


113it [1:47:09, 34.59s/it]

X_train size: torch.Size([60, 1, 768])	Y_train size: torch.Size([60])
Tensor saved to 'train_document/doc_112_inlegal/embedding'
Tensor saved to 'train_document/doc_112_inlegal/label'


114it [1:48:19, 45.19s/it]

X_train size: torch.Size([94, 1, 768])	Y_train size: torch.Size([94])
Tensor saved to 'train_document/doc_113_inlegal/embedding'
Tensor saved to 'train_document/doc_113_inlegal/label'


115it [1:49:03, 44.67s/it]

X_train size: torch.Size([121, 1, 768])	Y_train size: torch.Size([121])
Tensor saved to 'train_document/doc_114_inlegal/embedding'
Tensor saved to 'train_document/doc_114_inlegal/label'


116it [1:49:49, 45.13s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_115_inlegal/embedding'
Tensor saved to 'train_document/doc_115_inlegal/label'


117it [1:51:40, 64.94s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_116_inlegal/embedding'
Tensor saved to 'train_document/doc_116_inlegal/label'


118it [1:53:29, 78.27s/it]

X_train size: torch.Size([308, 1, 768])	Y_train size: torch.Size([308])
Tensor saved to 'train_document/doc_117_inlegal/embedding'
Tensor saved to 'train_document/doc_117_inlegal/label'


119it [1:53:52, 61.44s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_118_inlegal/embedding'
Tensor saved to 'train_document/doc_118_inlegal/label'


120it [1:54:42, 58.19s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'train_document/doc_119_inlegal/embedding'
Tensor saved to 'train_document/doc_119_inlegal/label'


121it [1:55:10, 48.92s/it]

X_train size: torch.Size([71, 1, 768])	Y_train size: torch.Size([71])
Tensor saved to 'train_document/doc_120_inlegal/embedding'
Tensor saved to 'train_document/doc_120_inlegal/label'


122it [1:55:36, 42.05s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_121_inlegal/embedding'
Tensor saved to 'train_document/doc_121_inlegal/label'


123it [1:56:56, 53.70s/it]

X_train size: torch.Size([201, 1, 768])	Y_train size: torch.Size([201])
Tensor saved to 'train_document/doc_122_inlegal/embedding'
Tensor saved to 'train_document/doc_122_inlegal/label'


124it [1:57:11, 41.98s/it]

X_train size: torch.Size([31, 1, 768])	Y_train size: torch.Size([31])
Tensor saved to 'train_document/doc_123_inlegal/embedding'
Tensor saved to 'train_document/doc_123_inlegal/label'


125it [1:58:52, 59.69s/it]

X_train size: torch.Size([168, 1, 768])	Y_train size: torch.Size([168])
Tensor saved to 'train_document/doc_124_inlegal/embedding'
Tensor saved to 'train_document/doc_124_inlegal/label'


126it [2:00:15, 66.66s/it]

X_train size: torch.Size([213, 1, 768])	Y_train size: torch.Size([213])
Tensor saved to 'train_document/doc_125_inlegal/embedding'
Tensor saved to 'train_document/doc_125_inlegal/label'


127it [2:00:54, 58.31s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'train_document/doc_126_inlegal/embedding'
Tensor saved to 'train_document/doc_126_inlegal/label'


128it [2:01:48, 57.19s/it]

X_train size: torch.Size([85, 1, 768])	Y_train size: torch.Size([85])
Tensor saved to 'train_document/doc_127_inlegal/embedding'
Tensor saved to 'train_document/doc_127_inlegal/label'


129it [2:02:57, 60.54s/it]

X_train size: torch.Size([174, 1, 768])	Y_train size: torch.Size([174])
Tensor saved to 'train_document/doc_128_inlegal/embedding'
Tensor saved to 'train_document/doc_128_inlegal/label'


130it [2:03:18, 48.71s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_129_inlegal/embedding'
Tensor saved to 'train_document/doc_129_inlegal/label'


131it [2:04:02, 47.31s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'train_document/doc_130_inlegal/embedding'
Tensor saved to 'train_document/doc_130_inlegal/label'


132it [2:05:17, 55.69s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'train_document/doc_131_inlegal/embedding'
Tensor saved to 'train_document/doc_131_inlegal/label'


133it [2:06:19, 57.55s/it]

X_train size: torch.Size([126, 1, 768])	Y_train size: torch.Size([126])
Tensor saved to 'train_document/doc_132_inlegal/embedding'
Tensor saved to 'train_document/doc_132_inlegal/label'


134it [2:07:06, 54.36s/it]

X_train size: torch.Size([158, 1, 768])	Y_train size: torch.Size([158])
Tensor saved to 'train_document/doc_133_inlegal/embedding'
Tensor saved to 'train_document/doc_133_inlegal/label'


135it [2:07:48, 50.80s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_134_inlegal/embedding'
Tensor saved to 'train_document/doc_134_inlegal/label'


136it [2:08:47, 53.10s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'train_document/doc_135_inlegal/embedding'
Tensor saved to 'train_document/doc_135_inlegal/label'


137it [2:09:27, 49.14s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_136_inlegal/embedding'
Tensor saved to 'train_document/doc_136_inlegal/label'


138it [2:10:24, 51.47s/it]

X_train size: torch.Size([98, 1, 768])	Y_train size: torch.Size([98])
Tensor saved to 'train_document/doc_137_inlegal/embedding'
Tensor saved to 'train_document/doc_137_inlegal/label'


139it [2:11:37, 57.98s/it]

X_train size: torch.Size([161, 1, 768])	Y_train size: torch.Size([161])
Tensor saved to 'train_document/doc_138_inlegal/embedding'
Tensor saved to 'train_document/doc_138_inlegal/label'


140it [2:12:11, 50.96s/it]

X_train size: torch.Size([93, 1, 768])	Y_train size: torch.Size([93])
Tensor saved to 'train_document/doc_139_inlegal/embedding'
Tensor saved to 'train_document/doc_139_inlegal/label'


141it [2:12:37, 43.34s/it]

X_train size: torch.Size([67, 1, 768])	Y_train size: torch.Size([67])
Tensor saved to 'train_document/doc_140_inlegal/embedding'
Tensor saved to 'train_document/doc_140_inlegal/label'


142it [2:14:02, 55.89s/it]

X_train size: torch.Size([173, 1, 768])	Y_train size: torch.Size([173])
Tensor saved to 'train_document/doc_141_inlegal/embedding'
Tensor saved to 'train_document/doc_141_inlegal/label'


143it [2:17:02, 93.03s/it]

X_train size: torch.Size([234, 1, 768])	Y_train size: torch.Size([234])
Tensor saved to 'train_document/doc_142_inlegal/embedding'
Tensor saved to 'train_document/doc_142_inlegal/label'


144it [2:20:55, 135.13s/it]

X_train size: torch.Size([153, 1, 768])	Y_train size: torch.Size([153])
Tensor saved to 'train_document/doc_143_inlegal/embedding'
Tensor saved to 'train_document/doc_143_inlegal/label'


145it [2:21:56, 112.77s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_144_inlegal/embedding'
Tensor saved to 'train_document/doc_144_inlegal/label'


146it [2:22:15, 84.69s/it] 

X_train size: torch.Size([42, 1, 768])	Y_train size: torch.Size([42])
Tensor saved to 'train_document/doc_145_inlegal/embedding'
Tensor saved to 'train_document/doc_145_inlegal/label'


147it [2:23:13, 76.76s/it]

X_train size: torch.Size([164, 1, 768])	Y_train size: torch.Size([164])
Tensor saved to 'train_document/doc_146_inlegal/embedding'
Tensor saved to 'train_document/doc_146_inlegal/label'


148it [2:23:32, 59.28s/it]

X_train size: torch.Size([36, 1, 768])	Y_train size: torch.Size([36])
Tensor saved to 'train_document/doc_147_inlegal/embedding'
Tensor saved to 'train_document/doc_147_inlegal/label'


149it [2:23:54, 48.28s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'train_document/doc_148_inlegal/embedding'
Tensor saved to 'train_document/doc_148_inlegal/label'


150it [2:27:15, 93.92s/it]

X_train size: torch.Size([386, 1, 768])	Y_train size: torch.Size([386])
Tensor saved to 'train_document/doc_149_inlegal/embedding'
Tensor saved to 'train_document/doc_149_inlegal/label'


151it [2:28:24, 86.62s/it]

X_train size: torch.Size([119, 1, 768])	Y_train size: torch.Size([119])
Tensor saved to 'train_document/doc_150_inlegal/embedding'
Tensor saved to 'train_document/doc_150_inlegal/label'


152it [2:29:09, 74.00s/it]

X_train size: torch.Size([83, 1, 768])	Y_train size: torch.Size([83])
Tensor saved to 'train_document/doc_151_inlegal/embedding'
Tensor saved to 'train_document/doc_151_inlegal/label'


153it [2:30:58, 84.57s/it]

X_train size: torch.Size([200, 1, 768])	Y_train size: torch.Size([200])
Tensor saved to 'train_document/doc_152_inlegal/embedding'
Tensor saved to 'train_document/doc_152_inlegal/label'


154it [2:33:45, 109.29s/it]

X_train size: torch.Size([159, 1, 768])	Y_train size: torch.Size([159])
Tensor saved to 'train_document/doc_153_inlegal/embedding'
Tensor saved to 'train_document/doc_153_inlegal/label'


155it [2:34:12, 84.42s/it] 

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_154_inlegal/embedding'
Tensor saved to 'train_document/doc_154_inlegal/label'


156it [2:34:42, 68.16s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_155_inlegal/embedding'
Tensor saved to 'train_document/doc_155_inlegal/label'


157it [2:35:09, 55.97s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_156_inlegal/embedding'
Tensor saved to 'train_document/doc_156_inlegal/label'


158it [2:36:04, 55.58s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_157_inlegal/embedding'
Tensor saved to 'train_document/doc_157_inlegal/label'


159it [2:37:07, 57.70s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'train_document/doc_158_inlegal/embedding'
Tensor saved to 'train_document/doc_158_inlegal/label'


160it [2:37:38, 49.94s/it]

X_train size: torch.Size([75, 1, 768])	Y_train size: torch.Size([75])
Tensor saved to 'train_document/doc_159_inlegal/embedding'
Tensor saved to 'train_document/doc_159_inlegal/label'


161it [2:38:40, 53.52s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'train_document/doc_160_inlegal/embedding'
Tensor saved to 'train_document/doc_160_inlegal/label'


162it [2:38:55, 41.78s/it]

X_train size: torch.Size([44, 1, 768])	Y_train size: torch.Size([44])
Tensor saved to 'train_document/doc_161_inlegal/embedding'
Tensor saved to 'train_document/doc_161_inlegal/label'


163it [2:39:35, 41.28s/it]

X_train size: torch.Size([63, 1, 768])	Y_train size: torch.Size([63])
Tensor saved to 'train_document/doc_162_inlegal/embedding'
Tensor saved to 'train_document/doc_162_inlegal/label'


164it [2:40:16, 41.33s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_163_inlegal/embedding'
Tensor saved to 'train_document/doc_163_inlegal/label'


165it [2:41:29, 50.60s/it]

X_train size: torch.Size([190, 1, 768])	Y_train size: torch.Size([190])
Tensor saved to 'train_document/doc_164_inlegal/embedding'
Tensor saved to 'train_document/doc_164_inlegal/label'


166it [2:41:45, 40.36s/it]

X_train size: torch.Size([33, 1, 768])	Y_train size: torch.Size([33])
Tensor saved to 'train_document/doc_165_inlegal/embedding'
Tensor saved to 'train_document/doc_165_inlegal/label'


167it [2:44:03, 69.81s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_166_inlegal/embedding'
Tensor saved to 'train_document/doc_166_inlegal/label'


168it [2:44:51, 63.18s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'train_document/doc_167_inlegal/embedding'
Tensor saved to 'train_document/doc_167_inlegal/label'


169it [2:45:30, 55.74s/it]

X_train size: torch.Size([52, 1, 768])	Y_train size: torch.Size([52])
Tensor saved to 'train_document/doc_168_inlegal/embedding'
Tensor saved to 'train_document/doc_168_inlegal/label'


170it [2:46:16, 52.88s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_169_inlegal/embedding'
Tensor saved to 'train_document/doc_169_inlegal/label'


171it [2:47:48, 64.53s/it]

X_train size: torch.Size([251, 1, 768])	Y_train size: torch.Size([251])
Tensor saved to 'train_document/doc_170_inlegal/embedding'
Tensor saved to 'train_document/doc_170_inlegal/label'


172it [2:48:41, 61.15s/it]

X_train size: torch.Size([117, 1, 768])	Y_train size: torch.Size([117])
Tensor saved to 'train_document/doc_171_inlegal/embedding'
Tensor saved to 'train_document/doc_171_inlegal/label'


173it [2:49:32, 58.19s/it]

X_train size: torch.Size([146, 1, 768])	Y_train size: torch.Size([146])
Tensor saved to 'train_document/doc_172_inlegal/embedding'
Tensor saved to 'train_document/doc_172_inlegal/label'


174it [2:50:03, 49.94s/it]

X_train size: torch.Size([84, 1, 768])	Y_train size: torch.Size([84])
Tensor saved to 'train_document/doc_173_inlegal/embedding'
Tensor saved to 'train_document/doc_173_inlegal/label'


175it [2:52:23, 77.15s/it]

X_train size: torch.Size([223, 1, 768])	Y_train size: torch.Size([223])
Tensor saved to 'train_document/doc_174_inlegal/embedding'
Tensor saved to 'train_document/doc_174_inlegal/label'


176it [2:52:50, 62.11s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'train_document/doc_175_inlegal/embedding'
Tensor saved to 'train_document/doc_175_inlegal/label'


177it [2:55:55, 98.97s/it]

X_train size: torch.Size([248, 1, 768])	Y_train size: torch.Size([248])
Tensor saved to 'train_document/doc_176_inlegal/embedding'
Tensor saved to 'train_document/doc_176_inlegal/label'


178it [2:56:31, 80.09s/it]

X_train size: torch.Size([69, 1, 768])	Y_train size: torch.Size([69])
Tensor saved to 'train_document/doc_177_inlegal/embedding'
Tensor saved to 'train_document/doc_177_inlegal/label'


179it [2:57:31, 73.86s/it]

X_train size: torch.Size([89, 1, 768])	Y_train size: torch.Size([89])
Tensor saved to 'train_document/doc_178_inlegal/embedding'
Tensor saved to 'train_document/doc_178_inlegal/label'


180it [2:58:13, 64.29s/it]

X_train size: torch.Size([124, 1, 768])	Y_train size: torch.Size([124])
Tensor saved to 'train_document/doc_179_inlegal/embedding'
Tensor saved to 'train_document/doc_179_inlegal/label'


181it [2:58:29, 49.89s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_180_inlegal/embedding'
Tensor saved to 'train_document/doc_180_inlegal/label'


182it [2:59:06, 46.01s/it]

X_train size: torch.Size([95, 1, 768])	Y_train size: torch.Size([95])
Tensor saved to 'train_document/doc_181_inlegal/embedding'
Tensor saved to 'train_document/doc_181_inlegal/label'


183it [2:59:48, 44.88s/it]

X_train size: torch.Size([131, 1, 768])	Y_train size: torch.Size([131])
Tensor saved to 'train_document/doc_182_inlegal/embedding'
Tensor saved to 'train_document/doc_182_inlegal/label'


184it [3:00:43, 47.89s/it]

X_train size: torch.Size([117, 1, 768])	Y_train size: torch.Size([117])
Tensor saved to 'train_document/doc_183_inlegal/embedding'
Tensor saved to 'train_document/doc_183_inlegal/label'


185it [3:00:55, 37.04s/it]

X_train size: torch.Size([27, 1, 768])	Y_train size: torch.Size([27])
Tensor saved to 'train_document/doc_184_inlegal/embedding'
Tensor saved to 'train_document/doc_184_inlegal/label'


186it [3:01:41, 39.68s/it]

X_train size: torch.Size([106, 1, 768])	Y_train size: torch.Size([106])
Tensor saved to 'train_document/doc_185_inlegal/embedding'
Tensor saved to 'train_document/doc_185_inlegal/label'


187it [3:03:13, 55.48s/it]

X_train size: torch.Size([244, 1, 768])	Y_train size: torch.Size([244])
Tensor saved to 'train_document/doc_186_inlegal/embedding'
Tensor saved to 'train_document/doc_186_inlegal/label'


188it [3:03:38, 46.33s/it]

X_train size: torch.Size([87, 1, 768])	Y_train size: torch.Size([87])
Tensor saved to 'train_document/doc_187_inlegal/embedding'
Tensor saved to 'train_document/doc_187_inlegal/label'


189it [3:04:11, 42.33s/it]

X_train size: torch.Size([78, 1, 768])	Y_train size: torch.Size([78])
Tensor saved to 'train_document/doc_188_inlegal/embedding'
Tensor saved to 'train_document/doc_188_inlegal/label'


190it [3:05:24, 51.43s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'train_document/doc_189_inlegal/embedding'
Tensor saved to 'train_document/doc_189_inlegal/label'


191it [3:05:39, 40.72s/it]

X_train size: torch.Size([36, 1, 768])	Y_train size: torch.Size([36])
Tensor saved to 'train_document/doc_190_inlegal/embedding'
Tensor saved to 'train_document/doc_190_inlegal/label'


192it [3:06:27, 42.86s/it]

X_train size: torch.Size([128, 1, 768])	Y_train size: torch.Size([128])
Tensor saved to 'train_document/doc_191_inlegal/embedding'
Tensor saved to 'train_document/doc_191_inlegal/label'


193it [3:07:56, 56.48s/it]

X_train size: torch.Size([154, 1, 768])	Y_train size: torch.Size([154])
Tensor saved to 'train_document/doc_192_inlegal/embedding'
Tensor saved to 'train_document/doc_192_inlegal/label'


194it [3:08:30, 49.78s/it]

X_train size: torch.Size([115, 1, 768])	Y_train size: torch.Size([115])
Tensor saved to 'train_document/doc_193_inlegal/embedding'
Tensor saved to 'train_document/doc_193_inlegal/label'


195it [3:08:52, 41.61s/it]

X_train size: torch.Size([61, 1, 768])	Y_train size: torch.Size([61])
Tensor saved to 'train_document/doc_194_inlegal/embedding'
Tensor saved to 'train_document/doc_194_inlegal/label'


196it [3:11:34, 77.64s/it]

X_train size: torch.Size([245, 1, 768])	Y_train size: torch.Size([245])
Tensor saved to 'train_document/doc_195_inlegal/embedding'
Tensor saved to 'train_document/doc_195_inlegal/label'


197it [3:12:27, 70.23s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'train_document/doc_196_inlegal/embedding'
Tensor saved to 'train_document/doc_196_inlegal/label'


198it [3:13:03, 60.04s/it]

X_train size: torch.Size([91, 1, 768])	Y_train size: torch.Size([91])
Tensor saved to 'train_document/doc_197_inlegal/embedding'
Tensor saved to 'train_document/doc_197_inlegal/label'


199it [3:13:46, 54.93s/it]

X_train size: torch.Size([120, 1, 768])	Y_train size: torch.Size([120])
Tensor saved to 'train_document/doc_198_inlegal/embedding'
Tensor saved to 'train_document/doc_198_inlegal/label'


200it [3:14:08, 44.90s/it]

X_train size: torch.Size([89, 1, 768])	Y_train size: torch.Size([89])
Tensor saved to 'train_document/doc_199_inlegal/embedding'
Tensor saved to 'train_document/doc_199_inlegal/label'


201it [3:14:39, 40.76s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_200_inlegal/embedding'
Tensor saved to 'train_document/doc_200_inlegal/label'


202it [3:15:09, 37.64s/it]

X_train size: torch.Size([74, 1, 768])	Y_train size: torch.Size([74])
Tensor saved to 'train_document/doc_201_inlegal/embedding'
Tensor saved to 'train_document/doc_201_inlegal/label'


203it [3:15:55, 40.20s/it]

X_train size: torch.Size([92, 1, 768])	Y_train size: torch.Size([92])
Tensor saved to 'train_document/doc_202_inlegal/embedding'
Tensor saved to 'train_document/doc_202_inlegal/label'


204it [3:17:33, 57.44s/it]

X_train size: torch.Size([240, 1, 768])	Y_train size: torch.Size([240])
Tensor saved to 'train_document/doc_203_inlegal/embedding'
Tensor saved to 'train_document/doc_203_inlegal/label'


205it [3:18:57, 65.40s/it]

X_train size: torch.Size([102, 1, 768])	Y_train size: torch.Size([102])
Tensor saved to 'train_document/doc_204_inlegal/embedding'
Tensor saved to 'train_document/doc_204_inlegal/label'


206it [3:19:48, 61.19s/it]

X_train size: torch.Size([107, 1, 768])	Y_train size: torch.Size([107])
Tensor saved to 'train_document/doc_205_inlegal/embedding'
Tensor saved to 'train_document/doc_205_inlegal/label'


207it [3:20:10, 49.33s/it]

X_train size: torch.Size([79, 1, 768])	Y_train size: torch.Size([79])
Tensor saved to 'train_document/doc_206_inlegal/embedding'
Tensor saved to 'train_document/doc_206_inlegal/label'


208it [3:21:48, 63.97s/it]

X_train size: torch.Size([168, 1, 768])	Y_train size: torch.Size([168])
Tensor saved to 'train_document/doc_207_inlegal/embedding'
Tensor saved to 'train_document/doc_207_inlegal/label'


209it [3:23:58, 83.68s/it]

X_train size: torch.Size([181, 1, 768])	Y_train size: torch.Size([181])
Tensor saved to 'train_document/doc_208_inlegal/embedding'
Tensor saved to 'train_document/doc_208_inlegal/label'


210it [3:24:12, 62.78s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'train_document/doc_209_inlegal/embedding'
Tensor saved to 'train_document/doc_209_inlegal/label'


211it [3:26:11, 79.64s/it]

X_train size: torch.Size([248, 1, 768])	Y_train size: torch.Size([248])
Tensor saved to 'train_document/doc_210_inlegal/embedding'
Tensor saved to 'train_document/doc_210_inlegal/label'


212it [3:27:40, 82.60s/it]

X_train size: torch.Size([141, 1, 768])	Y_train size: torch.Size([141])
Tensor saved to 'train_document/doc_211_inlegal/embedding'
Tensor saved to 'train_document/doc_211_inlegal/label'


213it [3:28:08, 66.05s/it]

X_train size: torch.Size([47, 1, 768])	Y_train size: torch.Size([47])
Tensor saved to 'train_document/doc_212_inlegal/embedding'
Tensor saved to 'train_document/doc_212_inlegal/label'


214it [3:28:43, 56.99s/it]

X_train size: torch.Size([81, 1, 768])	Y_train size: torch.Size([81])
Tensor saved to 'train_document/doc_213_inlegal/embedding'
Tensor saved to 'train_document/doc_213_inlegal/label'


215it [3:29:34, 55.11s/it]

X_train size: torch.Size([82, 1, 768])	Y_train size: torch.Size([82])
Tensor saved to 'train_document/doc_214_inlegal/embedding'
Tensor saved to 'train_document/doc_214_inlegal/label'


216it [3:29:50, 43.46s/it]

X_train size: torch.Size([60, 1, 768])	Y_train size: torch.Size([60])
Tensor saved to 'train_document/doc_215_inlegal/embedding'
Tensor saved to 'train_document/doc_215_inlegal/label'


217it [3:31:11, 54.47s/it]

X_train size: torch.Size([179, 1, 768])	Y_train size: torch.Size([179])
Tensor saved to 'train_document/doc_216_inlegal/embedding'
Tensor saved to 'train_document/doc_216_inlegal/label'


218it [3:31:49, 49.58s/it]

X_train size: torch.Size([65, 1, 768])	Y_train size: torch.Size([65])
Tensor saved to 'train_document/doc_217_inlegal/embedding'
Tensor saved to 'train_document/doc_217_inlegal/label'


219it [3:32:58, 55.43s/it]

X_train size: torch.Size([176, 1, 768])	Y_train size: torch.Size([176])
Tensor saved to 'train_document/doc_218_inlegal/embedding'
Tensor saved to 'train_document/doc_218_inlegal/label'


220it [3:35:29, 84.02s/it]

X_train size: torch.Size([283, 1, 768])	Y_train size: torch.Size([283])
Tensor saved to 'train_document/doc_219_inlegal/embedding'
Tensor saved to 'train_document/doc_219_inlegal/label'


221it [3:37:10, 89.38s/it]

X_train size: torch.Size([123, 1, 768])	Y_train size: torch.Size([123])
Tensor saved to 'train_document/doc_220_inlegal/embedding'
Tensor saved to 'train_document/doc_220_inlegal/label'


222it [3:38:49, 92.11s/it]

X_train size: torch.Size([263, 1, 768])	Y_train size: torch.Size([263])
Tensor saved to 'train_document/doc_221_inlegal/embedding'
Tensor saved to 'train_document/doc_221_inlegal/label'


223it [3:39:10, 70.71s/it]

X_train size: torch.Size([58, 1, 768])	Y_train size: torch.Size([58])
Tensor saved to 'train_document/doc_222_inlegal/embedding'
Tensor saved to 'train_document/doc_222_inlegal/label'


224it [3:41:33, 92.59s/it]

X_train size: torch.Size([193, 1, 768])	Y_train size: torch.Size([193])
Tensor saved to 'train_document/doc_223_inlegal/embedding'
Tensor saved to 'train_document/doc_223_inlegal/label'


225it [3:42:25, 80.36s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'train_document/doc_224_inlegal/embedding'
Tensor saved to 'train_document/doc_224_inlegal/label'
X_train size: torch.Size([0, 1, 768])	Y_train size: torch.Size([0])
Tensor saved to 'train_document/doc_225_inlegal/embedding'
Tensor saved to 'train_document/doc_225_inlegal/label'


227it [3:43:08, 53.04s/it]

X_train size: torch.Size([148, 1, 768])	Y_train size: torch.Size([148])
Tensor saved to 'train_document/doc_226_inlegal/embedding'
Tensor saved to 'train_document/doc_226_inlegal/label'


228it [3:43:26, 44.39s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_227_inlegal/embedding'
Tensor saved to 'train_document/doc_227_inlegal/label'


229it [3:44:38, 51.59s/it]

X_train size: torch.Size([255, 1, 768])	Y_train size: torch.Size([255])
Tensor saved to 'train_document/doc_228_inlegal/embedding'
Tensor saved to 'train_document/doc_228_inlegal/label'


230it [3:45:03, 44.33s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'train_document/doc_229_inlegal/embedding'
Tensor saved to 'train_document/doc_229_inlegal/label'


231it [3:45:24, 37.92s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'train_document/doc_230_inlegal/embedding'
Tensor saved to 'train_document/doc_230_inlegal/label'


232it [3:47:02, 55.04s/it]

X_train size: torch.Size([214, 1, 768])	Y_train size: torch.Size([214])
Tensor saved to 'train_document/doc_231_inlegal/embedding'
Tensor saved to 'train_document/doc_231_inlegal/label'


233it [3:48:35, 66.05s/it]

X_train size: torch.Size([246, 1, 768])	Y_train size: torch.Size([246])
Tensor saved to 'train_document/doc_232_inlegal/embedding'
Tensor saved to 'train_document/doc_232_inlegal/label'


234it [3:49:40, 65.73s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_233_inlegal/embedding'
Tensor saved to 'train_document/doc_233_inlegal/label'


235it [3:50:36, 62.78s/it]

X_train size: torch.Size([110, 1, 768])	Y_train size: torch.Size([110])
Tensor saved to 'train_document/doc_234_inlegal/embedding'
Tensor saved to 'train_document/doc_234_inlegal/label'


236it [3:50:58, 50.91s/it]

X_train size: torch.Size([67, 1, 768])	Y_train size: torch.Size([67])
Tensor saved to 'train_document/doc_235_inlegal/embedding'
Tensor saved to 'train_document/doc_235_inlegal/label'


237it [3:51:15, 40.79s/it]

X_train size: torch.Size([73, 1, 768])	Y_train size: torch.Size([73])
Tensor saved to 'train_document/doc_236_inlegal/embedding'
Tensor saved to 'train_document/doc_236_inlegal/label'


238it [3:51:40, 36.11s/it]

X_train size: torch.Size([56, 1, 768])	Y_train size: torch.Size([56])
Tensor saved to 'train_document/doc_237_inlegal/embedding'
Tensor saved to 'train_document/doc_237_inlegal/label'


239it [3:53:22, 55.58s/it]

X_train size: torch.Size([212, 1, 768])	Y_train size: torch.Size([212])
Tensor saved to 'train_document/doc_238_inlegal/embedding'
Tensor saved to 'train_document/doc_238_inlegal/label'


240it [3:55:17, 73.44s/it]

X_train size: torch.Size([215, 1, 768])	Y_train size: torch.Size([215])
Tensor saved to 'train_document/doc_239_inlegal/embedding'
Tensor saved to 'train_document/doc_239_inlegal/label'


241it [3:55:53, 62.15s/it]

X_train size: torch.Size([103, 1, 768])	Y_train size: torch.Size([103])
Tensor saved to 'train_document/doc_240_inlegal/embedding'
Tensor saved to 'train_document/doc_240_inlegal/label'


242it [3:56:06, 47.53s/it]

X_train size: torch.Size([45, 1, 768])	Y_train size: torch.Size([45])
Tensor saved to 'train_document/doc_241_inlegal/embedding'
Tensor saved to 'train_document/doc_241_inlegal/label'


243it [3:56:40, 43.47s/it]

X_train size: torch.Size([112, 1, 768])	Y_train size: torch.Size([112])
Tensor saved to 'train_document/doc_242_inlegal/embedding'
Tensor saved to 'train_document/doc_242_inlegal/label'


244it [3:57:03, 37.45s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'train_document/doc_243_inlegal/embedding'
Tensor saved to 'train_document/doc_243_inlegal/label'


245it [3:58:13, 47.13s/it]

X_train size: torch.Size([90, 1, 768])	Y_train size: torch.Size([90])
Tensor saved to 'train_document/doc_244_inlegal/embedding'
Tensor saved to 'train_document/doc_244_inlegal/label'


246it [3:59:33, 58.43s/it]


X_train size: torch.Size([122, 1, 768])	Y_train size: torch.Size([122])
Tensor saved to 'train_document/doc_245_inlegal/embedding'
Tensor saved to 'train_document/doc_245_inlegal/label'


1it [01:09, 69.86s/it]

X_train size: torch.Size([96, 1, 768])	Y_train size: torch.Size([96])
Tensor saved to 'test_document/doc_0_inlegal/embedding'
Tensor saved to 'test_document/doc_0_inlegal/label'


2it [02:05, 61.62s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'test_document/doc_1_inlegal/embedding'
Tensor saved to 'test_document/doc_1_inlegal/label'


3it [03:00, 58.62s/it]

X_train size: torch.Size([150, 1, 768])	Y_train size: torch.Size([150])
Tensor saved to 'test_document/doc_2_inlegal/embedding'
Tensor saved to 'test_document/doc_2_inlegal/label'


4it [03:24, 44.86s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'test_document/doc_3_inlegal/embedding'
Tensor saved to 'test_document/doc_3_inlegal/label'


5it [05:14, 68.28s/it]

X_train size: torch.Size([97, 1, 768])	Y_train size: torch.Size([97])
Tensor saved to 'test_document/doc_4_inlegal/embedding'
Tensor saved to 'test_document/doc_4_inlegal/label'


6it [05:33, 51.56s/it]

X_train size: torch.Size([57, 1, 768])	Y_train size: torch.Size([57])
Tensor saved to 'test_document/doc_5_inlegal/embedding'
Tensor saved to 'test_document/doc_5_inlegal/label'


7it [06:00, 43.40s/it]

X_train size: torch.Size([68, 1, 768])	Y_train size: torch.Size([68])
Tensor saved to 'test_document/doc_6_inlegal/embedding'
Tensor saved to 'test_document/doc_6_inlegal/label'


8it [06:31, 39.58s/it]

X_train size: torch.Size([113, 1, 768])	Y_train size: torch.Size([113])
Tensor saved to 'test_document/doc_7_inlegal/embedding'
Tensor saved to 'test_document/doc_7_inlegal/label'


9it [08:17, 60.29s/it]

X_train size: torch.Size([199, 1, 768])	Y_train size: torch.Size([199])
Tensor saved to 'test_document/doc_8_inlegal/embedding'
Tensor saved to 'test_document/doc_8_inlegal/label'


10it [09:02, 55.52s/it]

X_train size: torch.Size([139, 1, 768])	Y_train size: torch.Size([139])
Tensor saved to 'test_document/doc_9_inlegal/embedding'
Tensor saved to 'test_document/doc_9_inlegal/label'


11it [09:25, 45.63s/it]

X_train size: torch.Size([76, 1, 768])	Y_train size: torch.Size([76])
Tensor saved to 'test_document/doc_10_inlegal/embedding'
Tensor saved to 'test_document/doc_10_inlegal/label'


12it [09:53, 40.29s/it]

X_train size: torch.Size([104, 1, 768])	Y_train size: torch.Size([104])
Tensor saved to 'test_document/doc_11_inlegal/embedding'
Tensor saved to 'test_document/doc_11_inlegal/label'


13it [11:08, 50.71s/it]

X_train size: torch.Size([209, 1, 768])	Y_train size: torch.Size([209])
Tensor saved to 'test_document/doc_12_inlegal/embedding'
Tensor saved to 'test_document/doc_12_inlegal/label'


14it [11:45, 46.67s/it]

X_train size: torch.Size([135, 1, 768])	Y_train size: torch.Size([135])
Tensor saved to 'test_document/doc_13_inlegal/embedding'
Tensor saved to 'test_document/doc_13_inlegal/label'


15it [12:13, 40.95s/it]

X_train size: torch.Size([64, 1, 768])	Y_train size: torch.Size([64])
Tensor saved to 'test_document/doc_14_inlegal/embedding'
Tensor saved to 'test_document/doc_14_inlegal/label'


16it [12:57, 42.02s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'test_document/doc_15_inlegal/embedding'
Tensor saved to 'test_document/doc_15_inlegal/label'


17it [13:36, 41.11s/it]

X_train size: torch.Size([98, 1, 768])	Y_train size: torch.Size([98])
Tensor saved to 'test_document/doc_16_inlegal/embedding'
Tensor saved to 'test_document/doc_16_inlegal/label'


18it [14:28, 44.34s/it]

X_train size: torch.Size([111, 1, 768])	Y_train size: torch.Size([111])
Tensor saved to 'test_document/doc_17_inlegal/embedding'
Tensor saved to 'test_document/doc_17_inlegal/label'


19it [14:58, 40.00s/it]

X_train size: torch.Size([62, 1, 768])	Y_train size: torch.Size([62])
Tensor saved to 'test_document/doc_18_inlegal/embedding'
Tensor saved to 'test_document/doc_18_inlegal/label'


20it [16:14, 50.95s/it]

X_train size: torch.Size([130, 1, 768])	Y_train size: torch.Size([130])
Tensor saved to 'test_document/doc_19_inlegal/embedding'
Tensor saved to 'test_document/doc_19_inlegal/label'


21it [16:34, 41.52s/it]

X_train size: torch.Size([46, 1, 768])	Y_train size: torch.Size([46])
Tensor saved to 'test_document/doc_20_inlegal/embedding'
Tensor saved to 'test_document/doc_20_inlegal/label'


22it [17:04, 38.03s/it]

X_train size: torch.Size([66, 1, 768])	Y_train size: torch.Size([66])
Tensor saved to 'test_document/doc_21_inlegal/embedding'
Tensor saved to 'test_document/doc_21_inlegal/label'


23it [17:23, 32.49s/it]

X_train size: torch.Size([77, 1, 768])	Y_train size: torch.Size([77])
Tensor saved to 'test_document/doc_22_inlegal/embedding'
Tensor saved to 'test_document/doc_22_inlegal/label'


24it [17:34, 25.96s/it]

X_train size: torch.Size([53, 1, 768])	Y_train size: torch.Size([53])
Tensor saved to 'test_document/doc_23_inlegal/embedding'
Tensor saved to 'test_document/doc_23_inlegal/label'


25it [18:18, 31.28s/it]

X_train size: torch.Size([112, 1, 768])	Y_train size: torch.Size([112])
Tensor saved to 'test_document/doc_24_inlegal/embedding'
Tensor saved to 'test_document/doc_24_inlegal/label'


26it [19:37, 45.82s/it]

X_train size: torch.Size([186, 1, 768])	Y_train size: torch.Size([186])
Tensor saved to 'test_document/doc_25_inlegal/embedding'
Tensor saved to 'test_document/doc_25_inlegal/label'


27it [19:55, 37.24s/it]

X_train size: torch.Size([54, 1, 768])	Y_train size: torch.Size([54])
Tensor saved to 'test_document/doc_26_inlegal/embedding'
Tensor saved to 'test_document/doc_26_inlegal/label'


28it [20:09, 30.42s/it]

X_train size: torch.Size([43, 1, 768])	Y_train size: torch.Size([43])
Tensor saved to 'test_document/doc_27_inlegal/embedding'
Tensor saved to 'test_document/doc_27_inlegal/label'


29it [20:26, 42.30s/it]

X_train size: torch.Size([59, 1, 768])	Y_train size: torch.Size([59])
Tensor saved to 'test_document/doc_28_inlegal/embedding'
Tensor saved to 'test_document/doc_28_inlegal/label'





## BiLSTM

In [126]:
while_i = 0
accs5 = []
f1s5 = []
micro_accs2 = []
micro_f1s2 = []
while while_i < 3:
    while_i += 1
    model5 = BiLSTM(hidden_size=128, dropout= 0.3, output_size= 11)
    optimizer = torch.optim.Adam(model5.parameters(), lr= 5e-4)
    loss_function = nn.CrossEntropyLoss(weight= class_weights)

    print(f'{"Starting Training":-^100}')
    model5.train()
    loss_list = []
    for epoch in range(100):
        running_loss = 0
        for idx in tqdm(range(246)):
            TRAIN_emb = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/embedding")
            TRAIN_labels = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/label")
            TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
            if TRAIN_emb.size(0) == 0:
                continue
            output = model5(TRAIN_emb)
            loss = loss_function(output,TRAIN_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        # scheduler.step()
        # scheduler1.step()
        # scheduler2.step()
        # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
        loss_list.append(running_loss/246)
        print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
        if running_loss/246 < 0.025:
            break
    cm = None
    for i in range(29):
        TEST_emb = load_tensor(filepath=f"test_document/doc_{i}_inlegal/embedding")
        TEST_labels = load_tensor(filepath=f"test_document/doc_{i}_inlegal/label")
        TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
        conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model5, num_labels= 11)
        if cm is None:
            cm = conf_matrix_helper
        else:
            cm = np.add(cm, conf_matrix_helper)
            
    accuracies = class_accuracy(cm)
    f1_scores = class_f1_score(cm)
    average_accuracy = np.mean(accuracies)
    average_f1 = np.mean(f1_scores)
    
    accs5.append(average_accuracy)
    f1s5.append(average_f1)
    micro_accs2.append(accuracies)
    micro_f1s2.append(f1_scores)

    print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
    print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

-----------------------------------------Starting Training------------------------------------------


  1%|          | 3/246 [00:00<00:11, 21.93it/s]

100%|██████████| 246/246 [00:07<00:00, 31.68it/s]


Epoch: 1 	 Loss: 1.43463


100%|██████████| 246/246 [00:08<00:00, 30.32it/s]


Epoch: 2 	 Loss: 0.83848


100%|██████████| 246/246 [00:12<00:00, 20.43it/s]


Epoch: 3 	 Loss: 0.69454


100%|██████████| 246/246 [00:10<00:00, 22.79it/s]


Epoch: 4 	 Loss: 0.57944


100%|██████████| 246/246 [00:08<00:00, 30.04it/s]


Epoch: 5 	 Loss: 0.48422


100%|██████████| 246/246 [00:09<00:00, 25.74it/s]


Epoch: 6 	 Loss: 0.41423


100%|██████████| 246/246 [00:08<00:00, 27.46it/s]


Epoch: 7 	 Loss: 0.36253


100%|██████████| 246/246 [00:08<00:00, 29.80it/s]


Epoch: 8 	 Loss: 0.30634


100%|██████████| 246/246 [00:10<00:00, 24.56it/s]


Epoch: 9 	 Loss: 0.26962


100%|██████████| 246/246 [00:09<00:00, 26.74it/s]


Epoch: 10 	 Loss: 0.24427


100%|██████████| 246/246 [00:09<00:00, 26.14it/s]


Epoch: 11 	 Loss: 0.21338


100%|██████████| 246/246 [00:09<00:00, 26.50it/s]


Epoch: 12 	 Loss: 0.19964


100%|██████████| 246/246 [00:11<00:00, 22.21it/s]


Epoch: 13 	 Loss: 0.17547


100%|██████████| 246/246 [00:10<00:00, 22.52it/s]


Epoch: 14 	 Loss: 0.15063


100%|██████████| 246/246 [00:10<00:00, 24.45it/s]


Epoch: 15 	 Loss: 0.13587


100%|██████████| 246/246 [00:08<00:00, 27.64it/s]


Epoch: 16 	 Loss: 0.13062


100%|██████████| 246/246 [00:09<00:00, 27.02it/s]


Epoch: 17 	 Loss: 0.12694


100%|██████████| 246/246 [00:07<00:00, 31.02it/s]


Epoch: 18 	 Loss: 0.11018


100%|██████████| 246/246 [00:07<00:00, 32.83it/s]


Epoch: 19 	 Loss: 0.10678


100%|██████████| 246/246 [00:10<00:00, 24.54it/s]


Epoch: 20 	 Loss: 0.08950


100%|██████████| 246/246 [00:08<00:00, 27.56it/s]


Epoch: 21 	 Loss: 0.07653


100%|██████████| 246/246 [00:09<00:00, 27.23it/s]


Epoch: 22 	 Loss: 0.07522


100%|██████████| 246/246 [00:09<00:00, 26.47it/s]


Epoch: 23 	 Loss: 0.08844


100%|██████████| 246/246 [00:10<00:00, 22.49it/s]


Epoch: 24 	 Loss: 0.07098


100%|██████████| 246/246 [00:08<00:00, 27.52it/s]


Epoch: 25 	 Loss: 0.06252


100%|██████████| 246/246 [00:09<00:00, 26.13it/s]


Epoch: 26 	 Loss: 0.07693


100%|██████████| 246/246 [00:09<00:00, 24.83it/s]


Epoch: 27 	 Loss: 0.07376


100%|██████████| 246/246 [00:08<00:00, 29.82it/s]


Epoch: 28 	 Loss: 0.06012


100%|██████████| 246/246 [00:08<00:00, 29.44it/s]


Epoch: 29 	 Loss: 0.05444


100%|██████████| 246/246 [00:09<00:00, 27.05it/s]


Epoch: 30 	 Loss: 0.04575


100%|██████████| 246/246 [00:07<00:00, 30.92it/s]


Epoch: 31 	 Loss: 0.05145


100%|██████████| 246/246 [00:08<00:00, 28.24it/s]


Epoch: 32 	 Loss: 0.04648


100%|██████████| 246/246 [00:08<00:00, 27.36it/s]


Epoch: 33 	 Loss: 0.03861


100%|██████████| 246/246 [00:09<00:00, 26.50it/s]


Epoch: 34 	 Loss: 0.03314


100%|██████████| 246/246 [00:07<00:00, 31.01it/s]


Epoch: 35 	 Loss: 0.03307


100%|██████████| 246/246 [00:10<00:00, 24.52it/s]


Epoch: 36 	 Loss: 0.03962


100%|██████████| 246/246 [00:09<00:00, 26.19it/s]


Epoch: 37 	 Loss: 0.05198


100%|██████████| 246/246 [00:08<00:00, 30.04it/s]


Epoch: 38 	 Loss: 0.04068


100%|██████████| 246/246 [00:08<00:00, 30.54it/s]


Epoch: 39 	 Loss: 0.03742


100%|██████████| 246/246 [00:09<00:00, 25.50it/s]


Epoch: 40 	 Loss: 0.09739


100%|██████████| 246/246 [00:08<00:00, 28.89it/s]


Epoch: 41 	 Loss: 0.05223


100%|██████████| 246/246 [00:08<00:00, 29.70it/s]


Epoch: 42 	 Loss: 0.03150


100%|██████████| 246/246 [00:09<00:00, 27.07it/s]


Epoch: 43 	 Loss: 0.02177
Accuracies: [0.83435583 0.5106383  0.77325581 0.7037037  0.91758242 0.42955326
 0.98821218 0.5        0.62962963 0.81632653 0.51428571] 
 Average acccuracy: 0.6925039437465247
F1 Scores: [0.75977649 0.48484843 0.83977896 0.73076918 0.89784941 0.56433404
 0.99308978 0.47826082 0.23776221 0.860215   0.56249995] 
 Average F1: 0.6735622061892729
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:08<00:00, 27.35it/s]


Epoch: 1 	 Loss: 1.39625


100%|██████████| 246/246 [00:08<00:00, 29.57it/s]


Epoch: 2 	 Loss: 0.85253


100%|██████████| 246/246 [00:08<00:00, 29.53it/s]


Epoch: 3 	 Loss: 0.65761


100%|██████████| 246/246 [00:09<00:00, 26.96it/s]


Epoch: 4 	 Loss: 0.55686


100%|██████████| 246/246 [00:08<00:00, 29.73it/s]


Epoch: 5 	 Loss: 0.47331


100%|██████████| 246/246 [00:08<00:00, 29.22it/s]


Epoch: 6 	 Loss: 0.39728


100%|██████████| 246/246 [00:10<00:00, 23.75it/s]


Epoch: 7 	 Loss: 0.33891


100%|██████████| 246/246 [00:09<00:00, 25.17it/s]


Epoch: 8 	 Loss: 0.30355


100%|██████████| 246/246 [00:10<00:00, 22.88it/s]


Epoch: 9 	 Loss: 0.27347


100%|██████████| 246/246 [00:08<00:00, 28.63it/s]


Epoch: 10 	 Loss: 0.25531


100%|██████████| 246/246 [00:08<00:00, 27.58it/s]


Epoch: 11 	 Loss: 0.23079


100%|██████████| 246/246 [00:08<00:00, 28.07it/s]


Epoch: 12 	 Loss: 0.19990


100%|██████████| 246/246 [00:08<00:00, 28.88it/s]


Epoch: 13 	 Loss: 0.18247


100%|██████████| 246/246 [00:08<00:00, 28.78it/s]


Epoch: 14 	 Loss: 0.15681


100%|██████████| 246/246 [00:09<00:00, 25.04it/s]


Epoch: 15 	 Loss: 0.14849


100%|██████████| 246/246 [00:09<00:00, 26.52it/s]


Epoch: 16 	 Loss: 0.12116


100%|██████████| 246/246 [00:09<00:00, 24.63it/s]


Epoch: 17 	 Loss: 0.11698


100%|██████████| 246/246 [00:08<00:00, 29.61it/s]


Epoch: 18 	 Loss: 0.10153


100%|██████████| 246/246 [00:08<00:00, 28.06it/s]


Epoch: 19 	 Loss: 0.09634


100%|██████████| 246/246 [00:08<00:00, 30.25it/s]


Epoch: 20 	 Loss: 0.08888


100%|██████████| 246/246 [00:09<00:00, 24.64it/s]


Epoch: 21 	 Loss: 0.08048


100%|██████████| 246/246 [00:09<00:00, 26.18it/s]


Epoch: 22 	 Loss: 0.07972


100%|██████████| 246/246 [00:09<00:00, 27.05it/s]


Epoch: 23 	 Loss: 0.07894


100%|██████████| 246/246 [00:09<00:00, 26.42it/s]


Epoch: 24 	 Loss: 0.07762


100%|██████████| 246/246 [00:08<00:00, 28.96it/s]


Epoch: 25 	 Loss: 0.08795


100%|██████████| 246/246 [00:09<00:00, 27.25it/s]


Epoch: 26 	 Loss: 0.11676


100%|██████████| 246/246 [00:08<00:00, 27.87it/s]


Epoch: 27 	 Loss: 0.08426


100%|██████████| 246/246 [00:08<00:00, 28.41it/s]


Epoch: 28 	 Loss: 0.06535


100%|██████████| 246/246 [00:11<00:00, 21.29it/s]


Epoch: 29 	 Loss: 0.05025


100%|██████████| 246/246 [00:09<00:00, 25.03it/s]


Epoch: 30 	 Loss: 0.04622


100%|██████████| 246/246 [00:08<00:00, 28.79it/s]


Epoch: 31 	 Loss: 0.03886


100%|██████████| 246/246 [00:08<00:00, 28.52it/s]


Epoch: 32 	 Loss: 0.03580


100%|██████████| 246/246 [00:10<00:00, 23.97it/s]


Epoch: 33 	 Loss: 0.03649


100%|██████████| 246/246 [00:09<00:00, 26.66it/s]


Epoch: 34 	 Loss: 0.03963


100%|██████████| 246/246 [00:07<00:00, 33.47it/s]


Epoch: 35 	 Loss: 0.03184


100%|██████████| 246/246 [00:09<00:00, 25.44it/s]


Epoch: 36 	 Loss: 0.03641


100%|██████████| 246/246 [00:10<00:00, 23.88it/s]


Epoch: 37 	 Loss: 0.06050


100%|██████████| 246/246 [00:09<00:00, 27.14it/s]


Epoch: 38 	 Loss: 0.04982


100%|██████████| 246/246 [00:09<00:00, 26.92it/s]


Epoch: 39 	 Loss: 0.03602


100%|██████████| 246/246 [00:08<00:00, 29.75it/s]


Epoch: 40 	 Loss: 0.02197
Accuracies: [0.80673181 0.625      0.80713128 0.76470588 0.90425532 0.46218487
 0.99801587 0.43       0.7027027  0.88235294 0.5       ] 
 Average acccuracy: 0.7166436987256322
F1 Scores: [0.78375522 0.57291662 0.83277587 0.77227718 0.89947085 0.56410252
 0.99801582 0.49999995 0.33986924 0.86705197 0.50847452] 
 Average F1: 0.6944281605371763
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:08<00:00, 27.77it/s]


Epoch: 1 	 Loss: 1.42025


100%|██████████| 246/246 [00:08<00:00, 28.10it/s]


Epoch: 2 	 Loss: 0.81254


100%|██████████| 246/246 [00:08<00:00, 29.03it/s]


Epoch: 3 	 Loss: 0.65107


100%|██████████| 246/246 [00:09<00:00, 26.72it/s]


Epoch: 4 	 Loss: 0.55486


100%|██████████| 246/246 [00:09<00:00, 26.60it/s]


Epoch: 5 	 Loss: 0.46468


100%|██████████| 246/246 [00:08<00:00, 29.97it/s]


Epoch: 6 	 Loss: 0.38884


100%|██████████| 246/246 [00:08<00:00, 28.80it/s]


Epoch: 7 	 Loss: 0.34290


100%|██████████| 246/246 [00:10<00:00, 22.41it/s]


Epoch: 8 	 Loss: 0.29308


100%|██████████| 246/246 [00:09<00:00, 25.42it/s]


Epoch: 9 	 Loss: 0.26298


100%|██████████| 246/246 [00:11<00:00, 21.39it/s]


Epoch: 10 	 Loss: 0.23795


100%|██████████| 246/246 [00:10<00:00, 24.17it/s]


Epoch: 11 	 Loss: 0.21585


100%|██████████| 246/246 [00:10<00:00, 23.35it/s]


Epoch: 12 	 Loss: 0.20403


100%|██████████| 246/246 [00:11<00:00, 22.12it/s]


Epoch: 13 	 Loss: 0.19043


100%|██████████| 246/246 [00:08<00:00, 28.11it/s]


Epoch: 14 	 Loss: 0.21080


100%|██████████| 246/246 [00:08<00:00, 29.79it/s]


Epoch: 15 	 Loss: 0.17225


100%|██████████| 246/246 [00:08<00:00, 30.62it/s]


Epoch: 16 	 Loss: 0.14995


100%|██████████| 246/246 [00:08<00:00, 28.66it/s]


Epoch: 17 	 Loss: 0.12851


100%|██████████| 246/246 [00:08<00:00, 28.64it/s]


Epoch: 18 	 Loss: 0.10637


100%|██████████| 246/246 [00:08<00:00, 30.38it/s]


Epoch: 19 	 Loss: 0.09186


100%|██████████| 246/246 [00:09<00:00, 25.43it/s]


Epoch: 20 	 Loss: 0.08741


100%|██████████| 246/246 [00:08<00:00, 27.50it/s]


Epoch: 21 	 Loss: 0.08366


100%|██████████| 246/246 [00:08<00:00, 29.19it/s]


Epoch: 22 	 Loss: 0.08354


100%|██████████| 246/246 [00:10<00:00, 22.66it/s]


Epoch: 23 	 Loss: 0.08520


100%|██████████| 246/246 [00:08<00:00, 27.52it/s]


Epoch: 24 	 Loss: 0.06802


100%|██████████| 246/246 [00:09<00:00, 26.65it/s]


Epoch: 25 	 Loss: 0.06804


100%|██████████| 246/246 [00:11<00:00, 21.08it/s]


Epoch: 26 	 Loss: 0.06170


100%|██████████| 246/246 [00:09<00:00, 25.32it/s]


Epoch: 27 	 Loss: 0.05808


100%|██████████| 246/246 [00:09<00:00, 27.19it/s]


Epoch: 28 	 Loss: 0.05605


100%|██████████| 246/246 [00:10<00:00, 24.10it/s]


Epoch: 29 	 Loss: 0.05442


100%|██████████| 246/246 [00:09<00:00, 25.04it/s]


Epoch: 30 	 Loss: 0.06597


100%|██████████| 246/246 [00:09<00:00, 24.77it/s]


Epoch: 31 	 Loss: 0.08921


100%|██████████| 246/246 [00:08<00:00, 27.57it/s]


Epoch: 32 	 Loss: 0.05037


100%|██████████| 246/246 [00:08<00:00, 28.59it/s]


Epoch: 33 	 Loss: 0.04095


100%|██████████| 246/246 [00:08<00:00, 29.39it/s]


Epoch: 34 	 Loss: 0.03961


100%|██████████| 246/246 [00:08<00:00, 28.75it/s]


Epoch: 35 	 Loss: 0.03372


100%|██████████| 246/246 [00:07<00:00, 33.21it/s]


Epoch: 36 	 Loss: 0.02844


100%|██████████| 246/246 [00:09<00:00, 26.45it/s]


Epoch: 37 	 Loss: 0.02529


100%|██████████| 246/246 [00:08<00:00, 29.31it/s]


Epoch: 38 	 Loss: 0.02477
Accuracies: [0.85866261 0.41447368 0.78517398 0.86666667 0.87894737 0.32949309
 0.99015748 0.55769231 0.71428571 0.85393258 0.45238095] 
 Average acccuracy: 0.700169676236471
F1 Scores: [0.69197791 0.49218745 0.83709672 0.82105258 0.87894732 0.48805457
 0.9940711  0.46774189 0.27777775 0.85875701 0.53521122] 
 Average F1: 0.6675341369181542


In [127]:
print(f"{np.mean(accs5):.4f}")
print(f"{np.mean(f1s5):.4f}")

0.7031
0.6785


## CNN-BiLSTM

In [125]:
while_i = 0
accs6 = []
f1s6 = []
micro_accs3 = []
micro_f1s3 = []
while while_i < 3:
    while_i += 1
    model6 = BiLSTM(hidden_size=128, dropout= 0.3, output_size= 11)
    optimizer = torch.optim.Adam(model6.parameters(), lr= 5e-4)
    loss_function = nn.CrossEntropyLoss(weight= class_weights)

    print(f'{"Starting Training":-^100}')
    model6.train()
    loss_list = []
    for epoch in range(100):
        running_loss = 0
        for idx in tqdm(range(246)):
            TRAIN_emb = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/embedding")
            TRAIN_labels = load_tensor(filepath=f"train_document/doc_{idx}_inlegal/label")
            TRAIN_labels = remap_targets(TRAIN_labels, label_encoder_old, label_encoder_new)
            if TRAIN_emb.size(0) == 0:
                continue
            output = model6(TRAIN_emb)
            loss = loss_function(output,TRAIN_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        # scheduler.step()
        # scheduler1.step()
        # scheduler2.step()
        # running_lr.append(model_opt.state_dict()['param_groups'][0]['lr'])
        loss_list.append(running_loss/246)
        print(f"Epoch: {epoch+1} \t Loss: {running_loss/246:.5f}")
        if running_loss/246 < 0.025:
            break
    cm = None
    for i in range(29):
        TEST_emb = load_tensor(filepath=f"test_document/doc_{i}_inlegal/embedding")
        TEST_labels = load_tensor(filepath=f"test_document/doc_{i}_inlegal/label")
        TEST_labels = remap_targets(TEST_labels, label_encoder_old, label_encoder_new)
        conf_matrix_helper = calculate_confusion_matrix(TEST_emb, TEST_labels, model6, num_labels= 11)
        if cm is None:
            cm = conf_matrix_helper
        else:
            cm = np.add(cm, conf_matrix_helper)
            
    accuracies = class_accuracy(cm)
    f1_scores = class_f1_score(cm)
    average_accuracy = np.mean(accuracies)
    average_f1 = np.mean(f1_scores)

    accs6.append(average_accuracy)
    f1s6.append(average_f1)
    micro_accs3.append(accuracies)
    micro_f1s3.append(f1_scores)
    
    print("Accuracies: {} \n Average acccuracy: {}".format(accuracies, average_accuracy))
    print("F1 Scores: {} \n Average F1: {}".format(f1_scores, average_f1))

-----------------------------------------Starting Training------------------------------------------


  0%|          | 0/246 [00:00<?, ?it/s]

100%|██████████| 246/246 [00:20<00:00, 12.01it/s]


Epoch: 1 	 Loss: 1.44473


100%|██████████| 246/246 [00:12<00:00, 19.77it/s]


Epoch: 2 	 Loss: 0.84586


100%|██████████| 246/246 [00:14<00:00, 16.92it/s]


Epoch: 3 	 Loss: 0.67129


100%|██████████| 246/246 [00:11<00:00, 21.38it/s]


Epoch: 4 	 Loss: 0.56737


100%|██████████| 246/246 [00:16<00:00, 15.08it/s]


Epoch: 5 	 Loss: 0.48665


100%|██████████| 246/246 [00:11<00:00, 21.24it/s]


Epoch: 6 	 Loss: 0.40411


100%|██████████| 246/246 [00:10<00:00, 22.81it/s]


Epoch: 7 	 Loss: 0.33909


100%|██████████| 246/246 [00:07<00:00, 33.61it/s]


Epoch: 8 	 Loss: 0.30008


100%|██████████| 246/246 [00:09<00:00, 26.79it/s]


Epoch: 9 	 Loss: 0.25802


100%|██████████| 246/246 [00:09<00:00, 26.00it/s]


Epoch: 10 	 Loss: 0.23479


100%|██████████| 246/246 [00:13<00:00, 18.16it/s]


Epoch: 11 	 Loss: 0.21121


100%|██████████| 246/246 [00:12<00:00, 19.74it/s]


Epoch: 12 	 Loss: 0.19641


100%|██████████| 246/246 [00:14<00:00, 16.66it/s]


Epoch: 13 	 Loss: 0.18526


100%|██████████| 246/246 [00:08<00:00, 29.60it/s]


Epoch: 14 	 Loss: 0.16046


100%|██████████| 246/246 [00:16<00:00, 14.75it/s]


Epoch: 15 	 Loss: 0.14318


100%|██████████| 246/246 [00:10<00:00, 23.47it/s]


Epoch: 16 	 Loss: 0.13447


100%|██████████| 246/246 [00:09<00:00, 24.79it/s]


Epoch: 17 	 Loss: 0.13112


100%|██████████| 246/246 [00:14<00:00, 16.67it/s]


Epoch: 18 	 Loss: 0.11700


100%|██████████| 246/246 [00:13<00:00, 18.87it/s]


Epoch: 19 	 Loss: 0.10494


100%|██████████| 246/246 [00:13<00:00, 18.80it/s]


Epoch: 20 	 Loss: 0.09401


100%|██████████| 246/246 [00:11<00:00, 20.58it/s]


Epoch: 21 	 Loss: 0.08596


100%|██████████| 246/246 [00:09<00:00, 27.03it/s]


Epoch: 22 	 Loss: 0.07769


100%|██████████| 246/246 [00:13<00:00, 18.08it/s]


Epoch: 23 	 Loss: 0.06880


100%|██████████| 246/246 [00:09<00:00, 25.14it/s]


Epoch: 24 	 Loss: 0.08152


100%|██████████| 246/246 [00:10<00:00, 24.39it/s]


Epoch: 25 	 Loss: 0.08182


100%|██████████| 246/246 [00:10<00:00, 23.00it/s]


Epoch: 26 	 Loss: 0.11143


100%|██████████| 246/246 [00:14<00:00, 16.57it/s]


Epoch: 27 	 Loss: 0.07351


100%|██████████| 246/246 [00:16<00:00, 14.78it/s]


Epoch: 28 	 Loss: 0.05078


100%|██████████| 246/246 [00:14<00:00, 17.23it/s]


Epoch: 29 	 Loss: 0.04318


100%|██████████| 246/246 [00:12<00:00, 20.43it/s]


Epoch: 30 	 Loss: 0.03716


100%|██████████| 246/246 [00:14<00:00, 17.02it/s]


Epoch: 31 	 Loss: 0.03954


100%|██████████| 246/246 [00:15<00:00, 16.29it/s]


Epoch: 32 	 Loss: 0.03805


100%|██████████| 246/246 [00:11<00:00, 21.87it/s]


Epoch: 33 	 Loss: 0.03681


100%|██████████| 246/246 [00:13<00:00, 18.41it/s]


Epoch: 34 	 Loss: 0.04064


100%|██████████| 246/246 [00:14<00:00, 16.41it/s]


Epoch: 35 	 Loss: 0.04938


100%|██████████| 246/246 [00:09<00:00, 26.53it/s]


Epoch: 36 	 Loss: 0.04517


100%|██████████| 246/246 [00:13<00:00, 18.57it/s]


Epoch: 37 	 Loss: 0.03650


100%|██████████| 246/246 [00:12<00:00, 19.78it/s]


Epoch: 38 	 Loss: 0.04742


100%|██████████| 246/246 [00:09<00:00, 26.56it/s]


Epoch: 39 	 Loss: 0.05229


100%|██████████| 246/246 [00:11<00:00, 20.84it/s]


Epoch: 40 	 Loss: 0.03650


100%|██████████| 246/246 [00:10<00:00, 23.69it/s]


Epoch: 41 	 Loss: 0.04874


100%|██████████| 246/246 [00:10<00:00, 22.43it/s]


Epoch: 42 	 Loss: 0.03937


100%|██████████| 246/246 [00:13<00:00, 18.56it/s]


Epoch: 43 	 Loss: 0.02509


100%|██████████| 246/246 [00:11<00:00, 22.16it/s]


Epoch: 44 	 Loss: 0.02485
Accuracies: [0.84247539 0.34615385 0.73756906 0.68965517 0.87958115 0.50819672
 0.99407115 0.47761194 0.53333333 0.85869565 0.51282051] 
 Average acccuracy: 0.6709239931032334
F1 Scores: [0.71055748 0.44055939 0.81964692 0.74074069 0.88188971 0.62626258
 0.99603955 0.4604316  0.29813661 0.87777773 0.58823524] 
 Average F1: 0.6763888647109512
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:09<00:00, 24.74it/s]


Epoch: 1 	 Loss: 1.43614


100%|██████████| 246/246 [00:12<00:00, 19.86it/s]


Epoch: 2 	 Loss: 0.85970


100%|██████████| 246/246 [00:11<00:00, 21.72it/s]


Epoch: 3 	 Loss: 0.68171


100%|██████████| 246/246 [00:10<00:00, 24.34it/s]


Epoch: 4 	 Loss: 0.56971


100%|██████████| 246/246 [00:11<00:00, 20.90it/s]


Epoch: 5 	 Loss: 0.47999


100%|██████████| 246/246 [00:10<00:00, 22.59it/s]


Epoch: 6 	 Loss: 0.41593


100%|██████████| 246/246 [00:13<00:00, 18.28it/s]


Epoch: 7 	 Loss: 0.36689


100%|██████████| 246/246 [00:10<00:00, 22.56it/s]


Epoch: 8 	 Loss: 0.31856


100%|██████████| 246/246 [00:11<00:00, 21.00it/s]


Epoch: 9 	 Loss: 0.27226


100%|██████████| 246/246 [00:12<00:00, 19.46it/s]


Epoch: 10 	 Loss: 0.24615


100%|██████████| 246/246 [00:17<00:00, 14.13it/s]


Epoch: 11 	 Loss: 0.22422


100%|██████████| 246/246 [00:08<00:00, 28.21it/s]


Epoch: 12 	 Loss: 0.19261


100%|██████████| 246/246 [00:12<00:00, 20.48it/s]


Epoch: 13 	 Loss: 0.17265


100%|██████████| 246/246 [00:07<00:00, 31.54it/s]


Epoch: 14 	 Loss: 0.16441


100%|██████████| 246/246 [00:09<00:00, 26.17it/s]


Epoch: 15 	 Loss: 0.15655


100%|██████████| 246/246 [00:09<00:00, 24.94it/s]


Epoch: 16 	 Loss: 0.13202


100%|██████████| 246/246 [00:09<00:00, 24.87it/s]


Epoch: 17 	 Loss: 0.13053


100%|██████████| 246/246 [00:10<00:00, 24.59it/s]


Epoch: 18 	 Loss: 0.12742


100%|██████████| 246/246 [00:11<00:00, 21.19it/s]


Epoch: 19 	 Loss: 0.11196


100%|██████████| 246/246 [00:08<00:00, 27.96it/s]


Epoch: 20 	 Loss: 0.09718


100%|██████████| 246/246 [00:08<00:00, 27.99it/s]


Epoch: 21 	 Loss: 0.08960


100%|██████████| 246/246 [00:11<00:00, 21.18it/s]


Epoch: 22 	 Loss: 0.09587


100%|██████████| 246/246 [00:07<00:00, 30.82it/s]


Epoch: 23 	 Loss: 0.07929


100%|██████████| 246/246 [00:09<00:00, 25.28it/s]


Epoch: 24 	 Loss: 0.06564


100%|██████████| 246/246 [00:09<00:00, 26.31it/s]


Epoch: 25 	 Loss: 0.08486


100%|██████████| 246/246 [00:08<00:00, 30.41it/s]


Epoch: 26 	 Loss: 0.07306


100%|██████████| 246/246 [00:07<00:00, 32.04it/s]


Epoch: 27 	 Loss: 0.07671


100%|██████████| 246/246 [00:08<00:00, 30.36it/s]


Epoch: 28 	 Loss: 0.06151


100%|██████████| 246/246 [00:08<00:00, 28.09it/s]


Epoch: 29 	 Loss: 0.05474


100%|██████████| 246/246 [00:09<00:00, 25.34it/s]


Epoch: 30 	 Loss: 0.05076


100%|██████████| 246/246 [00:09<00:00, 25.73it/s]


Epoch: 31 	 Loss: 0.04847


100%|██████████| 246/246 [00:08<00:00, 27.84it/s]


Epoch: 32 	 Loss: 0.04578


100%|██████████| 246/246 [00:09<00:00, 27.26it/s]


Epoch: 33 	 Loss: 0.03997


100%|██████████| 246/246 [00:09<00:00, 26.00it/s]


Epoch: 34 	 Loss: 0.03969


100%|██████████| 246/246 [00:10<00:00, 23.88it/s]


Epoch: 35 	 Loss: 0.03777


100%|██████████| 246/246 [00:10<00:00, 23.19it/s]


Epoch: 36 	 Loss: 0.04216


100%|██████████| 246/246 [00:11<00:00, 21.67it/s]


Epoch: 37 	 Loss: 0.04778


100%|██████████| 246/246 [00:08<00:00, 27.99it/s]


Epoch: 38 	 Loss: 0.04918


100%|██████████| 246/246 [00:09<00:00, 25.91it/s]


Epoch: 39 	 Loss: 0.03315


100%|██████████| 246/246 [00:10<00:00, 23.21it/s]


Epoch: 40 	 Loss: 0.03436


100%|██████████| 246/246 [00:11<00:00, 22.19it/s]


Epoch: 41 	 Loss: 0.02731


100%|██████████| 246/246 [00:11<00:00, 22.34it/s]


Epoch: 42 	 Loss: 0.02491
Accuracies: [0.83313033 0.63333333 0.81372549 0.80434783 0.85294118 0.38505747
 0.9960396  0.49275362 0.68888889 0.87058824 0.5       ] 
 Average acccuracy: 0.7155278161409426
F1 Scores: [0.7616926  0.58762882 0.83627199 0.77083328 0.88324868 0.53599996
 0.99702671 0.48226945 0.38509313 0.85549128 0.53968249] 
 Average F1: 0.6941125801732024
-----------------------------------------Starting Training------------------------------------------


100%|██████████| 246/246 [00:14<00:00, 16.59it/s]


Epoch: 1 	 Loss: 1.40374


100%|██████████| 246/246 [00:11<00:00, 20.57it/s]


Epoch: 2 	 Loss: 0.83564


100%|██████████| 246/246 [00:11<00:00, 22.29it/s]


Epoch: 3 	 Loss: 0.67287


100%|██████████| 246/246 [00:10<00:00, 22.88it/s]


Epoch: 4 	 Loss: 0.55693


100%|██████████| 246/246 [00:12<00:00, 19.71it/s]


Epoch: 5 	 Loss: 0.49111


100%|██████████| 246/246 [00:11<00:00, 20.60it/s]


Epoch: 6 	 Loss: 0.40389


100%|██████████| 246/246 [00:11<00:00, 21.06it/s]


Epoch: 7 	 Loss: 0.35165


100%|██████████| 246/246 [00:10<00:00, 22.80it/s]


Epoch: 8 	 Loss: 0.30944


100%|██████████| 246/246 [00:10<00:00, 22.99it/s]


Epoch: 9 	 Loss: 0.27112


100%|██████████| 246/246 [00:11<00:00, 21.06it/s]


Epoch: 10 	 Loss: 0.24498


100%|██████████| 246/246 [00:10<00:00, 22.89it/s]


Epoch: 11 	 Loss: 0.22079


100%|██████████| 246/246 [00:09<00:00, 25.65it/s]


Epoch: 12 	 Loss: 0.20098


100%|██████████| 246/246 [00:09<00:00, 24.90it/s]


Epoch: 13 	 Loss: 0.18046


100%|██████████| 246/246 [00:09<00:00, 24.98it/s]


Epoch: 14 	 Loss: 0.17283


100%|██████████| 246/246 [00:09<00:00, 24.93it/s]


Epoch: 15 	 Loss: 0.14912


100%|██████████| 246/246 [00:10<00:00, 24.44it/s]


Epoch: 16 	 Loss: 0.14582


100%|██████████| 246/246 [00:10<00:00, 23.73it/s]


Epoch: 17 	 Loss: 0.12628


100%|██████████| 246/246 [00:11<00:00, 22.32it/s]


Epoch: 18 	 Loss: 0.12172


100%|██████████| 246/246 [00:09<00:00, 25.21it/s]


Epoch: 19 	 Loss: 0.11203


100%|██████████| 246/246 [00:11<00:00, 21.59it/s]


Epoch: 20 	 Loss: 0.09584


100%|██████████| 246/246 [00:12<00:00, 20.31it/s]


Epoch: 21 	 Loss: 0.08318


100%|██████████| 246/246 [00:09<00:00, 26.04it/s]


Epoch: 22 	 Loss: 0.08171


100%|██████████| 246/246 [00:09<00:00, 25.49it/s]


Epoch: 23 	 Loss: 0.08006


100%|██████████| 246/246 [00:09<00:00, 26.15it/s]


Epoch: 24 	 Loss: 0.07914


100%|██████████| 246/246 [00:09<00:00, 25.30it/s]


Epoch: 25 	 Loss: 0.07340


100%|██████████| 246/246 [00:09<00:00, 25.25it/s]


Epoch: 26 	 Loss: 0.07496


100%|██████████| 246/246 [00:10<00:00, 23.49it/s]


Epoch: 27 	 Loss: 0.06701


100%|██████████| 246/246 [00:09<00:00, 27.27it/s]


Epoch: 28 	 Loss: 0.05895


100%|██████████| 246/246 [00:10<00:00, 24.53it/s]


Epoch: 29 	 Loss: 0.06065


100%|██████████| 246/246 [00:08<00:00, 28.59it/s]


Epoch: 30 	 Loss: 0.05071


100%|██████████| 246/246 [00:08<00:00, 30.31it/s]


Epoch: 31 	 Loss: 0.05444


100%|██████████| 246/246 [00:08<00:00, 29.67it/s]


Epoch: 32 	 Loss: 0.04260


100%|██████████| 246/246 [00:09<00:00, 26.93it/s]


Epoch: 33 	 Loss: 0.03891


100%|██████████| 246/246 [00:07<00:00, 31.07it/s]


Epoch: 34 	 Loss: 0.03573


100%|██████████| 246/246 [00:08<00:00, 28.29it/s]


Epoch: 35 	 Loss: 0.04068


100%|██████████| 246/246 [00:09<00:00, 25.87it/s]


Epoch: 36 	 Loss: 0.03517


100%|██████████| 246/246 [00:08<00:00, 28.16it/s]


Epoch: 37 	 Loss: 0.03514


100%|██████████| 246/246 [00:09<00:00, 25.56it/s]


Epoch: 38 	 Loss: 0.03596


100%|██████████| 246/246 [00:08<00:00, 30.09it/s]


Epoch: 39 	 Loss: 0.03315


100%|██████████| 246/246 [00:10<00:00, 24.09it/s]


Epoch: 40 	 Loss: 0.07980


100%|██████████| 246/246 [00:07<00:00, 30.77it/s]


Epoch: 41 	 Loss: 0.06907


100%|██████████| 246/246 [00:09<00:00, 25.27it/s]


Epoch: 42 	 Loss: 0.03253


100%|██████████| 246/246 [00:09<00:00, 27.11it/s]


Epoch: 43 	 Loss: 0.02215
Accuracies: [0.84010152 0.39156627 0.75247525 0.69230769 0.89637306 0.5
 0.98821218 0.45588235 0.625      0.83157895 0.53658537] 
 Average acccuracy: 0.6827347846945421
F1 Scores: [0.75099258 0.48148143 0.82737165 0.7058823  0.90339421 0.58695647
 0.99308978 0.44285709 0.21428569 0.86338793 0.62857138] 
 Average F1: 0.6725700455676673


In [129]:
print(f"{np.mean(accs6):.4f}")
print(f"{np.mean(f1s6):.4f}")

0.6897
0.6810


## HOW DOES THE BILSTM WORK?

In [55]:
import torch.nn as nn

In [91]:
sample_input = load_tensor(filepath=f"train_document/doc_0_inlegal/embedding")

In [92]:
sample_input.size()

torch.Size([91, 1, 768])

In [93]:
# sample_input = sample_input.permute(1,2,0)
# sample_input.size()

In [94]:
bilstm = nn.LSTM(input_size=768, hidden_size=128, bidirectional= True)

In [95]:
output, (hidden, cell) = bilstm(sample_input)

In [96]:
output.size()

torch.Size([91, 1, 256])

In [97]:
hidden.size()

torch.Size([2, 1, 128])

In [98]:
cell.size()

torch.Size([2, 1, 128])