In [1]:
# !pip install tensorboard
# !pip install ipywidgets widgetsnbextension pandas-profiling
import os
import math
import random
import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import StratifiedKFold
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import copy


In [2]:
!nvidia-smi

Tue Aug 31 05:58:08 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 455.23.05    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN X (Pascal)    On   | 00000000:04:00.0 Off |                  N/A |
| 41%   62C    P0    80W / 250W |      1MiB / 12196MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN X (Pascal)    On   | 00000000:06:00.0 Off |                  N/A |
| 48%   72C    P0    77W / 250W |      1MiB / 12196MiB |      0%      Default |
|       

In [3]:
torch.cuda.is_available()

True

In [4]:
train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')

In [5]:
train_df.head()

Unnamed: 0,id,title,abstract,judgement
0,0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0
1,1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0
2,2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0
3,3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0
4,4,Prolonged shedding of SARS-CoV-2 in an elderly...,,0


In [6]:
border = len(train_df[train_df["judgement"] == 1]) / len(train_df["judgement"])
print(border)

0.023282372444280715


In [7]:
len(train_df)

27145

In [8]:
train_df.isnull().sum()

id              0
title           0
abstract     4390
judgement       0
dtype: int64

In [9]:
def clean_data(train_df):
    train_df.loc[train_df['abstract'].isnull(), 'abstract'] = train_df['title']
    return train_df

def get_texts(df):
    titles = df['title'].values.tolist()
    abstracts = df['abstract'].values.tolist()
    return titles, abstracts

def get_labels(df):
    labels = df.iloc[:, 3].values
    return labels
    
train_df = clean_data(train_df)
test_df= clean_data(test_df)
train_df.head()

Unnamed: 0,id,title,abstract,judgement
0,0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0
1,1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0
2,2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0
3,3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0
4,4,Prolonged shedding of SARS-CoV-2 in an elderly...,Prolonged shedding of SARS-CoV-2 in an elderly...,0


In [10]:
train_df.isnull().sum()
test_df.isnull().sum()

id          0
title       0
abstract    0
dtype: int64

In [11]:
# ax = sns.countplot(x='judgement', data=train_df)?

In [12]:
title = train_df['title'].tolist()
# y = [len(t.split()) for t in title]
# x = range(0, len(y))
# plt.bar(x, y)

In [13]:
abstract = train_df['abstract'].tolist()
# y = [len(a.split()) for a in abstract]
# x = range(0, len(y))
# plt.bar(x, y)

In [14]:
class Config:
    def __init__(self):
        super(Config, self).__init__()
        
        self.SEED = 42
        self.MODEL_PATH = 'allenai/scibert_scivocab_uncased'
        self.NUM_LABELS = 1
        
        self.TOKENIZER = AutoTokenizer.from_pretrained(self.MODEL_PATH)
        self.MAX_LENGTH1 = 50
        self.MAX_LENGTH2 = 350
        self.BATCH_SIZE = 32
        self.VALIDATION_SPLIT = 0.3
        
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.FULL_FINETUNING = True
        self.LR = 2e-5
        self.OPTIMIZER = 'AdamW'
        self.CRITERION = 'BCEWithLogitsLoss'
        self.SAVE_BEST_ONLY = True
        self.N_VALIDATE_DUR_TRAIN = 1
        self.EPOCHS = 3
        
config = Config()

