In [2]:
#!git clone https://github.com/IndoNLP/indonlu.git

In [3]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.1-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 4.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 10.8 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 63.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 61.4 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstal

In [4]:
import os, sys
sys.path.append('../')
os.chdir('../')

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm

from transformers import BertForSequenceClassification, BertConfig, BertTokenizer
from nltk.tokenize import TweetTokenizer

from content.drive.MyDrive.TA.indonlu.utils.forward_fn import forward_sequence_classification
from content.drive.MyDrive.TA.indonlu.utils.metrics import hadits_classification_metrics_fn
from content.drive.MyDrive.TA.indonlu.utils.data_utils import HaditsClassificationDataset, HaditsClassificationDataLoader


In [5]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.2f}'.format(key, value))
    return ' '.join(string_list)

In [6]:
# Set random seed
set_seed(26092020)

In [7]:
# Load Tokenizer and Config
tokenizer = BertTokenizer.from_pretrained('indobenchmark/indobert-base-p1')
config = BertConfig.from_pretrained('indobenchmark/indobert-base-p1')
config.num_labels = HaditsClassificationDataset.NUM_LABELS

# Instantiate model
model = BertForSequenceClassification.from_pretrained('indobenchmark/indobert-base-p1', config=config)

Downloading vocab.txt:   0%|          | 0.00/224k [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/475M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at indobenchmark/indobert-base-p1 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# model.summary()

In [9]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [10]:
count_param(model)

124500557

In [11]:
train_dataset_path = '/content/drive/MyDrive/TA/bert_train/nr/bukhari_train.csv'
valid_dataset_path = '/content/drive/MyDrive/TA/bert_train/nr/bukhari_valid.csv'
test_dataset_path = '/content/drive/MyDrive/TA/bert_train/nr/bukhari_test.csv'

In [12]:
train_dataset = HaditsClassificationDataset(train_dataset_path, tokenizer, lowercase=True)
valid_dataset = HaditsClassificationDataset(valid_dataset_path, tokenizer, lowercase=True)
test_dataset = HaditsClassificationDataset(test_dataset_path, tokenizer, lowercase=True)

train_loader = HaditsClassificationDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=8, num_workers=4, shuffle=True)  
valid_loader = HaditsClassificationDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=8, num_workers=4, shuffle=False)  
test_loader = HaditsClassificationDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=8, num_workers=4, shuffle=False)

In [13]:
w2i, i2w = HaditsClassificationDataset.LABEL2INDEX, HaditsClassificationDataset.INDEX2LABEL
print(w2i)
print(i2w)

{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 32: 31, 33: 32, 34: 33, 35: 34, 36: 35, 37: 36, 38: 37, 39: 38, 40: 39, 41: 40, 42: 41, 43: 42, 44: 43, 45: 44, 46: 45, 47: 46, 48: 47, 49: 48, 50: 49, 51: 50, 52: 51, 53: 52, 54: 53, 55: 54, 56: 55, 57: 56, 58: 57, 59: 58, 60: 59, 61: 60, 62: 61, 63: 62, 64: 63, 65: 64, 66: 65, 67: 66, 68: 67, 69: 68, 70: 69, 71: 70, 72: 71, 73: 72, 74: 73, 75: 74, 76: 75, 77: 76}
{0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 13, 13: 14, 14: 15, 15: 16, 16: 17, 17: 18, 18: 19, 19: 20, 20: 21, 21: 22, 22: 23, 23: 24, 24: 25, 25: 26, 26: 27, 27: 28, 28: 29, 29: 30, 30: 31, 31: 32, 32: 33, 33: 34, 34: 35, 35: 36, 36: 37, 37: 38, 38: 39, 39: 40, 40: 41, 41: 42, 42: 43, 43: 44, 44: 45, 45: 46, 46: 47, 47: 48, 48: 49, 49: 50, 50: 51, 51: 52, 52: 

In [14]:
# ___________________________________________-

In [15]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Fine Tuning & Evaluation

In [16]:
import torch
torch.cuda.empty_cache()

In [17]:
optimizer = optim.Adam(model.parameters(), lr=5e-6)
model = model.cuda()

In [18]:
!torch.cuda.empty_cache()
!nvidia-smi

/bin/bash: -c: line 1: syntax error: unexpected end of file
Sun Aug  7 18:49:35 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    26W /  70W |   1138MiB / 15109MiB |      6%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------

In [19]:
torch.no_grad()

<torch.autograd.grad_mode.no_grad at 0x7f803cafafd0>

In [20]:
# Train
loss_list = []
accuracy_list = []
precision_list = []
recall_list = []
f1_list = []
n_epochs = 20
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = hadits_classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = hadits_classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = hadits_classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    loss_list.append(total_loss/(i+1))
    accuracy_list.append(metrics['ACC'])
    precision_list.append(metrics['PRE'])
    recall_list.append(metrics['REC'])
    f1_list.append(metrics['F1'])

(Epoch 1) TRAIN LOSS:4.0641 LR:0.00000500: 100%|██████████| 701/701 [04:37<00:00,  2.53it/s]


(Epoch 1) TRAIN LOSS:4.0641 ACC:0.13 F1:0.11 REC:0.13 PRE:0.15 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 1) VALID LOSS:3.3806 ACC:0.34 F1:0.28 REC:0.34 PRE:0.33


