In [2]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import transformers
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

def print_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total_params:          {total_params}\n"
          f"Trainable_params:      {trainable_params}\n"
          f"Non-trainable_params:  {total_params-trainable_params}")

device = torch.device('cuda:1')

train_df = pd.read_json("data/train.jsonl", lines=True)
train_df = pd.DataFrame(train_df, columns=['text', 'label'])
dev_df = pd.read_json("data/dev.jsonl", lines=True)
dev_df = pd.DataFrame(dev_df, columns=['text', 'label'])

In [3]:
class DataframeDataset(Dataset):

    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        row = self.dataframe.iloc[idx]
        return {'text': row.text, 'label': row.label}

In [4]:
def test_model(model, dataframe):
    model.eval()
    preds = []
    targets = []
    loss = 0
    loop = tqdm(DataLoader(DataframeDataset(dataframe), batch_size=32))
    for batch in loop:
        output = tokenizer(batch['text'], padding=True, return_tensors='pt').to(device)
        input_ids = output['input_ids'].to(device)
        attention_mask = output['attention_mask'].to(device)
        labels = batch['label'].to(device)
        result = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss += result.loss.item()
        logits = result.logits
        preds.extend(torch.argmax(logits,axis=1).cpu().tolist())
        targets.extend(labels.cpu().tolist())
    
    loss /= len(loop)
    print("Mean loss: ", loss)
    print("Accuracy: ", accuracy_score(targets, preds))
    print(classification_report(targets, preds))
    return accuracy_score(targets, preds)

In [5]:
model_name = 'Hate-speech-CNERG/dehatebert-mono-english'
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
model = model.to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# trainable_params = ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']

# for param_name, param in model.named_parameters():
#     if param_name in trainable_params:
#         param.requires_grad = True
#     else:
#         param.requires_grad = False
        
print_parameters(model)

Total_params:          167357954
Trainable_params:      167357954
Non-trainable_params:  0


In [8]:
optimizer = transformers.AdamW(model.parameters(), lr=1e-6)

max_accuracy = test_model(model, dev_df)
max_state = {key: value.detach().clone() for key, value in model.state_dict().items()}

for epoch in range(25):
    total_loss = 0
    loop = tqdm(DataLoader(DataframeDataset(train_df), shuffle=True, batch_size=64))
    model = model.train()
    for i, batch in enumerate(loop):
        output = tokenizer(batch['text'], padding=True, return_tensors='pt')
        input_ids = output['input_ids'].to(device)
        attention_mask = output['attention_mask'].to(device)
        labels = batch['label'].to(device)
        output = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = output.loss
        total_loss += loss.item()
        loop.set_postfix({'loss': total_loss / (i+1)})
        loss.backward()
        optimizer.step()
    
    print("Epoch: ", epoch+1)
    accuracy = test_model(model, dev_df)
    if accuracy > max_accuracy:
        max_accuracy = accuracy
        max_state = {key: value.detach().clone() for key, value in model.state_dict().items()}

100%|██████████| 16/16 [00:00<00:00, 22.16it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.401]

Mean loss:  0.9192260634154081
Accuracy:  0.548
              precision    recall  f1-score   support

           0       0.53      0.89      0.66       250
           1       0.65      0.21      0.32       250

    accuracy                           0.55       500
   macro avg       0.59      0.55      0.49       500
weighted avg       0.59      0.55      0.49       500



100%|██████████| 133/133 [00:37<00:00,  3.54it/s, loss=0.472]
 19%|█▉        | 3/16 [00:00<00:00, 21.14it/s]

Epoch:  1


100%|██████████| 16/16 [00:00<00:00, 23.24it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.501]

Mean loss:  0.8415568955242634
Accuracy:  0.56
              precision    recall  f1-score   support

           0       0.54      0.80      0.65       250
           1       0.62      0.32      0.42       250

    accuracy                           0.56       500
   macro avg       0.58      0.56      0.53       500
weighted avg       0.58      0.56      0.53       500



100%|██████████| 133/133 [00:37<00:00,  3.55it/s, loss=0.469]
 19%|█▉        | 3/16 [00:00<00:00, 21.23it/s]

Epoch:  2


100%|██████████| 16/16 [00:00<00:00, 23.25it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.493]

Mean loss:  0.8086190409958363
Accuracy:  0.558
              precision    recall  f1-score   support

           0       0.55      0.70      0.61       250
           1       0.58      0.42      0.49       250

    accuracy                           0.56       500
   macro avg       0.56      0.56      0.55       500
weighted avg       0.56      0.56      0.55       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.481]
 19%|█▉        | 3/16 [00:00<00:00, 21.28it/s]

