In [None]:
import pandas as pd
import numpy as np
import torch 
from tqdm import tqdm

In [None]:
df_train = pd.read_csv('/kaggle/input/vishd-comments-dataset/train.csv')
df_test = pd.read_csv('/kaggle/input/vishd-comments-dataset/test.csv')
df_valid = pd.read_csv('/kaggle/input/vishd-comments-dataset/dev.csv')

In [None]:
df_train.head()

# Data cleaning

### Remove na values

In [None]:
df_train.isna().sum()

In [None]:
df_train[df_train['free_text'].isna()]

In [None]:
# Just drop it 
df_train = df_train.dropna(subset=['free_text'])

### Clean out the emoji

In [None]:
import re

df_train['free_text'] = df_train['free_text'].apply(lambda x: re.sub(r'[^\w\s#@/:%.,_-]', '', x))

In [None]:
df_train['free_text'].head()

### Standardize the vietnamese text

In [None]:
# lowercase all 
df_train['free_text'] = df_train['free_text'].apply(lambda x: x.lower())

In [None]:
# separate punctuation from words
df_train['free_text'] = df_train['free_text'].apply(lambda x: re.sub(r'(?<=[^\s])\s*([^\w\s])', r' \1', x))

I choose not to remove punctuations in this case as it may represent a sentence structure that as a whole shapes an offensive or not sentence. Thus removing punctuation may disrupt the natural structure of the text and impact downstream classifying tasks

In [None]:
df_train.head()

In [None]:
# # optional, turn bad worlds into its original form
# # form the bad words dictionaries
# bad_words_txt = '../vn_offensive_words.txt'
# bad_words_dict = {}
# with open(bad_words_txt, 'r') as f:
#     bad_words = f.read().splitlines()
#     origin = ""

#     for sent in bad_words:
#         temp = sent.split(' ')

#         if (len(temp) > 1 and temp[0] == '#'):
#             origin = ' '.join(temp[1:])
#             continue
        
#         if (origin != ""):
#             bad_words_dict[sent] = origin

In [None]:
# # sorry for the bad words :(
# bad_words_dict

In [None]:
# # replace all bad words variants with its original form
# def replace_bad_words(text):
#     for bad, origin in bad_words_dict.items():
#         text = text.replace(bad, origin)
#     return text

The function might be useful later

### Check output distribution

In [None]:
df_train['label_id'].value_counts()

We have the following label:
*   0: non-offensive
*   1: Offensive
*   2: Hate 

We see here the data is imbalance.
1 and 2 are similar, differ only at its level of hate. Thus as 0s outnumber the other 2 labels, we shall merge 1 and 2

In [None]:
df_train['label_id'] = df_train['label_id'].apply(lambda x: 1 if x in [1, 2] else x)
df_train['label_id'].value_counts()

yet the data is still imbalance. If we predict 0 for all case we would have 0.82 accuracy!We shall counter it with the choice of metrics later

### Apply the same processing step for test and valid

In [None]:
# drop na
df_test = df_test.dropna(subset=['free_text'])
df_valid = df_valid.dropna(subset=['free_text'])

# clean the emoji 
df_valid['free_text'] = df_valid['free_text'].apply(lambda x: re.sub(r'[^\w\s#@/:%.,_-]', '', x))
df_test['free_text'] = df_test['free_text'].apply(lambda x: re.sub(r'[^\w\s#@/:%.,_-]', '', x))

# standardize the text
# lowercase all
df_valid['free_text'] = df_valid['free_text'].apply(lambda x: x.lower())
df_test['free_text'] = df_test['free_text'].apply(lambda x: x.lower())

# separate punctuation from words
df_valid['free_text'] = df_valid['free_text'].apply(lambda x: re.sub(r'(?<=[^\s])\s*([^\w\s])', r' \1', x))
df_test['free_text'] = df_test['free_text'].apply(lambda x: re.sub(r'(?<=[^\s])\s*([^\w\s])', r' \1', x))

In [None]:
# Merge the 1 and 2 labels
df_valid['label_id'] = df_valid['label_id'].apply(lambda x: 1 if x in [1, 2] else x)
df_test['label_id'] = df_test['label_id'].apply(lambda x: 1 if x in [1, 2] else x)