(Epoch 2) TRAIN LOSS:3.0065 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]
  _warn_prf(average, modifier, msg_start, len(result))


(Epoch 2) TRAIN LOSS:3.0065 ACC:0.43 F1:0.39 REC:0.43 PRE:0.45 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 2) VALID LOSS:2.5364 ACC:0.52 F1:0.46 REC:0.51 PRE:0.51


(Epoch 3) TRAIN LOSS:2.2788 LR:0.00000500: 100%|██████████| 701/701 [04:49<00:00,  2.42it/s]


(Epoch 3) TRAIN LOSS:2.2788 ACC:0.60 F1:0.56 REC:0.59 PRE:0.61 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 3) VALID LOSS:2.0012 ACC:0.61 F1:0.56 REC:0.60 PRE:0.60


(Epoch 4) TRAIN LOSS:1.7549 LR:0.00000500: 100%|██████████| 701/701 [04:45<00:00,  2.45it/s]


(Epoch 4) TRAIN LOSS:1.7549 ACC:0.71 F1:0.69 REC:0.71 PRE:0.72 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 4) VALID LOSS:1.6255 ACC:0.69 F1:0.65 REC:0.69 PRE:0.67


(Epoch 5) TRAIN LOSS:1.3473 LR:0.00000500: 100%|██████████| 701/701 [04:47<00:00,  2.44it/s]


(Epoch 5) TRAIN LOSS:1.3473 ACC:0.79 F1:0.78 REC:0.79 PRE:0.80 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 5) VALID LOSS:1.3583 ACC:0.72 F1:0.69 REC:0.72 PRE:0.69


(Epoch 6) TRAIN LOSS:1.0115 LR:0.00000500: 100%|██████████| 701/701 [04:50<00:00,  2.42it/s]


(Epoch 6) TRAIN LOSS:1.0115 ACC:0.85 F1:0.84 REC:0.85 PRE:0.85 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 6) VALID LOSS:1.1583 ACC:0.74 F1:0.72 REC:0.74 PRE:0.71


(Epoch 7) TRAIN LOSS:0.7487 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]


(Epoch 7) TRAIN LOSS:0.7487 ACC:0.90 F1:0.89 REC:0.90 PRE:0.90 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 7) VALID LOSS:1.0211 ACC:0.76 F1:0.74 REC:0.76 PRE:0.75


(Epoch 8) TRAIN LOSS:0.5440 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]


(Epoch 8) TRAIN LOSS:0.5440 ACC:0.93 F1:0.93 REC:0.93 PRE:0.93 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 8) VALID LOSS:0.9142 ACC:0.79 F1:0.77 REC:0.79 PRE:0.78


(Epoch 9) TRAIN LOSS:0.3962 LR:0.00000500: 100%|██████████| 701/701 [04:51<00:00,  2.41it/s]


(Epoch 9) TRAIN LOSS:0.3962 ACC:0.96 F1:0.96 REC:0.96 PRE:0.96 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 9) VALID LOSS:0.8550 ACC:0.78 F1:0.77 REC:0.78 PRE:0.77


(Epoch 10) TRAIN LOSS:0.2804 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]


(Epoch 10) TRAIN LOSS:0.2804 ACC:0.97 F1:0.97 REC:0.97 PRE:0.97 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 10) VALID LOSS:0.8475 ACC:0.79 F1:0.77 REC:0.79 PRE:0.78


(Epoch 11) TRAIN LOSS:0.1952 LR:0.00000500: 100%|██████████| 701/701 [04:49<00:00,  2.42it/s]