Epoch:  3


100%|██████████| 16/16 [00:00<00:00, 23.19it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.408]

Mean loss:  0.9952378012239933
Accuracy:  0.548
              precision    recall  f1-score   support

           0       0.53      0.91      0.67       250
           1       0.68      0.18      0.29       250

    accuracy                           0.55       500
   macro avg       0.60      0.55      0.48       500
weighted avg       0.60      0.55      0.48       500



100%|██████████| 133/133 [00:38<00:00,  3.49it/s, loss=0.494]
 19%|█▉        | 3/16 [00:00<00:00, 20.77it/s]

Epoch:  4


100%|██████████| 16/16 [00:00<00:00, 22.97it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.482]

Mean loss:  0.8403174225240946
Accuracy:  0.568
              precision    recall  f1-score   support

           0       0.55      0.79      0.65       250
           1       0.62      0.35      0.45       250

    accuracy                           0.57       500
   macro avg       0.58      0.57      0.55       500
weighted avg       0.58      0.57      0.55       500



100%|██████████| 133/133 [00:38<00:00,  3.48it/s, loss=0.52] 
 19%|█▉        | 3/16 [00:00<00:00, 21.04it/s]

Epoch:  5


100%|██████████| 16/16 [00:00<00:00, 23.17it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.352]

Mean loss:  0.8609078228473663
Accuracy:  0.546
              precision    recall  f1-score   support

           0       0.53      0.89      0.66       250
           1       0.65      0.20      0.31       250

    accuracy                           0.55       500
   macro avg       0.59      0.55      0.48       500
weighted avg       0.59      0.55      0.48       500



100%|██████████| 133/133 [00:38<00:00,  3.45it/s, loss=0.516]
 19%|█▉        | 3/16 [00:00<00:00, 21.02it/s]

Epoch:  6


100%|██████████| 16/16 [00:00<00:00, 23.08it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.451]

Mean loss:  0.7073245905339718
Accuracy:  0.572
              precision    recall  f1-score   support

           0       0.55      0.74      0.63       250
           1       0.61      0.40      0.48       250

    accuracy                           0.57       500
   macro avg       0.58      0.57      0.56       500
weighted avg       0.58      0.57      0.56       500



100%|██████████| 133/133 [00:38<00:00,  3.48it/s, loss=0.501]
 19%|█▉        | 3/16 [00:00<00:00, 20.81it/s]

Epoch:  7


100%|██████████| 16/16 [00:00<00:00, 23.19it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.514]

Mean loss:  0.9740365017205477
Accuracy:  0.512
              precision    recall  f1-score   support

           0       0.51      0.98      0.67       250
           1       0.71      0.04      0.08       250

    accuracy                           0.51       500
   macro avg       0.61      0.51      0.37       500
weighted avg       0.61      0.51      0.37       500



100%|██████████| 133/133 [00:37<00:00,  3.52it/s, loss=0.478]
 19%|█▉        | 3/16 [00:00<00:00, 20.39it/s]

Epoch:  8


100%|██████████| 16/16 [00:00<00:00, 22.91it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.486]

Mean loss:  0.8273804374039173
Accuracy:  0.596
              precision    recall  f1-score   support

           0       0.59      0.62      0.61       250
           1       0.60      0.57      0.59       250

    accuracy                           0.60       500
   macro avg       0.60      0.60      0.60       500
weighted avg       0.60      0.60      0.60       500



100%|██████████| 133/133 [00:37<00:00,  3.55it/s, loss=0.488]
 19%|█▉        | 3/16 [00:00<00:00, 20.73it/s]

Epoch:  9


100%|██████████| 16/16 [00:00<00:00, 22.98it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.48]

Mean loss:  1.0892210602760315
Accuracy:  0.576
              precision    recall  f1-score   support

           0       0.56      0.70      0.62       250
           1       0.60      0.46      0.52       250

    accuracy                           0.58       500
   macro avg       0.58      0.58      0.57       500
weighted avg       0.58      0.58      0.57       500



100%|██████████| 133/133 [00:37<00:00,  3.53it/s, loss=0.479]
 19%|█▉        | 3/16 [00:00<00:00, 21.05it/s]

Epoch:  10


100%|██████████| 16/16 [00:00<00:00, 23.13it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.356]

Mean loss:  0.9603852517902851
Accuracy:  0.554
              precision    recall  f1-score   support

           0       0.54      0.78      0.64       250
           1       0.60      0.32      0.42       250

    accuracy                           0.55       500
   macro avg       0.57      0.55      0.53       500
