In [2]:
from typing import Dict, List

import numpy as np
import os
import pandas as pd
import argparse
import torch
import math

from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch.nn.functional as F
import sklearn
from sklearn.utils import class_weight
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import r2_score
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from torch.utils.data import Dataset, DataLoader

SEED = 227


class DataPreparation(Dataset):
    
    def __init__(self, tokenizer, data, scale_init, intelligence='verb', max_length=None, if_scale=True):
        
        self.tokenizer = tokenizer
        self.data = data
        self.intell = intelligence
        self.scale = scale_init
        self.if_scale = if_scale
        
        if max_length == None:
            max_length_counted = data["text"].str.split(' ').str.len().max(axis=0)
            self.max_length = max_length_counted if max_length_counted < 512 else 512
        else:
            self.max_length = max_length


    def __len__(self):
        return len(self.data)


    def tokenize(self, text):

        tokens = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt')

        return tokens


    def scaling(self, labels):
      
        scaled_target = self.scale.transform(np.array(labels).reshape(-1, 1))
        
        return scaled_target

     
    def __getitem__(self, index):
        
        source_text = self.data['text'].iloc[index]
        source = self.tokenize(source_text)

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        
        if self.if_scale:
            scaled_labels = self.scaling(self.data[self.intell])
            label = scaled_labels[index][0]
        else:
            label = self.data[self.intell].iloc[index]

        return {
            "source_ids": source_ids.to(dtype=torch.long),
            "source_mask": source_mask.to(dtype=torch.long),
            "labels":  label
        }


class BertBaseline_ResNet(nn.Module):
    
    def __init__(self, model_name, out_features, inner_feautes=256):
        super(BertBaseline_ResNet, self).__init__()

        self.bert = AutoModel.from_pretrained(model_name, return_dict=True)

        self.linear_modules = nn.ModuleList([torch.nn.Linear(self.bert.config.hidden_size, inner_feautes),
                                          torch.nn.Linear(inner_feautes, inner_feautes),
                                          torch.nn.Linear(inner_feautes, out_features)])
        
        self.dropout = nn.Dropout(0.4)
        self.layer_norm = nn.LayerNorm(inner_feautes)
        self.relu = nn.ReLU()   
    
    def forward(self, input_ids, attention_mask):

         bert = self.bert(
            input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True
            )
#          x = bert[0][:, 0]  # last hidden state output
         token_embeddings = bert[0]
         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
         sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
         sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-8)
         x = sum_embeddings / sum_mask
        
         h = []
        
         for lin in self.linear_modules[:-1]:
            x = lin(x)
            h.append(x)
            x = self.layer_norm(x)
            x = self.relu(x)
            x = self.dropout(x)
        
         x = self.relu(h[-1] + h[-2])
         x = self.linear_modules[-1](x)
        
         return x
    

def initialize_scaling(data_org, intell):
    scale = StandardScaler().fit(np.array(data_org[intell]).reshape(-1, 1))
    return scale

def inverse_toorig(scaler, list_of_labels):
    inverse = scaler.inverse_transform(list_of_labels.reshape(-1, 1))
    return inverse


def train(model, data_loader, device, optimizer, criterion, n_epoch):

    print('Epoch #{}\n'.format(n_epoch+1))

    train_losses = []
    train_labels = []
    train_predictions = []   

    progress_bar = tqdm(total=math.ceil(len(data_loader.dataset)/data_loader.batch_size), 
                        desc='Epoch {}'.format(n_epoch + 1))

    model.train()

    for _, data in enumerate(data_loader, 0):


          input_ids = data["source_ids"].to(device)
          attention_mask = data["source_mask"].to(device)
          labels = data['labels'].to(device)

          optimizer.zero_grad()

          pred = model(input_ids=input_ids, attention_mask=attention_mask)
          loss = criterion(pred, labels)
              
          loss.backward()
              
          optimizer.step()

          _, predict = torch.max(pred.cpu().data, 1)
          train_losses.append(loss.item())
          train_labels.extend(labels.cpu().detach().numpy())
          train_predictions.extend(predict.cpu().detach().numpy())

          progress_bar.set_postfix(loss=np.mean(train_losses))
          progress_bar.update(1)
    
    progress_bar.update(1)
    progress_bar.close()
  
    
    print('\n\nMean Loss after epoch #{0} - {1}'.format(str(n_epoch + 1), np.mean(train_losses)))
    print('F1 score after epoch #{0} on train - {1}\n'.format(str(n_epoch + 1), f1_score(train_labels, train_predictions, average='macro')))
    print('Accuracy score after epoch #{0} on train - {1}\n'.format(str(n_epoch + 1), accuracy_score(train_labels, train_predictions)))

    print(classification_report(train_labels, train_predictions))
    
    return train_labels, train_predictions