In [15]:
def seed_init(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
seed = config.SEED
seed_init(seed)

In [16]:
class TransformerDataset(Dataset):
    def __init__(self, df, indices, set_type=None):
        super(TransformerDataset, self).__init__()

        df = df.iloc[indices]
        self.titles, self.abstracts = get_texts(df)
        self.set_type = set_type
        if self.set_type != 'test':
            self.labels = get_labels(df)

        self.max_length1 = config.MAX_LENGTH1
        self.max_length2 = config.MAX_LENGTH2
        self.tokenizer = config.TOKENIZER

    def __len__(self):
        return len(self.titles)
    
    def __getitem__(self, index):
        tokenized_titles = self.tokenizer.encode_plus(
            self.titles[index], 
            max_length=self.max_length1,
            pad_to_max_length=True,
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )
        input_ids_titles = tokenized_titles['input_ids'].squeeze()
        attention_mask_titles = tokenized_titles['attention_mask'].squeeze()
        
        tokenized_abstracts = self.tokenizer.encode_plus(
            self.abstracts[index], 
            max_length=self.max_length2,
            pad_to_max_length=True,
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )
        input_ids_abstracts = tokenized_abstracts['input_ids'].squeeze()
        attention_mask_abstracts = tokenized_abstracts['attention_mask'].squeeze()

        if self.set_type != 'test':
            return {
                'titles': {
                    'input_ids': input_ids_titles.long(),
                    'attention_mask': attention_mask_titles.long(),
                },
                'abstracts': {
                    'input_ids': input_ids_abstracts.long(),
                    'attention_mask': attention_mask_abstracts.long(),
                },
                'labels': torch.Tensor([self.labels[index]]).float(),
            }

        return {
            'titles': {
                'input_ids': input_ids_titles.long(),
                'attention_mask': attention_mask_titles.long(),
            },
            'abstracts': {
                'input_ids': input_ids_abstracts.long(),
                'attention_mask': attention_mask_abstracts.long(),
            }
        }

In [17]:
class DualSciBert(nn.Module):
    def __init__(self):
        super(DualSciBert, self).__init__()
        self.titles_model = AutoModel.from_pretrained(config.MODEL_PATH)
        self.abstracts_model = AutoModel.from_pretrained(config.MODEL_PATH)
        self.dropout = nn.Dropout(0.25)
        self.avgpool = nn.AvgPool1d(2, 2)
        self.output = nn.Linear(768, config.NUM_LABELS)
    
    def forward(self, input_ids_titles, attention_mask_titles=None, input_ids_abstracts=None, attention_mask_abstracts=None):
        output = self.titles_model(input_ids=input_ids_titles, attention_mask=attention_mask_titles)
        titles_features = output.pooler_output
        titles_features = titles_features.unsqueeze(1)
        titles_features_pooled = self.avgpool(titles_features)
        titles_features_pooled = titles_features_pooled.squeeze(1)
        
        output = self.abstracts_model(input_ids=input_ids_abstracts, attention_mask=attention_mask_abstracts)
        abstracts_features = output.pooler_output
        abstracts_features = abstracts_features.unsqueeze(1)
        abstracts_features_pooled = self.avgpool(abstracts_features)
        abstracts_features_pooled = abstracts_features_pooled.squeeze(1)
        
        combined_features = torch.cat((titles_features_pooled, abstracts_features_pooled), dim=1)
        x = self.dropout(combined_features)
        x = self.output(x)
        
        return x

In [18]:
def val(model, val_dataloader, criterion):
    val_loss = 0
    true, pred = [], []
    
    model.eval()
    
    for step, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
        b_input_ids_titles = batch['titles']['input_ids'].to(device)
        b_attention_mask_titles = batch['titles']['attention_mask'].to(device)
        b_input_ids_abstracts = batch['abstracts']['input_ids'].to(device)
        b_attention_mask_abstracts = batch['abstracts']['attention_mask'].to(device)
        b_labels = batch['labels'].to(device)
        
        with torch.no_grad():
            logits = model(b_input_ids_titles, b_attention_mask_titles, b_input_ids_abstracts, b_attention_mask_abstracts)
            loss = criterion(logits, b_labels)
            val_loss += loss.item()
            
            logits = torch.sigmoid(logits)
            logits = np.where(logits.to('cpu').detach().numpy().copy() < border, 0, 1)
            labels = b_labels.to('cpu').detach().numpy().copy()
            
            pred.extend(logits)
            true.extend(labels)
        
    avg_val_loss = val_loss / len(val_dataloader)
    print('Val loss:', avg_val_loss)
    print('Val accuracy:', accuracy_score(true, pred))
    
    val_micro_f1_score = f1_score(true, pred, average='micro')
    print('Val micro f1 score:', val_micro_f1_score)
    return val_micro_f1_score

In [19]:
def train(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, epoch):
    nv = config.N_VALIDATE_DUR_TRAIN
    temp = len(train_dataloader) // nv
    temp = temp - (temp%100)
    validate_at_steps = [temp * x for x in range(1, nv+1)]
    
    train_loss = 0
    true, pred = [], []
    for step, batch in tqdm(enumerate(train_dataloader), desc='Epoch ' + str(epoch), total=len(train_dataloader)):
        
        model.train()
        b_input_ids_titles = batch['titles']['input_ids'].to(device)
        b_attention_mask_titles = batch['titles']['attention_mask'].to(device)
        b_input_ids_abstracts = batch['abstracts']['input_ids'].to(device)
        b_attention_mask_abstracts = batch['abstracts']['attention_mask'].to(device)
        b_labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(b_input_ids_titles, b_attention_mask_titles, b_input_ids_abstracts, b_attention_mask_abstracts)
#         print("logits=",logits)
        loss = criterion(logits, b_labels)
        train_loss += loss.item()
#         print(train_loss)
        
        logits = torch.sigmoid(logits)
        logits = np.where(logits.to('cpu').detach().numpy().copy() < border, 0, 1)
        labels = b_labels.to('cpu').detach().numpy().copy()
            
        pred.extend(logits)
        true.extend(labels)
            
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        if step in validate_at_steps:
            print(f'-- Step: {step}')
            _ = val(model, val_dataloader, criterion)
            
    avg_train_loss = train_loss / len(train_dataloader)
    accuracy = accuracy_score(true, pred)
    print('Training loss:', avg_train_loss)
    print('Training accuracy:', accuracy_score)
    return avg_train_loss, accuracy

In [20]:
def run(train_dataloader, val_dataloader, writer):
    criterion = nn.BCEWithLogitsLoss()
    
    if config.FULL_FINETUNING:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.001,
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = optim.AdamW(optimizer_parameters, lr=config.LR)
    
    num_training_steps = len(train_dataloader) * config.EPOCHS
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
    
    max_val_micro_f1_score = float('-inf')
    for epoch in range(config.EPOCHS):
        avg_train_loss, accuracy = train(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, epoch)
        val_micro_f1_score = val(model, val_dataloader, criterion)
        
        writer.add_scalar('train_loss', avg_train_loss, epoch+1)
        writer.add_scalar('accuracy', accuracy)
        
        if config.SAVE_BEST_ONLY:
            if val_micro_f1_score > max_val_micro_f1_score:
                best_model = copy.deepcopy(model)
                best_val_micro_f1_score = val_micro_f1_score
                
    return best_model, best_val_micro_f1_score

In [21]:
def cross_val():
    Fold = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed)
    max_val_micro_f1_score = float('-inf')
    
    for n, (train_indices, val_indices) in enumerate(Fold.split(train_df, train_df['judgement'])):
        print(f'========= fold: {n} training =========')
        
        log_dir = 'logs/fold'+str(n)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        
        writer = SummaryWriter(log_dir=log_dir)

        train_data = TransformerDataset(train_df, train_indices)
        val_data = TransformerDataset(train_df, val_indices)
        
        train_dataloader = DataLoader(train_data, batch_size=config.BATCH_SIZE)
        val_dataloader = DataLoader(val_data, batch_size=config.BATCH_SIZE)
                
        fold_best_model, fold_best_val_micro_f1_score = run(train_dataloader, val_dataloader, writer)
        
        if config.SAVE_BEST_ONLY:
            if fold_best_val_micro_f1_score > max_val_micro_f1_score:
                best_model = fold_best_model
                best_val_micro_f1_score = fold_best_val_micro_f1_score
                
                model_name = 'scibert_dualinput_best_model'
                torch.save(best_model.state_dict(), model_name+'.pt')
        
        writer.close()
                
    return best_model, best_val_micro_f1_score