weighted avg       0.57      0.55      0.53       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.463]
 19%|█▉        | 3/16 [00:00<00:00, 21.02it/s]

Epoch:  11


100%|██████████| 16/16 [00:00<00:00, 22.65it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.408]

Mean loss:  0.7818790413439274
Accuracy:  0.556
              precision    recall  f1-score   support

           0       0.54      0.84      0.65       250
           1       0.63      0.28      0.38       250

    accuracy                           0.56       500
   macro avg       0.58      0.56      0.52       500
weighted avg       0.58      0.56      0.52       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.456]
 19%|█▉        | 3/16 [00:00<00:00, 20.91it/s]

Epoch:  12


100%|██████████| 16/16 [00:00<00:00, 23.14it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.415]

Mean loss:  1.0739568900316954
Accuracy:  0.552
              precision    recall  f1-score   support

           0       0.53      0.85      0.66       250
           1       0.63      0.25      0.36       250

    accuracy                           0.55       500
   macro avg       0.58      0.55      0.51       500
weighted avg       0.58      0.55      0.51       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.445]
 19%|█▉        | 3/16 [00:00<00:00, 20.92it/s]

Epoch:  13


100%|██████████| 16/16 [00:00<00:00, 23.12it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.468]

Mean loss:  1.1035843379795551
Accuracy:  0.562
              precision    recall  f1-score   support

           0       0.55      0.66      0.60       250
           1       0.58      0.46      0.51       250

    accuracy                           0.56       500
   macro avg       0.56      0.56      0.56       500
weighted avg       0.56      0.56      0.56       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.455]
 19%|█▉        | 3/16 [00:00<00:00, 21.22it/s]

Epoch:  14


100%|██████████| 16/16 [00:00<00:00, 23.25it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.62]

Mean loss:  1.1481051761657
Accuracy:  0.566
              precision    recall  f1-score   support

           0       0.54      0.84      0.66       250
           1       0.65      0.29      0.40       250

    accuracy                           0.57       500
   macro avg       0.60      0.57      0.53       500
weighted avg       0.60      0.57      0.53       500



100%|██████████| 133/133 [00:38<00:00,  3.45it/s, loss=0.419]
 19%|█▉        | 3/16 [00:00<00:00, 21.06it/s]

Epoch:  15


100%|██████████| 16/16 [00:00<00:00, 23.06it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.449]

Mean loss:  0.8206413239240646
Accuracy:  0.582
              precision    recall  f1-score   support

           0       0.59      0.56      0.57       250
           1       0.58      0.60      0.59       250

    accuracy                           0.58       500
   macro avg       0.58      0.58      0.58       500
weighted avg       0.58      0.58      0.58       500



100%|██████████| 133/133 [00:38<00:00,  3.50it/s, loss=0.432]
 19%|█▉        | 3/16 [00:00<00:00, 20.71it/s]

Epoch:  16


100%|██████████| 16/16 [00:00<00:00, 22.80it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.389]

Mean loss:  0.9060435313731432
Accuracy:  0.558
              precision    recall  f1-score   support

           0       0.54      0.84      0.65       250
           1       0.63      0.28      0.39       250

    accuracy                           0.56       500
   macro avg       0.58      0.56      0.52       500
weighted avg       0.58      0.56      0.52       500



100%|██████████| 133/133 [00:38<00:00,  3.50it/s, loss=0.427]
 19%|█▉        | 3/16 [00:00<00:00, 21.04it/s]

Epoch:  17


100%|██████████| 16/16 [00:00<00:00, 22.92it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.423]

Mean loss:  1.1133309304714203
Accuracy:  0.542
              precision    recall  f1-score   support

           0       0.52      0.92      0.67       250
           1       0.68      0.16      0.26       250

    accuracy                           0.54       500
   macro avg       0.60      0.54      0.46       500
weighted avg       0.60      0.54      0.46       500



100%|██████████| 133/133 [00:38<00:00,  3.46it/s, loss=0.422]
 19%|█▉        | 3/16 [00:00<00:00, 21.04it/s]

Epoch:  18


100%|██████████| 16/16 [00:00<00:00, 23.11it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.46]

Mean loss:  1.0514733009040356
Accuracy:  0.57
              precision    recall  f1-score   support

           0       0.56      0.66      0.61       250
           1       0.59      0.48      0.53       250

    accuracy                           0.57       500
   macro avg       0.57      0.57      0.57       500
weighted avg       0.57      0.57      0.57       500



100%|██████████| 133/133 [00:37<00:00,  3.51it/s, loss=0.419]
 19%|█▉        | 3/16 [00:00<00:00, 20.72it/s]

Epoch:  19