def validating(model, data_loader, device, criterion, n_epoch):

    val_losses, val_labels, val_predictions = [], [], []

    progress_bar = tqdm(total=math.ceil(len(data_loader.dataset)/data_loader.batch_size),
                        desc='Epoch {}'.format(n_epoch + 1))

    model.eval()

    for _, data in enumerate(data_loader, 0):
          
          input_ids = data["source_ids"].to(device)
          attention_mask = data["source_mask"].to(device)
          labels = data['labels'].to(device)

          with torch.no_grad():
              pred = model(input_ids, attention_mask)

          loss = criterion(pred, labels)
          
          _, predict = torch.max(pred.cpu().data, 1)

          val_losses.append(loss.item())
          val_labels.extend(labels.cpu().detach().numpy())
          val_predictions.extend(predict.cpu().detach().numpy())

          progress_bar.set_postfix(loss=np.mean(val_losses))
          progress_bar.update(1)

    progress_bar.update(1)
    progress_bar.close()
    
    
    valid_stats.append(
        {
            'Val Loss': np.mean(val_losses)
        }
    )

    print('\n\nMean Loss after epoch #{0} - {1}'.format(str(n_epoch + 1), np.mean(val_losses)))
    print('F1 score after epoch #{0} on validation - {1}\n'.format(str(n_epoch + 1), f1_score(val_labels, val_predictions, average='macro')))
    print('Accuracy score after epoch #{0} on validation - {1}\n'.format(str(n_epoch + 1), accuracy_score(val_labels, val_predictions)))
    
    print(classification_report(val_labels, val_predictions))
    return valid_stats


def evaluate(model, train_dataset, val_dataset, device, epochs, target_value, weights):
    
    model = model.to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-6)
    criterion = nn.CrossEntropyLoss(weight=weights,reduction='mean').to(device)

    global valid_stats
    valid_stats = []
    best_valid_loss = float('inf')

    for epoch in range(epochs):
        # train
        try:
            train(model, train_dataset, device, optimizer, criterion,  epoch)
            # # validate
            validating(model, val_dataset, device, criterion, epoch)

            if valid_stats[epoch]['Val Loss'] < best_valid_loss:
                best_valid_loss = valid_stats[epoch]['Val Loss']

                name_to_save = f'model_baseline_basic_{target_value}'
                if os.path.isfile('results/'+name_to_save+'.pth'):
                    os.remove('results/'+name_to_save+'.pth')
                    torch.save(model.state_dict(), 'results/'+name_to_save+'.pth')
                else:
                    if not os.path.isdir('results'):
                        os.mkdir('results')
                    torch.save(model.state_dict(), 'results/'+name_to_save+'.pth')
#                     else:
#                         os.mkdir('results')
        except KeyboardInterrupt:
            break



In [None]:
path_to_data = '/kaggle/input/traits-no-naives/dataset_all_nlp_features_target_classes_no_naive (1).csv'
target_value = 'raven'
path_to_model = 'DeepPavlov/rubert-base-cased'
epochs = 15

dataset = pd.read_csv(path_to_data, sep='\t')

dataset = dataset[dataset.question_id != '129_Чтение текста - видео']

if target_value == 'raven':
    dataset = dataset[dataset["raven"] > 0]
if target_value == 'verb':
    dataset = dataset[dataset["verb"] > 0]
        
tokenizer = AutoTokenizer.from_pretrained(path_to_model)

intelligence = target_value+'_classes'

dataset[intelligence] = dataset[intelligence].astype(int)

dataset = dataset[dataset['N_words'] > 2]

