# Direct Finetuning Yejin
FakeNewsAAAI is a Fake News dataset with 2 possible labels: `real` and `fake`

In [1]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from tqdm import tqdm
import pickle
from copy import deepcopy
from multiprocessing import Pool                                                

from transformers import BertForSequenceClassification, RobertaForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from utils.forward_fn import forward_mask_sequence_classification
from utils.metrics import classification_metrics_fn
from utils.data_utils import FakeNewsDataset, FakeNewsDataLoader
from utils.utils import generate_random_mask

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.4f}'.format(key, value))
    return ' '.join(string_list)

In [4]:
def influence_score(model, id, subword, mask, label, device='cpu'):
    loss_fct = CrossEntropyLoss(reduction='none')
    with torch.no_grad():
        # Prepare input & label
        subword = torch.LongTensor(subword)
        mask = torch.FloatTensor(mask)
        label = torch.LongTensor(label)

        if device == "cuda":
            subword = subword.cuda()
            mask = mask.cuda()
            label = label.cuda()

        if isinstance(model, BertForSequenceClassification):
            # Apply mask
            weight, bias = model.classifier.weight, model.classifier.bias
            dropout_mask = generate_random_mask([id], weight.shape[0], weight.shape[1], device=device).repeat(subword.shape[0],1,1)
            masked_weight = weight.expand_as(dropout_mask) * dropout_mask

            # Calculate latents
            latents = model.bert(subword, attention_mask=mask)[1]
            latents = model.dropout(latents)            
        elif isinstance(model, RobertaForSequenceClassification):
            # Apply mask
            weight, bias = model.classifier.out_proj.weight, model.classifier.out_proj.bias
            dropout_mask = generate_random_mask([id], weight.shape[0], weight.shape[1], device=device).repeat(subword.shape[0],1,1)
            masked_weight = weight.expand_as(dropout_mask) * dropout_mask

            # Calculate latents
            latents = model.roberta(subword, attention_mask=mask)[0][:,0,:]
            latents = model.classifier.dense(latents)
            latents = model.classifier.dropout(latents)
        else:
            raise ValueError(f'Model class `{type(model)}` is not implemented yet')

        # Compute loss with mask
        logits = torch.einsum('bd,bcd->bc', latents, masked_weight) + bias
        mask_loss = loss_fct(logits.view(-1, model.num_labels), label.view(-1))

        # Compute loss with flipped mask
        logits = torch.einsum('bd,bcd->bc', latents, (masked_weight.max() - masked_weight)) + bias
        flipped_mask_loss = loss_fct(logits.view(-1, model.num_labels), label.view(-1))
                              
        return flipped_mask_loss - mask_loss
                              
def build_influence_matrix(model, data_loader, train_size, device='cpu'):
    test_size, batch_size = len(data_loader.dataset), data_loader.batch_size
    influence_mat = torch.zeros(test_size, train_size, device=device)
    for i, batch_data in enumerate(data_loader):
        print(f'Processing batch {i+1}/{len(data_loader)}')
        (ids, subword_batch, mask_batch, label_batch, seq_list) = batch_data
        token_type_batch = None

        for train_idx in tqdm(range(train_size)):
            train_id = train_idx + 1
            scores = influence_score(model, train_id, subword_batch, mask_batch, label_batch, device=device)
            for j, id in enumerate(ids):
                influence_mat[(i * batch_size) + j, train_idx] = scores[j]
    return influence_mat

def get_inference_result(model, data_loader, device='cpu'):
    results = {}
    with torch.no_grad():
        pbar = tqdm(data_loader, leave=True, total=len(data_loader))
        for i, batch_data in enumerate(pbar):
            batch_id = batch_data[0]
            batch_seq = batch_data[-1]
            outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=True, device='cuda')
            loss, batch_hyp, batch_label, logits, label_batch = outputs

            for i, id in enumerate(batch_id):
                results[id] = batch_hyp[i] == batch_label[i]
    return results

def get_filtered_dataloader(data_loader, id_list, inclusive=True, batch_size=2, shuffle=False):
    df = data_loader.dataset.data
    if inclusive:
        filt_df = df.loc[df['id'].isin(id_list),:].reset_index(drop=True)
    else:
        filt_df = df.loc[~df['id'].isin(id_list),:].reset_index(drop=True)
    dataset = FakeNewsDataset(dataset_path=None, dataset=filt_df, tokenizer=tokenizer, lowercase=False)
    data_loader = FakeNewsDataLoader(dataset=dataset, max_seq_len=512, batch_size=batch_size, num_workers=2, shuffle=shuffle)  
    return data_loader