In [None]:
df_valid['label_id'].value_counts()

# Model

### Benchmark : Bag-of-words with logistic regression

We use this basic model as a simple benchmark for out task

In [None]:
!pip install pyvi

In [None]:
# Apply k fold
from sklearn.model_selection import StratifiedKFold

df_train['kfold'] = -1

df_train = df_train.sample(frac=1).reset_index(drop=True)

y = df_train['label_id'].values

kf = StratifiedKFold(n_splits=5)
for f, (t_, v_) in enumerate(kf.split(X=df_train, y=y)):
    df_train.loc[v_, 'kfold'] = f

df_train['kfold'].value_counts()

In [22]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import accuracy_score, classification_report
from sklearn.linear_model import LogisticRegression
from pyvi import ViTokenizer

for fold_ in range(5):
    train_df = df_train[df_train.kfold != fold_]
    valid_df = df_train[df_train.kfold == fold_]
    
    vectorizer = CountVectorizer(tokenizer=ViTokenizer.tokenize)
    vectorizer.fit(train_df['free_text'])
    
    x_train = vectorizer.transform(train_df['free_text'])
    x_valid = vectorizer.transform(valid_df['free_text'])

    y_train = train_df['label_id']
    y_valid = valid_df['label_id']
    
    model = LogisticRegression()
    model.fit(x_train, y_train)

    # threshold currently 0.5
    preds = model.predict(x_valid)
    print(f'Fold {fold_}')
    print(accuracy_score(y_valid, preds))
    print(classification_report(y_valid, preds))

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Fold 0
0.8318087318087318
              precision    recall  f1-score   support

           0       0.84      0.98      0.91      3977
           1       0.56      0.15      0.23       833

    accuracy                           0.83      4810
   macro avg       0.70      0.56      0.57      4810
weighted avg       0.79      0.83      0.79      4810



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Fold 1
0.835308796007486
              precision    recall  f1-score   support

           0       0.85      0.98      0.91      3977
           1       0.60      0.15      0.24       832

    accuracy                           0.84      4809
   macro avg       0.72      0.56      0.57      4809
weighted avg       0.80      0.84      0.79      4809



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Fold 2
0.8405073819920982
              precision    recall  f1-score   support

           0       0.85      0.98      0.91      3977
           1       0.65      0.17      0.27       832

    accuracy                           0.84      4809
   macro avg       0.75      0.58      0.59      4809
weighted avg       0.81      0.84      0.80      4809



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Fold 3
0.8369723435225619
              precision    recall  f1-score   support

           0       0.85      0.98      0.91      3977
           1       0.63      0.14      0.23       832

    accuracy                           0.84      4809
   macro avg       0.74      0.56      0.57      4809
weighted avg       0.81      0.84      0.79      4809

Fold 4
0.8351008525681015
              precision    recall  f1-score   support

           0       0.85      0.97      0.91      3977
           1       0.58      0.17      0.26       832

    accuracy                           0.84      4809
   macro avg       0.71      0.57      0.58      4809
weighted avg       0.80      0.84      0.80      4809



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Say in this problem, I think we focus on recall better, as the cost of missing some hateful comments is high (may affect children). Our recall is already very high.

Average accuracy : ~0.84

Average F1 : ~0.91

Average precision : ~0.85

Average Recall : ~0.97

**Let's now test the performance on the valid set**

In [23]:
vectorizer = CountVectorizer(tokenizer=ViTokenizer.tokenize)
vectorizer.fit(df_train['free_text'])

x_train = vectorizer.transform(df_train['free_text'])
x_valid = vectorizer.transform(df_valid['free_text'])

y_train = df_train['label_id']
y_valid = df_valid['label_id']

model = LogisticRegression()
model.fit(x_train, y_train)

# threshold currently 0.5
preds = model.predict(x_valid)
print(accuracy_score(y_valid, preds))
print(classification_report(y_valid, preds))



0.8297155688622755
              precision    recall  f1-score   support

           0       0.84      0.98      0.90      2190
           1       0.60      0.16      0.26       482

    accuracy                           0.83      2672
   macro avg       0.72      0.57      0.58      2672