train_data, extra_data = train_test_split(dataset, test_size=0.25,
                                        stratify=dataset[intelligence],
                                        random_state=SEED)

vaild_data, test_data = train_test_split(extra_data, test_size=0.4,
                                        stratify=extra_data[intelligence],
                                        random_state=SEED)
        
scaler = initialize_scaling(train_data, target_value)

train_dataset_data = DataPreparation(
            tokenizer=tokenizer,
            data = train_data,
            scale_init = scaler,
            intelligence = intelligence,
            max_length = 120,
            if_scale = False
        )

val_dataset_data = DataPreparation(
            tokenizer=tokenizer,
            data = vaild_data,
            scale_init = scaler,
            intelligence = intelligence,
            max_length = 120,
            if_scale = False
        )

test_dataset_data = DataPreparation(
            tokenizer=tokenizer,
            data = test_data,
            scale_init = scaler,
            intelligence = intelligence,
            max_length = 120,
            if_scale = False
        )

weights = class_weight.compute_class_weight(class_weight='balanced',classes=np.unique(dataset[intelligence]), y=dataset[intelligence].to_numpy())
wights_tensor = torch.tensor(weights,dtype=torch.float)

train_dataset = DataLoader(train_dataset_data, batch_size=8, drop_last=True, shuffle=True)
val_dataset = DataLoader(val_dataset_data, batch_size=8)
test_dataset = DataLoader(test_dataset_data, batch_size=8)
        
model = BertBaseline_ResNet(model_name=path_to_model, out_features=len(dataset[intelligence].unique()))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

evaluate(model=model, train_dataset=train_dataset, val_dataset=val_dataset, device=device, epochs=epochs, target_value=target_value, weights=wights_tensor)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch #1



Epoch 1:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #1 - 1.599271120453203
F1 score after epoch #1 on train - 0.21059650597777974

Accuracy score after epoch #1 on train - 0.31946790540540543

              precision    recall  f1-score   support

           0       0.23      0.13      0.17       810
           1       0.19      0.15      0.17       899
           2       0.46      0.50      0.48      2080
           3       0.18      0.30      0.23       759
           4       0.17      0.01      0.01       188

    accuracy                           0.32      4736
   macro avg       0.24      0.22      0.21      4736
weighted avg       0.31      0.32      0.31      4736



Epoch 1:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #1 - 1.5722229671077568
F1 score after epoch #1 on validation - 0.23500646300332306

Accuracy score after epoch #1 on validation - 0.35232067510548526

              precision    recall  f1-score   support

           0       0.30      0.36      0.33       162
           1       0.22      0.08      0.12       180
           2       0.46      0.52      0.49       416
           3       0.21      0.28      0.24       152
           4       0.00      0.00      0.00        38

    accuracy                           0.35       948
   macro avg       0.24      0.25      0.24       948
weighted avg       0.33      0.35      0.33       948



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch #2



Epoch 2:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #2 - 1.563358065445681
F1 score after epoch #2 on train - 0.231134786087318

Accuracy score after epoch #2 on train - 0.32538006756756754

              precision    recall  f1-score   support

           0       0.26      0.38      0.31       809
           1       0.22      0.12      0.15       900
           2       0.46      0.43      0.44      2081
           3       0.21      0.32      0.25       758
           4       0.00      0.00      0.00       188

    accuracy                           0.33      4736
   macro avg       0.23      0.25      0.23      4736
weighted avg       0.32      0.33      0.32      4736



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #2 - 1.5584655980102153
F1 score after epoch #2 on validation - 0.21379410711979507

Accuracy score after epoch #2 on validation - 0.3059071729957806

              precision    recall  f1-score   support

           0       0.24      0.67      0.36       162
           1       0.27      0.04      0.07       180
           2       0.48      0.29      0.36       416
           3       0.24      0.34      0.28       152
           4       0.00      0.00      0.00        38

    accuracy                           0.31       948
   macro avg       0.25      0.27      0.21       948
weighted avg       0.34      0.31      0.28       948



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch #3



Epoch 3:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #3 - 1.5452891342543267
F1 score after epoch #3 on train - 0.2333568248957958