(Epoch 11) TRAIN LOSS:0.1952 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 11) VALID LOSS:0.8549 ACC:0.78 F1:0.77 REC:0.78 PRE:0.77


(Epoch 12) TRAIN LOSS:0.1385 LR:0.00000500: 100%|██████████| 701/701 [04:47<00:00,  2.44it/s]


(Epoch 12) TRAIN LOSS:0.1385 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 12) VALID LOSS:0.8293 ACC:0.79 F1:0.78 REC:0.79 PRE:0.79


(Epoch 13) TRAIN LOSS:0.0995 LR:0.00000500: 100%|██████████| 701/701 [04:47<00:00,  2.44it/s]


(Epoch 13) TRAIN LOSS:0.0995 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 13) VALID LOSS:0.8491 ACC:0.79 F1:0.77 REC:0.78 PRE:0.78


(Epoch 14) TRAIN LOSS:0.0692 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]


(Epoch 14) TRAIN LOSS:0.0692 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 14) VALID LOSS:0.8694 ACC:0.79 F1:0.77 REC:0.79 PRE:0.78


(Epoch 15) TRAIN LOSS:0.0570 LR:0.00000500: 100%|██████████| 701/701 [04:48<00:00,  2.43it/s]


(Epoch 15) TRAIN LOSS:0.0570 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 15) VALID LOSS:0.8447 ACC:0.80 F1:0.79 REC:0.80 PRE:0.80


(Epoch 16) TRAIN LOSS:0.0340 LR:0.00000500: 100%|██████████| 701/701 [04:49<00:00,  2.42it/s]


(Epoch 16) TRAIN LOSS:0.0340 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 16) VALID LOSS:0.8675 ACC:0.80 F1:0.79 REC:0.80 PRE:0.80


(Epoch 17) TRAIN LOSS:0.0316 LR:0.00000500: 100%|██████████| 701/701 [04:49<00:00,  2.42it/s]


(Epoch 17) TRAIN LOSS:0.0316 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 17) VALID LOSS:0.9016 ACC:0.80 F1:0.79 REC:0.80 PRE:0.79


(Epoch 18) TRAIN LOSS:0.0479 LR:0.00000500: 100%|██████████| 701/701 [04:49<00:00,  2.42it/s]


(Epoch 18) TRAIN LOSS:0.0479 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 18) VALID LOSS:0.8857 ACC:0.80 F1:0.79 REC:0.80 PRE:0.79


(Epoch 19) TRAIN LOSS:0.0235 LR:0.00000500: 100%|██████████| 701/701 [04:47<00:00,  2.44it/s]


(Epoch 19) TRAIN LOSS:0.0235 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 19) VALID LOSS:0.9537 ACC:0.80 F1:0.78 REC:0.80 PRE:0.79


(Epoch 20) TRAIN LOSS:0.0198 LR:0.00000500: 100%|██████████| 701/701 [04:50<00:00,  2.42it/s]


(Epoch 20) TRAIN LOSS:0.0198 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

(Epoch 20) VALID LOSS:0.9926 ACC:0.80 F1:0.78 REC:0.79 PRE:0.78





In [21]:
report = pd.DataFrame({'acc': accuracy_list, 'precision': precision_list, 'recall':recall_list, 'f1_score':f1_list, 'loss': loss_list}, columns=['acc', 'precision', 'recall','f1_score', 'loss'])

In [22]:
output = '/content/drive/MyDrive/TA/bert_result/metrics_nrr.csv'
report.to_csv(output, index=False)

In [24]:
# Evaluate on test
# model.eval()
# torch.set_grad_enabled(False)

# total_loss, total_correct, total_labels = 0, 0, 0
# list_hyp, list_label = [], []

# pbar = tqdm(test_loader, leave=True, total=len(test_loader))
# for i, batch_data in enumerate(pbar):
#     _, batch_hyp, _ = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
#     list_hyp += batch_hyp

# # Save prediction
# df = pd.DataFrame({'label':list_hyp}).reset_index()
# df.to_csv('pred.txt', index=False)

# print(df)

In [None]:
text = ' wahyu'
subwords = tokenizer.encode(text)
subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)

logits = model(subwords)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text:  | Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

In [None]:
text = "apa"
subwords = tokenizer.encode(text)
subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)

logits = model(subwords)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text:  | Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

In [None]:
import scipy.sparse as sparse