100%|██████████| 16/16 [00:00<00:00, 22.97it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.364]

Mean loss:  1.145182203501463
Accuracy:  0.56
              precision    recall  f1-score   support

           0       0.54      0.84      0.66       250
           1       0.64      0.28      0.39       250

    accuracy                           0.56       500
   macro avg       0.59      0.56      0.52       500
weighted avg       0.59      0.56      0.52       500



100%|██████████| 133/133 [00:37<00:00,  3.53it/s, loss=0.414]
 19%|█▉        | 3/16 [00:00<00:00, 20.71it/s]

Epoch:  20


100%|██████████| 16/16 [00:00<00:00, 22.89it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.432]

Mean loss:  1.1014114506542683
Accuracy:  0.55
              precision    recall  f1-score   support

           0       0.53      0.94      0.68       250
           1       0.72      0.16      0.27       250

    accuracy                           0.55       500
   macro avg       0.62      0.55      0.47       500
weighted avg       0.62      0.55      0.47       500



100%|██████████| 133/133 [00:38<00:00,  3.50it/s, loss=0.409]
 19%|█▉        | 3/16 [00:00<00:00, 20.99it/s]

Epoch:  21


100%|██████████| 16/16 [00:00<00:00, 23.08it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.478]

Mean loss:  0.8047413006424904
Accuracy:  0.566
              precision    recall  f1-score   support

           0       0.55      0.71      0.62       250
           1       0.59      0.42      0.49       250

    accuracy                           0.57       500
   macro avg       0.57      0.57      0.56       500
weighted avg       0.57      0.57      0.56       500



100%|██████████| 133/133 [00:38<00:00,  3.49it/s, loss=0.401]
 19%|█▉        | 3/16 [00:00<00:00, 21.00it/s]

Epoch:  22


100%|██████████| 16/16 [00:00<00:00, 23.19it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.484]

Mean loss:  1.1312827244400978
Accuracy:  0.558
              precision    recall  f1-score   support

           0       0.54      0.82      0.65       250
           1       0.62      0.29      0.40       250

    accuracy                           0.56       500
   macro avg       0.58      0.56      0.52       500
weighted avg       0.58      0.56      0.52       500



100%|██████████| 133/133 [00:38<00:00,  3.48it/s, loss=0.4]  
 19%|█▉        | 3/16 [00:00<00:00, 21.08it/s]

Epoch:  23


100%|██████████| 16/16 [00:00<00:00, 22.97it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.431]

Mean loss:  1.1484573259949684
Accuracy:  0.546
              precision    recall  f1-score   support

           0       0.53      0.84      0.65       250
           1       0.61      0.26      0.36       250

    accuracy                           0.55       500
   macro avg       0.57      0.55      0.50       500
weighted avg       0.57      0.55      0.50       500



100%|██████████| 133/133 [00:37<00:00,  3.50it/s, loss=0.375]
 19%|█▉        | 3/16 [00:00<00:00, 21.07it/s]

Epoch:  24


100%|██████████| 16/16 [00:00<00:00, 23.06it/s]
  0%|          | 0/133 [00:00<?, ?it/s, loss=0.349]

Mean loss:  0.9514499381184578
Accuracy:  0.558
              precision    recall  f1-score   support

           0       0.55      0.70      0.61       250
           1       0.58      0.42      0.48       250

    accuracy                           0.56       500
   macro avg       0.56      0.56      0.55       500
weighted avg       0.56      0.56      0.55       500



100%|██████████| 133/133 [00:37<00:00,  3.50it/s, loss=0.377]
 19%|█▉        | 3/16 [00:00<00:00, 20.79it/s]

Epoch:  25


100%|██████████| 16/16 [00:00<00:00, 22.66it/s]

Mean loss:  1.0706781912595034
Accuracy:  0.536
              precision    recall  f1-score   support

           0       0.52      0.90      0.66       250
           1       0.63      0.18      0.28       250

    accuracy                           0.54       500
   macro avg       0.57      0.54      0.47       500
weighted avg       0.57      0.54      0.47       500






In [9]:
model.load_state_dict(max_state)

<All keys matched successfully>

In [13]:
test_model(model, dev_df)

100%|██████████| 16/16 [00:00<00:00, 23.15it/s]

Mean loss:  0.8273804374039173
Accuracy:  0.596
              precision    recall  f1-score   support

           0       0.59      0.62      0.61       250
           1       0.60      0.57      0.59       250

    accuracy                           0.60       500
   macro avg       0.60      0.60      0.60       500
weighted avg       0.60      0.60      0.60       500






0.596

In [14]:
model.save_pretrained("text_model_best")