Accuracy score after epoch #3 on train - 0.31777871621621623

              precision    recall  f1-score   support

           0       0.26      0.53      0.35       810
           1       0.19      0.08      0.12       900
           2       0.48      0.34      0.40      2078
           3       0.23      0.37      0.28       759
           4       0.08      0.01      0.01       189

    accuracy                           0.32      4736
   macro avg       0.25      0.27      0.23      4736
weighted avg       0.33      0.32      0.31      4736



Epoch 3:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #3 - 1.549458850331667
F1 score after epoch #3 on validation - 0.19512127009881106

Accuracy score after epoch #3 on validation - 0.26371308016877637

              precision    recall  f1-score   support

           0       0.26      0.60      0.36       162
           1       0.19      0.06      0.09       180
           2       0.42      0.17      0.24       416
           3       0.21      0.48      0.29       152
           4       0.00      0.00      0.00        38

    accuracy                           0.26       948
   macro avg       0.21      0.26      0.20       948
weighted avg       0.30      0.26      0.23       948



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch #4



Epoch 4:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #4 - 1.5312168491450515
F1 score after epoch #4 on train - 0.24182747999639714

Accuracy score after epoch #4 on train - 0.32263513513513514

              precision    recall  f1-score   support

           0       0.27      0.54      0.36       810
           1       0.23      0.11      0.15       899
           2       0.49      0.32      0.38      2079
           3       0.25      0.45      0.32       759
           4       0.00      0.00      0.00       189

    accuracy                           0.32      4736
   macro avg       0.25      0.28      0.24      4736
weighted avg       0.35      0.32      0.31      4736



Epoch 4:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #4 - 1.5415188164270224
F1 score after epoch #4 on validation - 0.2292765402919413

Accuracy score after epoch #4 on validation - 0.3059071729957806

              precision    recall  f1-score   support

           0       0.26      0.66      0.38       162
           1       0.20      0.10      0.13       180
           2       0.47      0.26      0.34       416
           3       0.25      0.36      0.30       152
           4       0.00      0.00      0.00        38

    accuracy                           0.31       948
   macro avg       0.24      0.28      0.23       948
weighted avg       0.33      0.31      0.29       948

Epoch #5



Epoch 5:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #5 - 1.5148441644536483
F1 score after epoch #5 on train - 0.24807635649689325

Accuracy score after epoch #5 on train - 0.32538006756756754

              precision    recall  f1-score   support

           0       0.28      0.59      0.38       809
           1       0.24      0.12      0.16       898
           2       0.48      0.30      0.37      2081
           3       0.26      0.43      0.32       759
           4       0.06      0.01      0.01       189

    accuracy                           0.33      4736
   macro avg       0.26      0.29      0.25      4736
weighted avg       0.35      0.33      0.31      4736



Epoch 5:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #5 - 1.5397659349842232
F1 score after epoch #5 on validation - 0.24259612618960605

Accuracy score after epoch #5 on validation - 0.31329113924050633

              precision    recall  f1-score   support

           0       0.27      0.65      0.38       162
           1       0.21      0.17      0.19       180
           2       0.49      0.27      0.35       416
           3       0.27      0.33      0.29       152
           4       0.00      0.00      0.00        38

    accuracy                           0.31       948
   macro avg       0.25      0.28      0.24       948
weighted avg       0.35      0.31      0.30       948

Epoch #6



Epoch 6:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #6 - 1.5015217194686066
F1 score after epoch #6 on train - 0.2727541805590225

Accuracy score after epoch #6 on train - 0.33614864864864863

              precision    recall  f1-score   support

           0       0.29      0.59      0.39       810
           1       0.26      0.16      0.19       900
           2       0.51      0.29      0.37      2080
           3       0.26      0.47      0.34       757
           4       0.29      0.04      0.07       189

    accuracy                           0.34      4736
   macro avg       0.32      0.31      0.27      4736
weighted avg       0.38      0.34      0.32      4736



Epoch 6:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #6 - 1.529074730993319
F1 score after epoch #6 on validation - 0.25327629759411074