weighted avg       0.80      0.83      0.79      2672



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


#### Use n-gram


Usually bad word phrases in Vietnamese goes in pair or group of 3, I have an intuition that using n-gram with bag-of-words can be useful. 

It is also easy to implement

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import accuracy_score, classification_report
from sklearn.linear_model import LogisticRegression
from pyvi import ViTokenizer

for fold_ in range(5):
    train_df = df_train[df_train.kfold != fold_]
    valid_df = df_train[df_train.kfold == fold_]

    vectorizer = CountVectorizer(tokenizer=ViTokenizer.tokenize, ngram_range=(1, 3))
    vectorizer.fit(train_df['free_text'])

    x_train = vectorizer.transform(train_df['free_text'])
    x_valid = vectorizer.transform(valid_df['free_text'])

    y_train = train_df['label_id']
    y_valid = valid_df['label_id']

    model = LogisticRegression()
    model.fit(x_train, y_train)

    # threshold currently 0.5
    preds = model.predict(x_valid)
    print(f'Fold {fold_}')
    print(accuracy_score(y_valid, preds))
    print(classification_report(y_valid, preds))

Quite an improvement for only a little change.  


Average accuracy : ~0.88

Average F1 : ~0.93

Average precision : ~0.91

Average Recall : ~0.95


**Let's save the model**

In [None]:
import pickle

# Save the vectorizer
with open('vectorizer.pkl', 'wb') as f:
    pickle.dump(vectorizer, f)

with open ('model.pkl', 'wb') as f:
    pickle.dump(model, f)

**Let's see how it perform on unseen examples**

In [None]:
vectorizer = CountVectorizer(tokenizer=ViTokenizer.tokenize,ngram_range=(1, 3))
vectorizer.fit(df_train['free_text'])

x_train = vectorizer.transform(df_train['free_text'])
x_valid = vectorizer.transform(df_valid['free_text'])

y_train = df_train['label_id']
y_valid = df_valid['label_id']

model = LogisticRegression()
model.fit(x_train, y_train)

preds = model.predict(x_valid)
print(accuracy_score(y_valid, preds))
print(classification_report(y_valid, preds))

Quite impressive already. Perhaps this is due to, most of the comments that are labeled offensive or hate speech, are actually based on certain bad word phrases.

### Benchmark: Stacked-LSTM

Next let's try a deep model. See how it performs on the same dataset.

**Cuda check!!!**

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#### init model

In [89]:
import torch


class LSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, emb_matrix=None):
        super(LSTM, self).__init__()
        self.emb = torch.nn.Embedding(num_embeddings= len(vocab), embedding_dim= input_size,
                                      padding_idx=1, _weight=emb_matrix)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = torch.nn.LSTM(
            input_size, hidden_size, num_layers, batch_first=True,dropout=0.1)
        self.fc = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(
            0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(
            0), self.hidden_size).to(x.device)
        
        emb = self.emb(x)
        out, _ = self.lstm(emb, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

#### Utils for training

In [21]:
def save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss
    }
    torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, filename="checkpoint.pth"):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    return model, optimizer, epoch, loss

In [22]:
def train(model, optimizer, loss_fn, train_loader, valid_loader, epochs=5, file_name="checkpoint.pth"):
    train_losses = []
    valid_losses = [] 
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        train_loss_batch =[]
        for data in pbar:
            x, y = data
            
            optimizer.zero_grad()
            y_pred = model(x)
            
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()


            train_loss += loss.item()
            pbar.set_postfix({'Train Loss': train_loss / len(train_loader)})
            train_loss_batch.append(train_loss / len(train_loader))
        
        train_losses.append(np.mean(train_loss_batch))
        
        model.eval()
        valid_loss = 0.0
        valid_loss_batch = []
        with torch.no_grad():
            for data in valid_loader:
                x, y = data
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                valid_loss += loss.item()
            valid_loss_batch.append(valid_loss / len(valid_loader))
        valid_losses.append(np.mean(valid_loss_batch))
        
        print(f'Epoch {epoch}, Train Loss: {train_losses[-1]}, Valid Loss: {valid_losses[-1]}')

    save_checkpoint(model, optimizer, epoch, loss, filename=file_name)
    return train_losses, valid_losses

    