In [22]:
device = config.DEVICE
device

device(type='cuda')

In [23]:
!jupyter nbextension enable --py --sys-prefix widgetsnbextension
torch.cuda.empty_cache()
from tqdm import tqdm
model = DualSciBert()
model=nn.DataParallel(model)
model.to(device)
best_model, best_val_micro_f1_score = cross_val()

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing Ber



Epoch 0:  94%|██████████████████████████████  | 400/425 [05:46<00:20,  1.23it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:02,  2.32it/s][A
  0%|▏                                          | 2/425 [00:00<02:59,  2.35it/s][A
  1%|▎                                          | 3/425 [00:01<03:10,  2.22it/s][A
  1%|▍                                          | 4/425 [00:01<03:14,  2.17it/s][A
  1%|▌                                          | 5/425 [00:02<03:13,  2.17it/s][A
  1%|▌                                          | 6/425 [00:02<03:12,  2.18it/s][A
  2%|▋                                          | 7/425 [00:03<03:23,  2.05it/s][A
  2%|▊                                          | 8/425 [00:03<03:19,  2.09it/s][A
  2%|▉                                          | 9/425 [00:04<03:17,  2.11it/s][A
  2%|▉                                         | 10/425 [00:04<03:14,  2.13it/s][A
  3%|█                                         | 11/425 [00:05<03:23,  2.03

 23%|█████████▌                                | 97/425 [00:46<02:41,  2.03it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:37,  2.08it/s][A
 23%|█████████▊                                | 99/425 [00:47<02:35,  2.10it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:43,  1.99it/s][A
 24%|█████████▋                               | 101/425 [00:48<02:40,  2.02it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:35,  2.08it/s][A
 24%|█████████▉                               | 103/425 [00:49<02:32,  2.12it/s][A
 24%|██████████                               | 104/425 [00:50<02:39,  2.02it/s][A
 25%|██████████▏                              | 105/425 [00:50<02:35,  2.06it/s][A
 25%|██████████▏                              | 106/425 [00:50<02:32,  2.09it/s][A
 25%|██████████▎                              | 107/425 [00:51<02:29,  2.12it/s][A
 25%|██████████▍                              | 108/425 [00:51<02:35,  2.03i

 46%|██████████████████▋                      | 194/425 [01:33<01:50,  2.09it/s][A
 46%|██████████████████▊                      | 195/425 [01:33<01:48,  2.11it/s][A
 46%|██████████████████▉                      | 196/425 [01:34<01:47,  2.13it/s][A
 46%|███████████████████                      | 197/425 [01:34<01:54,  2.00it/s][A
 47%|███████████████████                      | 198/425 [01:35<01:50,  2.06it/s][A
 47%|███████████████████▏                     | 199/425 [01:35<01:47,  2.11it/s][A
 47%|███████████████████▎                     | 200/425 [01:36<01:46,  2.12it/s][A
 47%|███████████████████▍                     | 201/425 [01:36<01:50,  2.03it/s][A
 48%|███████████████████▍                     | 202/425 [01:37<01:47,  2.08it/s][A
 48%|███████████████████▌                     | 203/425 [01:37<01:46,  2.09it/s][A
 48%|███████████████████▋                     | 204/425 [01:38<01:45,  2.10it/s][A
 48%|███████████████████▊                     | 205/425 [01:38<01:49,  2.01i

 68%|████████████████████████████             | 291/425 [02:20<01:03,  2.10it/s][A
 69%|████████████████████████████▏            | 292/425 [02:20<01:02,  2.12it/s][A
 69%|████████████████████████████▎            | 293/425 [02:21<01:02,  2.12it/s][A
 69%|████████████████████████████▎            | 294/425 [02:21<01:04,  2.02it/s][A
 69%|████████████████████████████▍            | 295/425 [02:22<01:03,  2.05it/s][A
 70%|████████████████████████████▌            | 296/425 [02:22<01:01,  2.08it/s][A
 70%|████████████████████████████▋            | 297/425 [02:23<01:00,  2.13it/s][A
 70%|████████████████████████████▋            | 298/425 [02:23<01:02,  2.02it/s][A
 70%|████████████████████████████▊            | 299/425 [02:24<01:01,  2.05it/s][A
 71%|████████████████████████████▉            | 300/425 [02:24<00:59,  2.11it/s][A
 71%|█████████████████████████████            | 301/425 [02:25<00:57,  2.14it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:25<01:00,  2.04i

 91%|█████████████████████████████████████▍   | 388/425 [03:07<00:18,  2.05it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:07<00:17,  2.09it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:08<00:16,  2.10it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:08<00:16,  2.00it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:09<00:16,  2.05it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:09<00:15,  2.08it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:10<00:14,  2.09it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:10<00:14,  2.01it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:11<00:14,  2.05it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:11<00:13,  2.09it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:11<00:12,  2.11it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:12<00:12,  2.02i

Val loss: 0.04729852896238513
Val accuracy: 0.9045163191630443
Val micro f1 score: 0.9045163191630443


Epoch 0: 100%|████████████████████████████████| 425/425 [09:30<00:00,  1.34s/it]


Training loss: 0.07876459090696539
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:24<00:00,  2.08it/s]


Val loss: 0.04857351439950221
Val accuracy: 0.8692993442864511
Val micro f1 score: 0.8692993442864511


Epoch 1:  94%|██████████████████████████████  | 400/425 [05:30<00:20,  1.22it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:11,  2.22it/s][A
  0%|▏                                          | 2/425 [00:00<03:16,  2.15it/s][A
  1%|▎                                          | 3/425 [00:01<03:15,  2.16it/s][A
  1%|▍                                          | 4/425 [00:01<03:34,  1.96it/s][A
  1%|▌                                          | 5/425 [00:02<03:27,  2.03it/s][A
  1%|▌                                          | 6/425 [00:02<03:22,  2.07it/s][A
  2%|▋                                          | 7/425 [00:03<03:20,  2.09it/s][A
  2%|▊                                          | 8/425 [00:03<03:30,  1.98it/s][A
  2%|▉                                          | 9/425 [00:04<03:25,  2.03it/s][A
  2%|▉                                         | 10/425 [00:04<03:20,  2.07it/s][A
  3%|█                                         | 11/425 [00:05<03:17,  2.10

 23%|█████████▌                                | 97/425 [00:47<02:38,  2.07it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:44,  1.99it/s][A
 23%|█████████▊                                | 99/425 [00:48<02:34,  2.11it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:34,  2.11it/s][A
 24%|█████████▋                               | 101/425 [00:49<02:42,  2.00it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:38,  2.03it/s][A
 24%|█████████▉                               | 103/425 [00:50<02:35,  2.08it/s][A
 24%|██████████                               | 104/425 [00:50<02:32,  2.11it/s][A
 25%|██████████▏                              | 105/425 [00:51<02:39,  2.01it/s][A
 25%|██████████▏                              | 106/425 [00:51<02:34,  2.06it/s][A
 25%|██████████▎                              | 107/425 [00:52<02:32,  2.09it/s][A
 25%|██████████▍                              | 108/425 [00:52<02:29,  2.13i

 46%|██████████████████▋                      | 194/425 [01:34<01:47,  2.14it/s][A
 46%|██████████████████▊                      | 195/425 [01:35<01:53,  2.04it/s][A
 46%|██████████████████▉                      | 196/425 [01:35<01:50,  2.07it/s][A
 46%|███████████████████                      | 197/425 [01:35<01:47,  2.12it/s][A
 47%|███████████████████                      | 198/425 [01:36<01:53,  2.00it/s][A
 47%|███████████████████▏                     | 199/425 [01:36<01:48,  2.08it/s][A
 47%|███████████████████▎                     | 200/425 [01:37<01:47,  2.09it/s][A
 47%|███████████████████▍                     | 201/425 [01:37<01:45,  2.13it/s][A
 48%|███████████████████▍                     | 202/425 [01:38<01:50,  2.02it/s][A
 48%|███████████████████▌                     | 203/425 [01:38<01:47,  2.06it/s][A
 48%|███████████████████▋                     | 204/425 [01:39<01:45,  2.09it/s][A
 48%|███████████████████▊                     | 205/425 [01:39<01:44,  2.11i

 68%|████████████████████████████             | 291/425 [02:21<01:03,  2.11it/s][A
 69%|████████████████████████████▏            | 292/425 [02:22<01:06,  2.00it/s][A
 69%|████████████████████████████▎            | 293/425 [02:22<01:04,  2.05it/s][A
 69%|████████████████████████████▎            | 294/425 [02:23<01:03,  2.08it/s][A
 69%|████████████████████████████▍            | 295/425 [02:23<01:01,  2.11it/s][A
 70%|████████████████████████████▌            | 296/425 [02:24<01:04,  2.01it/s][A
 70%|████████████████████████████▋            | 297/425 [02:24<01:02,  2.06it/s][A
 70%|████████████████████████████▋            | 298/425 [02:25<01:00,  2.10it/s][A
 70%|████████████████████████████▊            | 299/425 [02:25<01:03,  1.99it/s][A
 71%|████████████████████████████▉            | 300/425 [02:26<01:00,  2.05it/s][A
 71%|█████████████████████████████            | 301/425 [02:26<00:59,  2.08it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:26<00:58,  2.11i

 91%|█████████████████████████████████████▍   | 388/425 [03:09<00:17,  2.08it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:09<00:18,  1.98it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:10<00:17,  2.02it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:10<00:16,  2.05it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:11<00:15,  2.09it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:11<00:16,  2.00it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:12<00:15,  2.03it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:12<00:14,  2.07it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:13<00:14,  1.97it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:13<00:13,  2.03it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:14<00:13,  2.06it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:14<00:12,  2.08i

Val loss: 0.04626357858096633
Val accuracy: 0.9657408089589626
Val micro f1 score: 0.9657408089589626


Epoch 1: 100%|████████████████████████████████| 425/425 [09:17<00:00,  1.31s/it]


Training loss: 0.03863668737248244
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:25<00:00,  2.07it/s]


Val loss: 0.04349890100057511
Val accuracy: 0.950637294629043
Val micro f1 score: 0.950637294629043


Epoch 2:  94%|██████████████████████████████  | 400/425 [05:31<00:20,  1.21it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:02,  2.32it/s][A
  0%|▏                                          | 2/425 [00:00<03:33,  1.98it/s][A
  1%|▎                                          | 3/425 [00:01<03:26,  2.04it/s][A
  1%|▍                                          | 4/425 [00:01<03:24,  2.06it/s][A
  1%|▌                                          | 5/425 [00:02<03:20,  2.10it/s][A
  1%|▌                                          | 6/425 [00:02<03:32,  1.97it/s][A
  2%|▋                                          | 7/425 [00:03<03:26,  2.03it/s][A
  2%|▊                                          | 8/425 [00:03<03:21,  2.07it/s][A
  2%|▉                                          | 9/425 [00:04<03:21,  2.06it/s][A
  2%|▉                                         | 10/425 [00:04<03:31,  1.96it/s][A
  3%|█                                         | 11/425 [00:05<03:26,  2.00

 23%|█████████▌                                | 97/425 [00:47<02:35,  2.11it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:32,  2.14it/s][A
 23%|█████████▊                                | 99/425 [00:48<02:41,  2.02it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:37,  2.06it/s][A
 24%|█████████▋                               | 101/425 [00:49<02:35,  2.09it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:32,  2.11it/s][A
 24%|█████████▉                               | 103/425 [00:50<02:41,  1.99it/s][A
 24%|██████████                               | 104/425 [00:50<02:38,  2.02it/s][A
 25%|██████████▏                              | 105/425 [00:51<02:34,  2.07it/s][A
 25%|██████████▏                              | 106/425 [00:51<02:31,  2.10it/s][A
 25%|██████████▎                              | 107/425 [00:52<02:37,  2.02it/s][A
 25%|██████████▍                              | 108/425 [00:52<02:33,  2.06i

 46%|██████████████████▋                      | 194/425 [01:34<01:51,  2.07it/s][A
 46%|██████████████████▊                      | 195/425 [01:34<01:49,  2.11it/s][A
 46%|██████████████████▉                      | 196/425 [01:35<01:54,  2.01it/s][A
 46%|███████████████████                      | 197/425 [01:35<01:50,  2.06it/s][A
 47%|███████████████████                      | 198/425 [01:36<01:46,  2.14it/s][A
 47%|███████████████████▏                     | 199/425 [01:36<01:42,  2.20it/s][A
 47%|███████████████████▎                     | 200/425 [01:37<01:48,  2.07it/s][A
 47%|███████████████████▍                     | 201/425 [01:37<01:46,  2.10it/s][A
 48%|███████████████████▍                     | 202/425 [01:38<01:45,  2.12it/s][A
 48%|███████████████████▌                     | 203/425 [01:38<01:44,  2.13it/s][A
 48%|███████████████████▋                     | 204/425 [01:39<01:49,  2.01it/s][A
 48%|███████████████████▊                     | 205/425 [01:39<01:47,  2.05i

 68%|████████████████████████████             | 291/425 [02:21<01:04,  2.09it/s][A
 69%|████████████████████████████▏            | 292/425 [02:21<01:03,  2.11it/s][A
 69%|████████████████████████████▎            | 293/425 [02:22<01:06,  2.00it/s][A
 69%|████████████████████████████▎            | 294/425 [02:22<01:04,  2.04it/s][A
 69%|████████████████████████████▍            | 295/425 [02:23<01:03,  2.06it/s][A
 70%|████████████████████████████▌            | 296/425 [02:23<01:01,  2.10it/s][A
 70%|████████████████████████████▋            | 297/425 [02:24<01:03,  2.00it/s][A
 70%|████████████████████████████▋            | 298/425 [02:24<01:02,  2.04it/s][A
 70%|████████████████████████████▊            | 299/425 [02:25<01:00,  2.08it/s][A
 71%|████████████████████████████▉            | 300/425 [02:25<00:59,  2.12it/s][A
 71%|█████████████████████████████            | 301/425 [02:26<01:01,  2.01it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:26<01:00,  2.05i

 91%|█████████████████████████████████████▍   | 388/425 [03:08<00:17,  2.06it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:09<00:17,  2.11it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:09<00:17,  2.00it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:10<00:16,  2.04it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:10<00:15,  2.08it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:11<00:15,  2.08it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:11<00:15,  1.97it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:12<00:14,  2.03it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:12<00:14,  2.05it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:13<00:13,  2.09it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:13<00:13,  1.99it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:14<00:12,  2.03i

Val loss: 0.04454499962547904
Val accuracy: 0.9541737272526339
Val micro f1 score: 0.9541737272526339


Epoch 2: 100%|████████████████████████████████| 425/425 [09:17<00:00,  1.31s/it]


Training loss: 0.019156174409504543
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:25<00:00,  2.06it/s]


Val loss: 0.044332074772648736
Val accuracy: 0.9560892949237457
Val micro f1 score: 0.9560892949237457


Epoch 0:  94%|██████████████████████████████  | 400/425 [05:29<00:19,  1.28it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:20,  2.12it/s][A
  0%|▏                                          | 2/425 [00:00<03:17,  2.14it/s][A
  1%|▎                                          | 3/425 [00:01<03:36,  1.95it/s][A
  1%|▍                                          | 4/425 [00:01<03:29,  2.01it/s][A
  1%|▌                                          | 5/425 [00:02<03:22,  2.07it/s][A
  1%|▌                                          | 6/425 [00:02<03:19,  2.10it/s][A
  2%|▋                                          | 7/425 [00:03<03:31,  1.97it/s][A
  2%|▊                                          | 8/425 [00:03<03:25,  2.03it/s][A
  2%|▉                                          | 9/425 [00:04<03:22,  2.06it/s][A
  2%|▉                                         | 10/425 [00:04<03:18,  2.09it/s][A
  3%|█                                         | 11/425 [00:05<03:30,  1.96

 23%|█████████▌                                | 97/425 [00:47<02:43,  2.00it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:40,  2.04it/s][A
 23%|█████████▊                                | 99/425 [00:48<02:37,  2.07it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:45,  1.96it/s][A
 24%|█████████▋                               | 101/425 [00:49<02:41,  2.01it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:39,  2.03it/s][A
 24%|█████████▉                               | 103/425 [00:50<02:36,  2.05it/s][A
 24%|██████████                               | 104/425 [00:50<02:44,  1.96it/s][A
 25%|██████████▏                              | 105/425 [00:51<02:38,  2.01it/s][A
 25%|██████████▏                              | 106/425 [00:51<02:35,  2.05it/s][A
 25%|██████████▎                              | 107/425 [00:52<02:34,  2.06it/s][A
 25%|██████████▍                              | 108/425 [00:52<02:44,  1.93i

 46%|██████████████████▋                      | 194/425 [01:34<01:54,  2.01it/s][A
 46%|██████████████████▊                      | 195/425 [01:35<01:52,  2.04it/s][A
 46%|██████████████████▉                      | 196/425 [01:35<01:50,  2.08it/s][A
 46%|███████████████████                      | 197/425 [01:36<01:55,  1.97it/s][A
 47%|███████████████████                      | 198/425 [01:36<01:52,  2.02it/s][A
 47%|███████████████████▏                     | 199/425 [01:37<01:49,  2.06it/s][A
 47%|███████████████████▎                     | 200/425 [01:37<01:47,  2.08it/s][A
 47%|███████████████████▍                     | 201/425 [01:38<01:52,  1.99it/s][A
 48%|███████████████████▍                     | 202/425 [01:38<01:49,  2.03it/s][A
 48%|███████████████████▌                     | 203/425 [01:39<01:47,  2.07it/s][A
 48%|███████████████████▋                     | 204/425 [01:39<01:44,  2.11it/s][A
 48%|███████████████████▊                     | 205/425 [01:40<01:49,  2.01i

 68%|████████████████████████████             | 291/425 [02:21<01:06,  2.00it/s][A
 69%|████████████████████████████▏            | 292/425 [02:22<01:05,  2.04it/s][A
 69%|████████████████████████████▎            | 293/425 [02:22<01:03,  2.06it/s][A
 69%|████████████████████████████▎            | 294/425 [02:23<01:07,  1.93it/s][A
 69%|████████████████████████████▍            | 295/425 [02:23<01:05,  2.00it/s][A
 70%|████████████████████████████▌            | 296/425 [02:24<01:02,  2.05it/s][A
 70%|████████████████████████████▋            | 297/425 [02:24<01:01,  2.09it/s][A
 70%|████████████████████████████▋            | 298/425 [02:25<01:03,  1.99it/s][A
 70%|████████████████████████████▊            | 299/425 [02:25<01:02,  2.03it/s][A
 71%|████████████████████████████▉            | 300/425 [02:26<01:00,  2.06it/s][A
 71%|█████████████████████████████            | 301/425 [02:26<00:59,  2.10it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:27<01:01,  2.00i

 91%|█████████████████████████████████████▍   | 388/425 [03:09<00:18,  2.03it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:09<00:17,  2.07it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:09<00:16,  2.10it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:10<00:16,  2.00it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:10<00:16,  2.05it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:11<00:15,  2.09it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:11<00:14,  2.11it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:12<00:14,  2.01it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:12<00:14,  2.06it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:13<00:13,  2.09it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:13<00:12,  2.12it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:14<00:12,  2.00i

Val loss: 0.024854432925143662
Val accuracy: 0.976790450928382
Val micro f1 score: 0.976790450928382


Epoch 0: 100%|████████████████████████████████| 425/425 [09:16<00:00,  1.31s/it]


Training loss: 0.04460967604906353
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:25<00:00,  2.06it/s]


Val loss: 0.025915951235219836
Val accuracy: 0.9249189507810197
Val micro f1 score: 0.9249189507810197


Epoch 1:  94%|██████████████████████████████  | 400/425 [05:30<00:20,  1.23it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:41,  1.92it/s][A
  0%|▏                                          | 2/425 [00:00<03:20,  2.11it/s][A
  1%|▎                                          | 3/425 [00:01<03:12,  2.19it/s][A
  1%|▍                                          | 4/425 [00:01<03:06,  2.26it/s][A
  1%|▌                                          | 5/425 [00:02<03:12,  2.18it/s][A
  1%|▌                                          | 6/425 [00:02<03:13,  2.17it/s][A
  2%|▋                                          | 7/425 [00:03<03:14,  2.14it/s][A
  2%|▊                                          | 8/425 [00:03<03:15,  2.14it/s][A
  2%|▉                                          | 9/425 [00:04<03:26,  2.01it/s][A
  2%|▉                                         | 10/425 [00:04<03:23,  2.04it/s][A
  3%|█                                         | 11/425 [00:05<03:21,  2.06

 23%|█████████▌                                | 97/425 [00:47<02:35,  2.11it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:44,  1.99it/s][A
 23%|█████████▊                                | 99/425 [00:48<02:40,  2.03it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:37,  2.07it/s][A
 24%|█████████▋                               | 101/425 [00:48<02:35,  2.08it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:43,  1.98it/s][A
 24%|█████████▉                               | 103/425 [00:50<02:39,  2.02it/s][A
 24%|██████████                               | 104/425 [00:50<02:35,  2.07it/s][A
 25%|██████████▏                              | 105/425 [00:50<02:33,  2.08it/s][A
 25%|██████████▏                              | 106/425 [00:51<02:40,  1.99it/s][A
 25%|██████████▎                              | 107/425 [00:51<02:36,  2.03it/s][A
 25%|██████████▍                              | 108/425 [00:52<02:33,  2.06i

 46%|██████████████████▋                      | 194/425 [01:34<01:50,  2.09it/s][A
 46%|██████████████████▊                      | 195/425 [01:35<01:56,  1.97it/s][A
 46%|██████████████████▉                      | 196/425 [01:35<01:53,  2.02it/s][A
 46%|███████████████████                      | 197/425 [01:35<01:50,  2.06it/s][A
 47%|███████████████████                      | 198/425 [01:36<01:48,  2.09it/s][A
 47%|███████████████████▏                     | 199/425 [01:36<01:53,  1.99it/s][A
 47%|███████████████████▎                     | 200/425 [01:37<01:50,  2.04it/s][A
 47%|███████████████████▍                     | 201/425 [01:37<01:47,  2.09it/s][A
 48%|███████████████████▍                     | 202/425 [01:38<01:45,  2.11it/s][A
 48%|███████████████████▌                     | 203/425 [01:38<01:51,  1.99it/s][A
 48%|███████████████████▋                     | 204/425 [01:39<01:48,  2.05it/s][A
 48%|███████████████████▊                     | 205/425 [01:39<01:46,  2.07i

 68%|████████████████████████████             | 291/425 [02:22<01:04,  2.08it/s][A
 69%|████████████████████████████▏            | 292/425 [02:22<01:06,  1.99it/s][A
 69%|████████████████████████████▎            | 293/425 [02:23<01:04,  2.05it/s][A
 69%|████████████████████████████▎            | 294/425 [02:23<01:02,  2.08it/s][A
 69%|████████████████████████████▍            | 295/425 [02:24<01:01,  2.11it/s][A
 70%|████████████████████████████▌            | 296/425 [02:24<01:03,  2.02it/s][A
 70%|████████████████████████████▋            | 297/425 [02:25<01:02,  2.06it/s][A
 70%|████████████████████████████▋            | 298/425 [02:25<01:01,  2.07it/s][A
 70%|████████████████████████████▊            | 299/425 [02:25<01:00,  2.10it/s][A
 71%|████████████████████████████▉            | 300/425 [02:26<01:02,  1.99it/s][A
 71%|█████████████████████████████            | 301/425 [02:26<00:59,  2.10it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:27<00:58,  2.10i

 91%|█████████████████████████████████████▍   | 388/425 [03:09<00:17,  2.08it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:09<00:18,  1.98it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:10<00:17,  2.03it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:10<00:16,  2.06it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:11<00:15,  2.10it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:11<00:15,  2.07it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:12<00:14,  2.16it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:12<00:13,  2.16it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:13<00:13,  2.17it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:13<00:13,  2.04it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:14<00:13,  2.07it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:14<00:12,  2.10i

Val loss: 0.024172737658818198
Val accuracy: 0.9725906277630415
Val micro f1 score: 0.9725906277630415


Epoch 1: 100%|████████████████████████████████| 425/425 [09:17<00:00,  1.31s/it]


Training loss: 0.02175387929433354
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:25<00:00,  2.07it/s]


Val loss: 0.023866149242214092
Val accuracy: 0.9715590922487474
Val micro f1 score: 0.9715590922487474


Epoch 2:  94%|██████████████████████████████  | 400/425 [05:28<00:21,  1.16it/s]

-- Step: 400



  0%|                                                   | 0/425 [00:00<?, ?it/s][A
  0%|                                           | 1/425 [00:00<03:12,  2.20it/s][A
  0%|▏                                          | 2/425 [00:00<03:14,  2.17it/s][A
  1%|▎                                          | 3/425 [00:01<03:37,  1.94it/s][A
  1%|▍                                          | 4/425 [00:01<03:28,  2.02it/s][A
  1%|▌                                          | 5/425 [00:02<03:23,  2.06it/s][A
  1%|▌                                          | 6/425 [00:02<03:19,  2.10it/s][A
  2%|▋                                          | 7/425 [00:03<03:32,  1.96it/s][A
  2%|▊                                          | 8/425 [00:03<03:24,  2.04it/s][A
  2%|▉                                          | 9/425 [00:04<03:22,  2.05it/s][A
  2%|▉                                         | 10/425 [00:04<03:31,  1.97it/s][A
  3%|█                                         | 11/425 [00:05<03:27,  2.00

 23%|█████████▌                                | 97/425 [00:47<02:41,  2.03it/s][A
 23%|█████████▋                                | 98/425 [00:47<02:39,  2.05it/s][A
 23%|█████████▊                                | 99/425 [00:48<02:36,  2.08it/s][A
 24%|█████████▋                               | 100/425 [00:48<02:43,  1.99it/s][A
 24%|█████████▋                               | 101/425 [00:49<02:39,  2.03it/s][A
 24%|█████████▊                               | 102/425 [00:49<02:36,  2.06it/s][A
 24%|█████████▉                               | 103/425 [00:50<02:34,  2.08it/s][A
 24%|██████████                               | 104/425 [00:50<02:40,  2.00it/s][A
 25%|██████████▏                              | 105/425 [00:51<02:37,  2.03it/s][A
 25%|██████████▏                              | 106/425 [00:51<02:33,  2.07it/s][A
 25%|██████████▎                              | 107/425 [00:52<02:41,  1.97it/s][A
 25%|██████████▍                              | 108/425 [00:52<02:37,  2.02i

 46%|██████████████████▋                      | 194/425 [01:34<01:50,  2.10it/s][A
 46%|██████████████████▊                      | 195/425 [01:35<01:48,  2.12it/s][A
 46%|██████████████████▉                      | 196/425 [01:35<01:47,  2.14it/s][A
 46%|███████████████████                      | 197/425 [01:36<01:52,  2.02it/s][A
 47%|███████████████████                      | 198/425 [01:36<01:51,  2.04it/s][A
 47%|███████████████████▏                     | 199/425 [01:37<01:48,  2.08it/s][A
 47%|███████████████████▎                     | 200/425 [01:37<01:47,  2.10it/s][A
 47%|███████████████████▍                     | 201/425 [01:38<01:51,  2.01it/s][A
 48%|███████████████████▍                     | 202/425 [01:38<01:49,  2.04it/s][A
 48%|███████████████████▌                     | 203/425 [01:39<01:47,  2.06it/s][A
 48%|███████████████████▋                     | 204/425 [01:39<01:45,  2.10it/s][A
 48%|███████████████████▊                     | 205/425 [01:40<01:49,  2.00i

 68%|████████████████████████████             | 291/425 [02:22<01:06,  2.01it/s][A
 69%|████████████████████████████▏            | 292/425 [02:22<01:04,  2.05it/s][A
 69%|████████████████████████████▎            | 293/425 [02:23<01:03,  2.08it/s][A
 69%|████████████████████████████▎            | 294/425 [02:23<01:07,  1.95it/s][A
 69%|████████████████████████████▍            | 295/425 [02:24<01:04,  2.01it/s][A
 70%|████████████████████████████▌            | 296/425 [02:24<01:02,  2.06it/s][A
 70%|████████████████████████████▋            | 297/425 [02:25<01:01,  2.08it/s][A
 70%|████████████████████████████▋            | 298/425 [02:25<01:04,  1.97it/s][A
 70%|████████████████████████████▊            | 299/425 [02:26<01:02,  2.03it/s][A
 71%|████████████████████████████▉            | 300/425 [02:26<01:00,  2.05it/s][A
 71%|█████████████████████████████            | 301/425 [02:27<00:59,  2.09it/s][A
 71%|█████████████████████████████▏           | 302/425 [02:27<01:01,  1.99i

 91%|█████████████████████████████████████▍   | 388/425 [03:09<00:18,  2.02it/s][A
 92%|█████████████████████████████████████▌   | 389/425 [03:09<00:17,  2.05it/s][A
 92%|█████████████████████████████████████▌   | 390/425 [03:10<00:16,  2.09it/s][A
 92%|█████████████████████████████████████▋   | 391/425 [03:10<00:17,  1.99it/s][A
 92%|█████████████████████████████████████▊   | 392/425 [03:11<00:16,  2.04it/s][A
 92%|█████████████████████████████████████▉   | 393/425 [03:11<00:15,  2.08it/s][A
 93%|██████████████████████████████████████   | 394/425 [03:12<00:14,  2.14it/s][A
 93%|██████████████████████████████████████   | 395/425 [03:12<00:15,  1.98it/s][A
 93%|██████████████████████████████████████▏  | 396/425 [03:13<00:14,  2.02it/s][A
 93%|██████████████████████████████████████▎  | 397/425 [03:13<00:13,  2.08it/s][A
 94%|██████████████████████████████████████▍  | 398/425 [03:14<00:12,  2.11it/s][A
 94%|██████████████████████████████████████▍  | 399/425 [03:14<00:13,  2.00i

Val loss: 0.019876387711402082
Val accuracy: 0.9846743295019157
Val micro f1 score: 0.9846743295019157


Epoch 2: 100%|████████████████████████████████| 425/425 [09:15<00:00,  1.31s/it]


Training loss: 0.008827155504225042
Training accuracy: <function accuracy_score at 0x7fd71e6d1670>


100%|█████████████████████████████████████████| 425/425 [03:26<00:00,  2.06it/s]


Val loss: 0.01979384132068577
Val accuracy: 0.9846743295019157
Val micro f1 score: 0.9846743295019157


In [None]:
dataset_size = len(test_df)
test_indices = list(range(dataset_size))
test_data = TransformerDataset(test_df, test_indices, set_type='test')
test_dataloader = DataLoader(test_data, batch_size=config.BATCH_SIZE)


def predict(model):
    val_loss = 0
    test_pred = []
    model.eval()
    for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
#         batch = batch[0]
        b_input_ids_titles = batch['titles']['input_ids'].to(device)
        b_attention_mask_titles = batch['titles']['attention_mask'].to(device)
        b_input_ids_abstracts = batch['abstracts']['input_ids'].to(device)
        b_attention_mask_abstracts = batch['abstracts']['attention_mask'].to(device)
        
        with torch.no_grad():
            logits = model(b_input_ids_titles, b_attention_mask_titles, b_input_ids_abstracts, b_attention_mask_abstracts)
            logits = torch.sigmoid(logits)
            logits = np.where(logits.to('cpu').detach().numpy().copy() < border, 0, 1)
            test_pred.extend(logits)
    
    test_pred = np.array(test_pred)
    return test_pred

test_pred = predict(best_model)

 16%|██████▌                                 | 208/1277 [01:41<08:27,  2.11it/s]

In [None]:
def submit():
    sample_submission = pd.read_csv('data/sample_submit.csv', names=('id', 'judgement'))
    ids = sample_submission['id'].values.reshape(-1,1)
    
    merged = np.concatenate((ids, test_pred), axis=1)
    submission = pd.DataFrame(merged, columns=sample_submission.columns).astype(int)
    return submission

submission = submit()

In [None]:
submission.to_csv('output/baseline.csv', index=False, header=False)

In [None]:
#validation accuracyをtensorboardで管理
#k_foldのscoreの平均値を出す
#f1からbetaスコアに変える