In [5]:
# Set random seed
set_seed(26092020)

# Load Model

In [6]:
# Load Tokenizer and Config
tokenizer = AutoTokenizer.from_pretrained('roberta-large')
config = AutoConfig.from_pretrained('roberta-large')
config.num_labels = FakeNewsDataset.NUM_LABELS

# Instantiate model
model = AutoModelForSequenceClassification.from_pretrained('roberta-large', config=config)

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

In [7]:
count_param(model)

355361794

# Prepare Dataset

In [8]:
train_dataset_path = './data/train.tsv'
valid_dataset_path = './data/valid.tsv'
test_dataset_path = './data/test.tsv'

In [9]:
train_dataset = FakeNewsDataset(dataset_path=train_dataset_path, tokenizer=tokenizer, lowercase=False)
valid_dataset = FakeNewsDataset(dataset_path=valid_dataset_path, tokenizer=tokenizer, lowercase=False)
test_dataset = FakeNewsDataset(dataset_path=test_dataset_path, tokenizer=tokenizer, lowercase=False)

train_loader = FakeNewsDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=2, num_workers=2, shuffle=True)  
valid_loader = FakeNewsDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=2, num_workers=2, shuffle=False)  
test_loader = FakeNewsDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=2, num_workers=2, shuffle=False)

In [10]:
w2i, i2w = FakeNewsDataset.LABEL2INDEX, FakeNewsDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'fake': 0, 'real': 1}
{0: 'fake', 1: 'real'}


# Fine Tuning & Evaluation

In [11]:
optimizer = optim.Adam(model.parameters(), lr=3e-6)
model = model.cuda()