def predict_hadits(hadits):
  result = []
  for i in tqdm(range(0, len(hadits))):
    text = hadits['indo'][i]
    subwords = tokenizer.encode(text)
    subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)

    # Mengurangi panjang words, jika lebih dari 512
    arr = subwords[0]    
    if len(arr) > 512:
      arr = arr[:512]
      arr = np.reshape(arr,(1,512))
      subwords = arr
    logits = model(subwords)[0]
    label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()
    result.append(label)
  return result

def reading_data_result(data_file):
    data_file = pd.read_json(data_file)
    data_file = data_file.sort_values(by=['haditsId'])
    data_file = data_file.reset_index()
    data_file = data_file[["haditsId" , "kitabId","indo", "arab" ]]
    return data_file


In [None]:
hadits_data_loc = [
    '/content/drive/MyDrive/TA/json_data/abudaud.json',
    '/content/drive/MyDrive/TA/json_data/darimi.json',
    '/content/drive/MyDrive/TA/json_data/ibnumajah.json',
    '/content/drive/MyDrive/TA/json_data/malik.json',
    '/content/drive/MyDrive/TA/json_data/muslim.json',
    '/content/drive/MyDrive/TA/json_data/nasai.json',
    '/content/drive/MyDrive/TA/json_data/tirmidzi.json'
]
hadits_ns_loc = [
    '/content/drive/MyDrive/TA/noSanad/abudaud.csv',
    '/content/drive/MyDrive/TA/noSanad/darimi.csv',
    '/content/drive/MyDrive/TA/noSanad/ibnumajah.csv',
    '/content/drive/MyDrive/TA/noSanad/malik.csv',
    '/content/drive/MyDrive/TA/noSanad/muslim.csv',
    '/content/drive/MyDrive/TA/noSanad/nasai.csv',
    '/content/drive/MyDrive/TA/noSanad/tirmidzi.csv'
]
hadits_nsn_loc = [
    '/content/drive/MyDrive/TA/noSanad_name/abudaud.csv',
    '/content/drive/MyDrive/TA/noSanad_name/darimi.csv',
    '/content/drive/MyDrive/TA/noSanad_name/ibnumajah.csv',
    '/content/drive/MyDrive/TA/noSanad_name/malik.csv',
    '/content/drive/MyDrive/TA/noSanad_name/muslim.csv',
    '/content/drive/MyDrive/TA/noSanad_name/nasai.csv',
    '/content/drive/MyDrive/TA/noSanad_name/tirmidzi.csv'
]
hadits_ns_save_loc = [
    '/content/drive/MyDrive/TA/bert_result/model4/ns/abudaud.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/darimi.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/ibnumajah.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/malik.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/muslim.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/nasai.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/ns/tirmidzi.csv'
]
hadits_nsn_save_loc = [
    '/content/drive/MyDrive/TA/bert_result/model4/nr/abudaud.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/darimi.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/ibnumajah.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/malik.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/muslim.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/nasai.csv',
    '/content/drive/MyDrive/TA/bert_result/model4/nr/tirmidzi.csv'
]

In [None]:
for i in range(len(hadits_nsn_loc)):
  hadits = pd.read_csv(hadits_nsn_loc[i])
  hadits = predict_hadits(hadits)
  final = reading_data_result(hadits_data_loc[i])
  final['label'] = hadits
  file_output = hadits_nsn_save_loc[i]
  final.to_csv(file_output, index=False)

In [None]:
hadits = pd.read_csv('/content/drive/MyDrive/TA/bert_result/model4/ns/muslim.csv')

In [None]:
result = hadits.groupby('label').describe()
file_output = '/content/drive/MyDrive/TA/bert_result/model4/log_ns/muslim.csv'
result.to_csv(file_output, index=False)

In [None]:
# hadits = pd.read_csv('/content/drive/MyDrive/TA/noSanad/tirmidzi.csv')
# hadits = predict_hadits(hadits)
# final = reading_data_result('/content/drive/MyDrive/TA/json_data/tirmidzi.json')
# final['label'] = hadits
# final.head()
# file_output = '/content/drive/MyDrive/TA/bert_result/model4/ns/tirmidzi.csv'
# final.to_csv(file_output, index=False)

In [None]:
final.groupby('label').describe()

In [None]:
model.save_pretrained('/content/drive/MyDrive/TA/model/4')

In [None]:
model = BertForSequenceClassification.from_pretrained('/content/drive/MyDrive/TA/model/4')
model.bert.load_state_dict(model.bert.state_dict())