#### Preprocess 

In [23]:
!pip install underthesea

Collecting underthesea
  Downloading underthesea-6.8.4-py3-none-any.whl.metadata (15 kB)
Collecting python-crfsuite>=0.9.6 (from underthesea)
  Downloading python_crfsuite-0.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting underthesea-core==1.0.4 (from underthesea)
  Downloading underthesea_core-1.0.4-cp310-cp310-manylinux2010_x86_64.whl.metadata (1.7 kB)
Downloading underthesea-6.8.4-py3-none-any.whl (20.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m62.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading underthesea_core-1.0.4-cp310-cp310-manylinux2010_x86_64.whl (657 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m657.8/657.8 kB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading python_crfsuite-0.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m45.0

In [24]:
MAX_LENGTH = 150
pad_token = "<pad>"
unk_token = "<unk>"
def pad_tokens(tokens):
    if (len(tokens) >= MAX_LENGTH):
        return tokens[:MAX_LENGTH]
    else:
        return tokens + [pad_token] * (MAX_LENGTH - len(tokens))

In [25]:
from underthesea import word_tokenize
from torchtext.vocab import build_vocab_from_iterator


def yield_tokens(df_series):
    for text in df_series:
        yield word_tokenize(text)

vocab = build_vocab_from_iterator(yield_tokens(df_train['free_text']), specials=[pad_token, unk_token])

vocab.set_default_index(vocab[unk_token])


#### Create train vs valid loader

In [26]:
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, df, vocab):
        self.df = df
        self.vocab = vocab

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = row['free_text']
        padded_tokens = pad_tokens(word_tokenize(text))
        ids = torch.tensor(vocab.lookup_indices(padded_tokens))
        y = row['label_id']
        return ids, torch.tensor([y], dtype=torch.float32)

In [27]:
train_ds = TextDataset(df_train, vocab)
valid_ds = TextDataset(df_valid, vocab)

In [28]:
from torch.utils.data.dataloader import default_collate

BATCH_SIZE = 8
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
                      shuffle=True,
                      collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
val_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE*2,
                    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

#### The training process

In [33]:
model = LSTM(input_size=128, hidden_size=256, num_layers=2, num_classes=1).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_losses, valid_losses = train(model, optimizer, loss_fn, train_dl, val_dl, epochs=50, file_name="checkpoint.pth")

Epoch 0:   5%|▍         | 146/3006 [00:03<01:15, 38.03it/s, Train Loss=0.0224]


KeyboardInterrupt: 

#### graph out

In [None]:
import matplotlib.plt as plt

# Create the plot
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')   

plt.ylabel('Loss')
plt.legend()
plt.show()   

The valid loss is quite high.

I believe this is  due to the large amount of unknown words in the valid sets not seen during training, and we are using fairly simple technique of embedding and thus not really prepare for unknown words.

#### Using FastText for better text representation

Based on the performance of the original lstm model, let's try using fasttext, which is a static text representation model that has support for Vietnamese,to see if it will improve the performance.

In [29]:
!pip install fasttext



In [32]:
import fasttext 
ft = fasttext.load_model('/kaggle/input/fasttext-vietnamese-word-vectors-full/cc.vi.300.bin')
ft.get_dimension()

300

In [69]:
vocab_size = len(vocab)
print(vocab_size)
embedding_matrix = np.random.random((vocab_size, 300))
embedding_vector = np.zeros(300)
for voc in tqdm(range(len(vocab))):
    word = vocab.get_itos()[voc]
    i = vocab.get_stoi()[word]
    try:
        embedding_vector = ft.get_word_vector(word)
    except:
        print(word, 'not found')
    if embedding_vector is not None:
        embedding_matrix[i, :] = embedding_vector

24346


100%|██████████| 24346/24346 [08:46<00:00, 46.22it/s]


In [77]:
embedding_matrix_tensor= torch.from_numpy(embedding_matrix).float()

In [None]:
model = LSTM(input_size=300, hidden_size=256, num_layers=2, num_classes=1, emb_matrix=embedding_matrix_tensor).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

train_losses, valid_losses = train(model, optimizer, loss_fn, train_dl, val_dl, epochs=50, file_name="checkpoint.pth")

In [67]:
len(vocab)

24346

In [68]:
embedding_matrix.shape

(24347, 300)

### PhoBert finetuning

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#### Prepare data

In [None]:
!pip install transformers

**Load Model**

In [None]:
from transformers import (
    AutoModel, AutoConfig, XLMRobertaModel,
    AutoTokenizer, AutoModelForSequenceClassification
)

input_model = AutoModelForSequenceClassification.from_pretrained("vinai/phobert-base")
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

input_model.resize_token_embeddings(len(tokenizer))


In [None]:
# for name, param in input_model.named_parameters():
#     if 'classifier' not in name: # classifier layer
#         param.requires_grad = False

As we have very little data, it is important that we only train the classifier head, or the last few layer of the model to prevent overfitting. thus in this step I freeze the whole bert-model and train the classifier layer only

#### Prepare dataset

In [None]:
def tokenize(my_str, tokenizer):
    mapped_tokenize = tokenizer(my_str)

    ids = mapped_tokenize['input_ids']
    att_mask = mapped_tokenize['attention_mask']
    return ids, att_mask

In [None]:
from torch.utils.data import Dataset, DataLoader

class bert_dataset_from_df(Dataset):
    def __init__(self, df, tokenizer,max_len=150):
        self.df = df
        self.max_len = max_len
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = row['free_text']
        map_tokenized = self.tokenizer(text,padding='max_length',
                                max_length = 64, truncation=True,
                                return_tensors="pt")
        y = row['label_id']
        target = torch.tensor([1,0], dtype=torch.float32)
        if y == 0:
            target = torch.tensor([1,0], dtype=torch.float32)
        else:
            target = torch.tensor([0,1], dtype=torch.float32)
        return map_tokenized, target

        

In [None]:
# create dataloader
train_ds = bert_dataset_from_df(df_train, tokenizer)
val_ds = bert_dataset_from_df(df_valid,tokenizer)

train_dl = DataLoader(train_ds,batch_size=8, shuffle=True)
valid_dl = DataLoader(val_ds,batch_size=16)

#### Utils for Bert model

In [None]:
def bert_train(model, train_dataloader, dev_dataloader, criterion_span, optimizer_spans, device, num_epochs):
    train_losses = []
    val_losses = []
    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        print('Epoch: ', epoch+1)
        for texts, target in tqdm(train_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            targets = target.to(device)

            optimizer_spans.zero_grad()
#             preds = model(input_ids, attention_mask)
            preds = model(input_ids,token_type_ids=None,
                          attention_mask=attention_mask, labels= targets)
#             loss_span = criterion_span(preds, targets)
            loss = preds.loss
            loss.backward()

            optimizer_spans.step()
            total_loss += loss.item()
            
        
        train_losses.append(total_loss/len(train_dataloader))

        # Calculate validation loss and macro F1-score
        val_loss = 0
        for texts, target in tqdm(dev_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            targets = target.to(device)
            with torch.no_grad():
                preds = model(input_ids,token_type_ids=None,
                          attention_mask=attention_mask, labels= targets)
                
#                 loss_span = criterion_span(preds.squeeze(), targets)
                loss = preds.loss
                val_loss += loss #+ loss_label

        val_losses.append(val_loss/len(dev_dataloader))
    
        print(f'Epoch {epoch}, Train Loss: {train_losses[-1]}, Valid Loss: {val_losses[-1]}')

    save_checkpoint(model, optimizer_spans, epoch, train_losses, filename="bert_ckpt.pth")
    return train_losses, val_losses

In [None]:
# try training
model = input_model
criterion_span = torch.nn.BCELoss()
optimizer_spans = torch.optim.Adam(list(model.parameters()), lr=5e-6, weight_decay=1e-5)
bert_train(model, train_dl, valid_dl, criterion_span, optimizer_spans, device, 5)