Accuracy score after epoch #6 on validation - 0.31856540084388185

              precision    recall  f1-score   support

           0       0.29      0.58      0.38       162
           1       0.25      0.12      0.16       180
           2       0.50      0.27      0.35       416
           3       0.25      0.49      0.33       152
           4       0.20      0.03      0.05        38

    accuracy                           0.32       948
   macro avg       0.30      0.30      0.25       948
weighted avg       0.36      0.32      0.30       948

Epoch #7



Epoch 7:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #7 - 1.4810221036140983
F1 score after epoch #7 on train - 0.2762011636628466

Accuracy score after epoch #7 on train - 0.34375

              precision    recall  f1-score   support

           0       0.30      0.60      0.40       809
           1       0.27      0.16      0.20       898
           2       0.52      0.29      0.37      2082
           3       0.28      0.53      0.37       759
           4       0.15      0.03      0.05       188

    accuracy                           0.34      4736
   macro avg       0.30      0.32      0.28      4736
weighted avg       0.38      0.34      0.33      4736



Epoch 7:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #7 - 1.522920296973541
F1 score after epoch #7 on validation - 0.2608441707453277

Accuracy score after epoch #7 on validation - 0.3438818565400844

              precision    recall  f1-score   support

           0       0.28      0.56      0.37       162
           1       0.17      0.04      0.06       180
           2       0.46      0.41      0.43       416
           3       0.27      0.37      0.31       152
           4       0.27      0.08      0.12        38

    accuracy                           0.34       948
   macro avg       0.29      0.29      0.26       948
weighted avg       0.34      0.34      0.32       948

Epoch #8



Epoch 8:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #8 - 1.4623801112174988
F1 score after epoch #8 on train - 0.2954765784340211

Accuracy score after epoch #8 on train - 0.3597972972972973

              precision    recall  f1-score   support

           0       0.30      0.60      0.40       810
           1       0.30      0.16      0.21       898
           2       0.53      0.33      0.41      2080
           3       0.29      0.50      0.37       759
           4       0.15      0.06      0.09       189

    accuracy                           0.36      4736
   macro avg       0.31      0.33      0.30      4736
weighted avg       0.39      0.36      0.35      4736



Epoch 8:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #8 - 1.5230726254086535
F1 score after epoch #8 on validation - 0.25459864373557256

Accuracy score after epoch #8 on validation - 0.3048523206751055

              precision    recall  f1-score   support

           0       0.26      0.73      0.38       162
           1       0.25      0.12      0.16       180
           2       0.50      0.21      0.30       416
           3       0.27      0.38      0.31       152
           4       0.25      0.08      0.12        38

    accuracy                           0.30       948
   macro avg       0.31      0.30      0.25       948
weighted avg       0.37      0.30      0.28       948

Epoch #9



Epoch 9:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #9 - 1.4457948447482005
F1 score after epoch #9 on train - 0.31860686056159826

Accuracy score after epoch #9 on train - 0.3733108108108108

              precision    recall  f1-score   support

           0       0.32      0.61      0.42       808
           1       0.31      0.20      0.24       899
           2       0.56      0.32      0.41      2082
           3       0.32      0.53      0.40       758
           4       0.17      0.11      0.13       189

    accuracy                           0.37      4736
   macro avg       0.33      0.35      0.32      4736
weighted avg       0.41      0.37      0.37      4736



Epoch 9:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #9 - 1.5135859481426848
F1 score after epoch #9 on validation - 0.2800816817243937

Accuracy score after epoch #9 on validation - 0.3322784810126582

              precision    recall  f1-score   support

           0       0.28      0.65      0.39       162
           1       0.22      0.09      0.13       180
           2       0.51      0.31      0.38       416
           3       0.29      0.39      0.34       152
           4       0.17      0.16      0.16        38

    accuracy                           0.33       948
   macro avg       0.29      0.32      0.28       948
weighted avg       0.37      0.33      0.32       948

Epoch #10



Epoch 10:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #10 - 1.420368719644643
F1 score after epoch #10 on train - 0.32934958487685434

Accuracy score after epoch #10 on train - 0.3745777027027027

              precision    recall  f1-score   support

           0       0.32      0.63      0.43       810
           1       0.31      0.18      0.23       899
           2       0.57      0.31      0.40      2079
           3       0.32      0.56      0.40       759
           4       0.21      0.16      0.18       189

    accuracy                           0.37      4736
   macro avg       0.35      0.37      0.33      4736