In [12]:
# Train without mask
n_epochs = 25
best_val_metric, best_metrics, best_state_dict = 0, None, None
early_stop, count_stop = 3, 0
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    if best_val_metric <= val_metric:
        torch.save(model.state_dict(), './tmp_india/model_direct_ft.pt')
        best_val_metric = val_metric
        best_metrics = metrics
        count_stop = 0
    else:
        count_stop += 1
        if count_stop == early_stop:
            break
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.5230 LR:0.00000300:   7%|▋         | 234/3210 [00:40<08:16,  6.00it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 1) TRAIN LOSS:0.4223 LR:0.00000300:  12%|█▏        | 401/3210 [01:08<07:53,  5.94it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 1) TRAIN LOSS:0.1606 LR:0.00000300: 100%|██████████| 3210/3210 [09:05<00:00,  5.89it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.1606 ACC:0.9366 F1:0.9364 REC:0.9361 PRE:0.9368 LR:0.00000300


VALID LOSS:0.0673 ACC:0.9739 F1:0.9738 REC:0.9746 PRE:0.9733:  80%|████████  | 861/1070 [00:31<00:08, 23.73it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0699 ACC:0.9743 F1:0.9743 REC:0.9747 PRE:0.9740: 100%|██████████| 1070/1070 [00:40<00:00, 26.48it/s]


(Epoch 1) VALID LOSS:0.0699 ACC:0.9743 F1:0.9743 REC:0.9747 PRE:0.9740


(Epoch 2) TRAIN LOSS:0.0501 LR:0.00000300:  27%|██▋       | 861/3210 [02:25<06:21,  6.16it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0543 LR:0.00000300:  54%|█████▍    | 1732/3210 [04:53<04:20,  5.68it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0532 LR:0.00000300: 100%|██████████| 3210/3210 [09:05<00:00,  5.88it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0532 ACC:0.9807 F1:0.9806 REC:0.9806 PRE:0.9807 LR:0.00000300


  _warn_prf(average, modifier, msg_start, len(result))
VALID LOSS:0.0782 ACC:0.9704 F1:0.9703 REC:0.9712 PRE:0.9698:  80%|████████  | 860/1070 [00:32<00:08, 23.93it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0825 ACC:0.9706 F1:0.9705 REC:0.9710 PRE:0.9703: 100%|██████████| 1070/1070 [00:41<00:00, 25.56it/s]
  0%|          | 0/3210 [00:00<?, ?it/s]

(Epoch 2) VALID LOSS:0.0825 ACC:0.9706 F1:0.9705 REC:0.9710 PRE:0.9703


(Epoch 3) TRAIN LOSS:0.0310 LR:0.00000300:  16%|█▌        | 509/3210 [01:27<07:54,  5.69it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0267 LR:0.00000300:  92%|█████████▏| 2954/3210 [08:21<00:42,  5.99it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0293 LR:0.00000300: 100%|██████████| 3210/3210 [09:04<00:00,  5.90it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0293 ACC:0.9896 F1:0.9895 REC:0.9895 PRE:0.9896 LR:0.00000300


VALID LOSS:0.0674 ACC:0.9762 F1:0.9761 REC:0.9770 PRE:0.9756:  81%|████████  | 862/1070 [00:31<00:08, 23.95it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0688 ACC:0.9757 F1:0.9757 REC:0.9760 PRE:0.9754: 100%|██████████| 1070/1070 [00:40<00:00, 26.24it/s]


(Epoch 3) VALID LOSS:0.0688 ACC:0.9757 F1:0.9757 REC:0.9760 PRE:0.9754


(Epoch 4) TRAIN LOSS:0.0236 LR:0.00000300:  40%|████      | 1291/3210 [03:36<05:15,  6.09it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 4) TRAIN LOSS:0.0211 LR:0.00000300: 100%|██████████| 3210/3210 [09:02<00:00,  5.92it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0211 ACC:0.9933 F1:0.9933 REC:0.9933 PRE:0.9933 LR:0.00000300


VALID LOSS:0.0597 ACC:0.9820 F1:0.9819 REC:0.9819 PRE:0.9820:  80%|████████  | 861/1070 [00:31<00:08, 23.78it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0658 ACC:0.9804 F1:0.9803 REC:0.9802 PRE:0.9805: 100%|██████████| 1070/1070 [00:40<00:00, 26.39it/s]


(Epoch 4) VALID LOSS:0.0658 ACC:0.9804 F1:0.9803 REC:0.9802 PRE:0.9805


(Epoch 5) TRAIN LOSS:0.0206 LR:0.00000300:  24%|██▎       | 755/3210 [02:07<07:02,  5.81it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 5) TRAIN LOSS:0.0164 LR:0.00000300: 100%|██████████| 3210/3210 [09:03<00:00,  5.91it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0164 ACC:0.9947 F1:0.9947 REC:0.9947 PRE:0.9947 LR:0.00000300


VALID LOSS:0.0767 ACC:0.9797 F1:0.9796 REC:0.9800 PRE:0.9793:  81%|████████  | 862/1070 [00:32<00:09, 22.80it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0910 ACC:0.9776 F1:0.9775 REC:0.9777 PRE:0.9774: 100%|██████████| 1070/1070 [00:41<00:00, 25.81it/s]
  0%|          | 0/3210 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:0.0910 ACC:0.9776 F1:0.9775 REC:0.9777 PRE:0.9774


(Epoch 6) TRAIN LOSS:0.0075 LR:0.00000300:  38%|███▊      | 1219/3210 [03:27<05:24,  6.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 6) TRAIN LOSS:0.0130 LR:0.00000300: 100%|██████████| 3210/3210 [09:07<00:00,  5.86it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0130 ACC:0.9960 F1:0.9959 REC:0.9960 PRE:0.9959 LR:0.00000300


VALID LOSS:0.0861 ACC:0.9803 F1:0.9802 REC:0.9798 PRE:0.9805:  80%|████████  | 861/1070 [00:31<00:08, 23.67it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0909 ACC:0.9780 F1:0.9780 REC:0.9777 PRE:0.9783: 100%|██████████| 1070/1070 [00:40<00:00, 26.50it/s]
  0%|          | 0/3210 [00:00<?, ?it/s]

(Epoch 6) VALID LOSS:0.0909 ACC:0.9780 F1:0.9780 REC:0.9777 PRE:0.9783


(Epoch 7) TRAIN LOSS:0.0008 LR:0.00000300:   4%|▍         | 141/3210 [00:23<08:33,  5.97it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 7) TRAIN LOSS:0.0151 LR:0.00000300:  77%|███████▋  | 2474/3210 [06:55<02:00,  6.12it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 7) TRAIN LOSS:0.0130 LR:0.00000300: 100%|██████████| 3210/3210 [09:01<00:00,  5.93it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.0130 ACC:0.9955 F1:0.9955 REC:0.9955 PRE:0.9955 LR:0.00000300


VALID LOSS:0.0973 ACC:0.9751 F1:0.9749 REC:0.9741 PRE:0.9759:  80%|████████  | 860/1070 [00:30<00:08, 23.98it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1024 ACC:0.9743 F1:0.9742 REC:0.9737 PRE:0.9750: 100%|██████████| 1070/1070 [00:40<00:00, 26.39it/s]

(Epoch 7) VALID LOSS:0.1024 ACC:0.9743 F1:0.9742 REC:0.9737 PRE:0.9750
== BEST METRICS ==
ACC:0.9804 F1:0.9803 REC:0.9802 PRE:0.9805





In [13]:
pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    valid_loss = loss.item()
    total_loss = total_loss + valid_loss

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label
    metrics = classification_metrics_fn(list_hyp, list_label)

    pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print('== EVAL METRICS ==')
print(metrics_to_string(eval_metrics))

TEST LOSS:0.3857 ACC:0.9738 F1:0.9738 REC:0.9733 PRE:0.9744:  41%|████      | 440/1070 [00:21<00:36, 17.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors
TEST LOSS:0.2220 ACC:0.9748 F1:0.9747 REC:0.9743 PRE:0.9753: 100%|██████████| 1070/1070 [00:58<00:00, 18.16it/s]


== EVAL METRICS ==
ACC:0.9748 F1:0.9747 REC:0.9743 PRE:0.9753 FAKE_ACC:0.9748 FAKE_F1:0.9733 FAKE_REC:0.9632 FAKE_PRE:0.9835 REAL_ACC:0.9748 REAL_F1:0.9761 REAL_REC:0.9853 REAL_PRE:0.9671


In [12]:
# Load Tokenizer and Config
tokenizer = AutoTokenizer.from_pretrained('roberta-large')
config = AutoConfig.from_pretrained('roberta-large')
config.num_labels = FakeNewsDataset.NUM_LABELS

# Instantiate model
model = AutoModelForSequenceClassification.from_pretrained('roberta-large', config=config)

optimizer = optim.Adam(model.parameters(), lr=3e-6)
model = model.cuda()

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classif

In [13]:
# Train with mask
n_epochs = 25
best_val_metric, best_metrics, best_state_dict = 0, None, None
early_stop, count_stop = 3, 0
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=True, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=True, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    if best_val_metric <= val_metric:
        torch.save(model.state_dict(), './tmp_india/model_weight.pt')
        best_val_metric = val_metric
        best_metrics = metrics
        count_stop = 0
    else:
        count_stop += 1
        if count_stop == early_stop:
            break
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.3705 LR:0.00000300:  29%|██▊       | 916/3210 [02:26<06:06,  6.27it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 1) TRAIN LOSS:0.1977 LR:0.00000300: 100%|██████████| 3210/3210 [08:38<00:00,  6.19it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.1977 ACC:0.9174 F1:0.9172 REC:0.9172 PRE:0.9173 LR:0.00000300


VALID LOSS:0.0708 ACC:0.9751 F1:0.9749 REC:0.9746 PRE:0.9753:  80%|████████  | 861/1070 [00:33<00:08, 23.85it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0743 ACC:0.9743 F1:0.9742 REC:0.9740 PRE:0.9746: 100%|██████████| 1070/1070 [00:42<00:00, 25.24it/s]


(Epoch 1) VALID LOSS:0.0743 ACC:0.9743 F1:0.9742 REC:0.9740 PRE:0.9746


(Epoch 2) TRAIN LOSS:0.0572 LR:0.00000300:  68%|██████▊   | 2190/3210 [06:05<02:55,  5.81it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0570 LR:0.00000300:  69%|██████▊   | 2205/3210 [06:07<02:53,  5.78it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0581 LR:0.00000300: 100%|██████████| 3210/3210 [08:51<00:00,  6.03it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0581 ACC:0.9788 F1:0.9788 REC:0.9786 PRE:0.9789 LR:0.00000300


VALID LOSS:0.0600 ACC:0.9751 F1:0.9749 REC:0.9753 PRE:0.9746:  81%|████████  | 862/1070 [00:31<00:08, 24.03it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0633 ACC:0.9748 F1:0.9747 REC:0.9749 PRE:0.9746: 100%|██████████| 1070/1070 [00:40<00:00, 26.14it/s]


(Epoch 2) VALID LOSS:0.0633 ACC:0.9748 F1:0.9747 REC:0.9749 PRE:0.9746


(Epoch 3) TRAIN LOSS:0.0106 LR:0.00000300:  68%|██████▊   | 2190/3210 [05:48<03:00,  5.65it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0106 LR:0.00000300:  69%|██████▊   | 2205/3210 [05:51<02:46,  6.04it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0108 LR:0.00000300: 100%|██████████| 3210/3210 [08:33<00:00,  6.25it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0108 ACC:0.9963 F1:0.9963 REC:0.9962 PRE:0.9963 LR:0.00000300


VALID LOSS:0.0954 ACC:0.9762 F1:0.9761 REC:0.9770 PRE:0.9756:  81%|████████  | 862/1070 [00:31<00:09, 21.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0962 ACC:0.9771 F1:0.9771 REC:0.9775 PRE:0.9768: 100%|██████████| 1070/1070 [00:41<00:00, 25.66it/s]


(Epoch 3) VALID LOSS:0.0962 ACC:0.9771 F1:0.9771 REC:0.9775 PRE:0.9768


(Epoch 4) TRAIN LOSS:0.0085 LR:0.00000300:  68%|██████▊   | 2190/3210 [05:53<02:58,  5.72it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 4) TRAIN LOSS:0.0085 LR:0.00000300:  69%|██████▊   | 2205/3210 [05:56<02:55,  5.73it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 4) TRAIN LOSS:0.0108 LR:0.00000300: 100%|██████████| 3210/3210 [08:44<00:00,  6.12it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0108 ACC:0.9967 F1:0.9967 REC:0.9967 PRE:0.9967 LR:0.00000300


  _warn_prf(average, modifier, msg_start, len(result))
VALID LOSS:0.0814 ACC:0.9745 F1:0.9743 REC:0.9738 PRE:0.9750:  81%|████████  | 862/1070 [00:31<00:08, 24.47it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.0878 ACC:0.9748 F1:0.9747 REC:0.9744 PRE:0.9751: 100%|██████████| 1070/1070 [00:40<00:00, 26.20it/s]
  0%|          | 0/3210 [00:00<?, ?it/s]

(Epoch 4) VALID LOSS:0.0878 ACC:0.9748 F1:0.9747 REC:0.9744 PRE:0.9751


(Epoch 5) TRAIN LOSS:0.0113 LR:0.00000300:  68%|██████▊   | 2190/3210 [05:49<02:49,  6.02it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 5) TRAIN LOSS:0.0112 LR:0.00000300:  69%|██████▊   | 2205/3210 [05:51<02:45,  6.07it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 5) TRAIN LOSS:0.0092 LR:0.00000300: 100%|██████████| 3210/3210 [08:34<00:00,  6.24it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0092 ACC:0.9975 F1:0.9975 REC:0.9975 PRE:0.9975 LR:0.00000300


VALID LOSS:0.1160 ACC:0.9751 F1:0.9749 REC:0.9738 PRE:0.9764:  80%|████████  | 860/1070 [00:33<00:08, 23.71it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1238 ACC:0.9734 F1:0.9733 REC:0.9725 PRE:0.9747: 100%|██████████| 1070/1070 [00:43<00:00, 24.82it/s]
  0%|          | 0/3210 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:0.1238 ACC:0.9734 F1:0.9733 REC:0.9725 PRE:0.9747


(Epoch 6) TRAIN LOSS:0.0048 LR:0.00000300:  68%|██████▊   | 2190/3210 [05:52<02:59,  5.69it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 6) TRAIN LOSS:0.0048 LR:0.00000300:  69%|██████▊   | 2205/3210 [05:55<02:57,  5.67it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 6) TRAIN LOSS:0.0081 LR:0.00000300: 100%|██████████| 3210/3210 [08:39<00:00,  6.18it/s]
  0%|          | 0/1070 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0081 ACC:0.9967 F1:0.9967 REC:0.9967 PRE:0.9967 LR:0.00000300


VALID LOSS:0.1046 ACC:0.9768 F1:0.9767 REC:0.9769 PRE:0.9765:  81%|████████  | 862/1070 [00:31<00:09, 21.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1093 ACC:0.9762 F1:0.9761 REC:0.9763 PRE:0.9760: 100%|██████████| 1070/1070 [00:40<00:00, 26.13it/s]


(Epoch 6) VALID LOSS:0.1093 ACC:0.9762 F1:0.9761 REC:0.9763 PRE:0.9760
== BEST METRICS ==
ACC:0.9771 F1:0.9771 REC:0.9775 PRE:0.9768