weighted avg       0.42      0.37      0.36      4736



Epoch 10:   0%|          | 0/119 [00:00<?, ?it/s]



Mean Loss after epoch #10 - 1.5094299867373555
F1 score after epoch #10 on validation - 0.2873726371970286

Accuracy score after epoch #10 on validation - 0.32489451476793246

              precision    recall  f1-score   support

           0       0.28      0.62      0.38       162
           1       0.21      0.08      0.12       180
           2       0.49      0.28      0.35       416
           3       0.29      0.45      0.35       152
           4       0.23      0.24      0.23        38

    accuracy                           0.32       948
   macro avg       0.30      0.33      0.29       948
weighted avg       0.36      0.32      0.31       948

Epoch #11



Epoch 11:   0%|          | 0/593 [00:00<?, ?it/s]



Mean Loss after epoch #11 - 1.3968709302512374
F1 score after epoch #11 on train - 0.34988943594642086

Accuracy score after epoch #11 on train - 0.38640202702702703

              precision    recall  f1-score   support

           0       0.34      0.63      0.44       809
           1       0.33      0.22      0.27       898
           2       0.56      0.32      0.41      2081
           3       0.34      0.54      0.42       759
           4       0.19      0.25      0.22       189

    accuracy                           0.39      4736
   macro avg       0.35      0.39      0.35      4736
weighted avg       0.43      0.39      0.38      4736



Epoch 11:   0%|          | 0/119 [00:00<?, ?it/s]

In [11]:
def test(data_loader, device):
    
    model.load_state_dict(torch.load('/kaggle/working/results/model_baseline_basic_GP.pth'))
    
    test_labels, test_predictions = [], []

    model.eval()

    for _, data in enumerate(data_loader, 0):
          
          input_ids = data["source_ids"].to(device)
          attention_mask = data["source_mask"].to(device)
          labels = data['labels'].to(device)

          with torch.no_grad():
              pred = model(input_ids, attention_mask)
          
          _, predict = torch.max(pred.cpu().data, 1)

          test_labels.extend(labels.cpu().detach().numpy())
          test_predictions.extend(predict.cpu().detach().numpy())

    print('F1 macro score on test - {0}\n'.format(f1_score(test_labels, test_predictions, average='macro')))
    print('F1 score on test - {0}\n'.format(f1_score(test_labels, test_predictions, average='weighted')))
    print('Accuracy score on test - {0}\n'.format(accuracy_score(test_labels, test_predictions)))
    
    print(classification_report(test_labels, test_predictions))

In [12]:
test(train_dataset, device)

F1 macro score on test - 0.454995683567936

F1 score on test - 0.4539857512974617

Accuracy score on test - 0.4592257001647446

              precision    recall  f1-score   support

           0       0.47      0.39      0.43      1430
           1       0.54      0.38      0.44      1956
           2       0.41      0.64      0.50      1470

    accuracy                           0.46      4856
   macro avg       0.47      0.47      0.45      4856
weighted avg       0.48      0.46      0.45      4856



In [13]:
test(val_dataset, device)

F1 macro score on test - 0.37768992180184613

F1 score on test - 0.3721391099125639

Accuracy score on test - 0.37962962962962965

              precision    recall  f1-score   support

           0       0.41      0.37      0.39       286
           1       0.41      0.27      0.32       391
           2       0.35      0.54      0.42       295

    accuracy                           0.38       972
   macro avg       0.39      0.39      0.38       972
weighted avg       0.39      0.38      0.37       972



In [14]:
test(test_dataset, device)

F1 macro score on test - 0.33272413369477605

F1 score on test - 0.3332291179285965

Accuracy score on test - 0.33487654320987653

              precision    recall  f1-score   support

           0       0.35      0.28      0.31       191
           1       0.39      0.30      0.34       261
           2       0.29      0.44      0.35       196

    accuracy                           0.33       648
   macro avg       0.34      0.34      0.33       648
weighted avg       0.35      0.33      0.33       648

