# Direct Finetuning Yejin
FakeNewsAAAI is a Fake News dataset with 2 possible labels: `real` and `fake`

In [1]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.nn as nn
from tqdm import tqdm
import pickle
from copy import deepcopy
from multiprocessing import Pool                                                

from transformers import BertForSequenceClassification, RobertaForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from utils.forward_fn import forward_mask_sequence_classification
from utils.metrics import classification_metrics_fn
from utils.data_utils import FakeNewsDataset, FakeNewsDataLoader
from utils.utils import generate_random_mask

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.4f}'.format(key, value))
    return ' '.join(string_list)

In [4]:
def influence_score(model, id, subword, mask, label, device='cpu'):
    loss_fct = CrossEntropyLoss(reduction='none')
    with torch.no_grad():
        # Prepare input & label
        subword = torch.LongTensor(subword)
        mask = torch.FloatTensor(mask)
        label = torch.LongTensor(label)

        if device == "cuda":
            subword = subword.cuda()
            mask = mask.cuda()
            label = label.cuda()

        if isinstance(model, BertForSequenceClassification):
            # Apply mask
            weight, bias = model.classifier.weight, model.classifier.bias
            dropout_mask = generate_random_mask([id], weight.shape[0], weight.shape[1], device=device).repeat(subword.shape[0],1,1)
            masked_weight = weight.expand_as(dropout_mask) * dropout_mask

            # Calculate latents
            latents = model.bert(subword, attention_mask=mask)[1]
            latents = model.dropout(latents)            
        elif isinstance(model, RobertaForSequenceClassification):
            # Apply mask
            weight, bias = model.classifier.out_proj.weight, model.classifier.out_proj.bias
            dropout_mask = generate_random_mask([id], weight.shape[0], weight.shape[1], device=device).repeat(subword.shape[0],1,1)
            masked_weight = weight.expand_as(dropout_mask) * dropout_mask

            # Calculate latents
            latents = model.roberta(subword, attention_mask=mask)[0][:,0,:]
            latents = model.classifier.dense(latents)
            latents = model.classifier.dropout(latents)
        else:
            raise ValueError(f'Model class `{type(model)}` is not implemented yet')

        # Compute loss with mask
        logits = torch.einsum('bd,bcd->bc', latents, masked_weight) + bias
        mask_loss = loss_fct(logits.view(-1, model.num_labels), label.view(-1))

        # Compute loss with flipped mask
        logits = torch.einsum('bd,bcd->bc', latents, (masked_weight.max() - masked_weight)) + bias
        flipped_mask_loss = loss_fct(logits.view(-1, model.num_labels), label.view(-1))
                              
        return flipped_mask_loss - mask_loss
                              
def build_influence_matrix(model, data_loader, train_size, device='cpu'):
    test_size, batch_size = len(data_loader.dataset), data_loader.batch_size
    influence_mat = torch.zeros(test_size, train_size, device=device)
    for i, batch_data in enumerate(data_loader):
        print(f'Processing batch {i+1}/{len(data_loader)}')
        (ids, subword_batch, mask_batch, label_batch, seq_list) = batch_data
        token_type_batch = None

        for train_idx in tqdm(range(train_size)):
            train_id = train_idx + 1
            scores = influence_score(model, train_id, subword_batch, mask_batch, label_batch, device=device)
            for j, id in enumerate(ids):
                influence_mat[(i * batch_size) + j, train_idx] = scores[j]
    return influence_mat

def get_inference_result(model, data_loader, device='cpu'):
    results = {}
    with torch.no_grad():
        pbar = tqdm(data_loader, leave=True, total=len(data_loader))
        for i, batch_data in enumerate(pbar):
            batch_id = batch_data[0]
            batch_seq = batch_data[-1]
            outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=True, device='cuda')
            loss, batch_hyp, batch_label, logits, label_batch = outputs

            for i, id in enumerate(batch_id):
                results[id] = batch_hyp[i] == batch_label[i]
    return results

def get_filtered_dataloader(data_loader, id_list, inclusive=True, batch_size=2, shuffle=False):
    df = data_loader.dataset.data
    if inclusive:
        filt_df = df.loc[df['id'].isin(id_list),:].reset_index(drop=True)
    else:
        filt_df = df.loc[~df['id'].isin(id_list),:].reset_index(drop=True)
    dataset = FakeNewsDataset(dataset_path=None, dataset=filt_df, tokenizer=tokenizer, lowercase=False)
    data_loader = FakeNewsDataLoader(dataset=dataset, max_seq_len=512, batch_size=batch_size, num_workers=2, shuffle=shuffle)  
    return data_loader

In [5]:
# Set random seed
set_seed(26092020)

# Load Model

In [6]:
# Load Tokenizer and Config
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
config = AutoConfig.from_pretrained('roberta-base')
config.num_labels = FakeNewsDataset.NUM_LABELS

# Instantiate model
model = AutoModelForSequenceClassification.from_pretrained('roberta-base', config=config)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

In [7]:
count_param(model)

124647170

# Prepare Dataset

In [8]:
train_dataset_path = './data/train.tsv'
valid_dataset_path = './data/valid.tsv'
test_dataset_path = './data/test.tsv'
valid_zero_dataset_path = './data/covid19_infodemic_english_data/processed_valid_data.tsv'
test_zero_dataset_path = './data/covid19_infodemic_english_data/processed_test_data.tsv'

In [9]:
train_dataset = FakeNewsDataset(dataset_path=train_dataset_path, tokenizer=tokenizer, lowercase=False)
valid_dataset = FakeNewsDataset(dataset_path=valid_dataset_path, tokenizer=tokenizer, lowercase=False)
test_dataset = FakeNewsDataset(dataset_path=test_dataset_path, tokenizer=tokenizer, lowercase=False)
valid_zero_dataset = FakeNewsDataset(dataset_path=valid_zero_dataset_path, tokenizer=tokenizer, lowercase=False)
test_zero_dataset = FakeNewsDataset(dataset_path=test_zero_dataset_path, tokenizer=tokenizer, lowercase=False)

train_loader = FakeNewsDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=True)  
valid_loader = FakeNewsDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=False)  
test_loader = FakeNewsDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=False)  
valid_zero_loader = FakeNewsDataLoader(dataset=valid_zero_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=False)
test_zero_loader = FakeNewsDataLoader(dataset=test_zero_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=False)

In [10]:
from sklearn.model_selection import train_test_split
_, valid_id = train_test_split(valid_zero_dataset.data['id'], test_size=15, stratify=valid_zero_dataset.data['label'], random_state=12345)

In [11]:
valid_id = valid_id.tolist()
new_train_loader = get_filtered_dataloader(valid_zero_loader, valid_id, inclusive=False, batch_size=8, shuffle=False)
new_valid_loader = get_filtered_dataloader(valid_zero_loader, valid_id, inclusive=True, batch_size=8, shuffle=False)

In [12]:
w2i, i2w = FakeNewsDataset.LABEL2INDEX, FakeNewsDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'fake': 0, 'real': 1}
{0: 'fake', 1: 'real'}


# Fine Tuning & Evaluation

In [13]:
optimizer = optim.Adam(model.parameters(), lr=3e-6)
model = model.cuda()

In [14]:
# Train without mask
n_epochs = 25
best_val_metric, best_metrics, best_state_dict = 0, None, None
early_stop, count_stop = 5, 0
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(new_train_loader, leave=True, total=len(new_train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(new_valid_loader, leave=True, total=len(new_valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    if best_val_metric <= val_metric:
        torch.save(model.state_dict(), './tmp_hessian/model_direct_ft.pt')
        best_val_metric = val_metric
        best_metrics = metrics
        count_stop = 0
    else:
        count_stop += 1
        if count_stop == early_stop:
            break
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.6746 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.01it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.6746 ACC:0.7333 F1:0.4925 REC:0.5160 PRE:0.5476 LR:0.00000300


  _warn_prf(average, modifier, msg_start, len(result))
VALID LOSS:0.6643 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.41it/s]


(Epoch 1) VALID LOSS:0.6643 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 2) TRAIN LOSS:0.6574 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.47it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.6574 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6539 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.34it/s]


(Epoch 2) VALID LOSS:0.6539 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 3) TRAIN LOSS:0.6210 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.38it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.6210 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6417 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.14it/s]


(Epoch 3) VALID LOSS:0.6417 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 4) TRAIN LOSS:0.6106 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.48it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.6106 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6296 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.80it/s]


(Epoch 4) VALID LOSS:0.6296 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 5) TRAIN LOSS:0.6128 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.64it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.6128 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6176 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.46it/s]


(Epoch 5) VALID LOSS:0.6176 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 6) TRAIN LOSS:0.6141 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.49it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.6141 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6058 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.09it/s]


(Epoch 6) VALID LOSS:0.6058 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 7) TRAIN LOSS:0.5722 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.63it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.5722 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5922 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.86it/s]


(Epoch 7) VALID LOSS:0.5922 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 8) TRAIN LOSS:0.5326 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.65it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.5326 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5763 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.15it/s]


(Epoch 8) VALID LOSS:0.5763 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 9) TRAIN LOSS:0.4988 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.50it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 9) TRAIN LOSS:0.4988 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5588 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.37it/s]


(Epoch 9) VALID LOSS:0.5588 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 10) TRAIN LOSS:0.4665 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.19it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 10) TRAIN LOSS:0.4665 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5434 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.23it/s]


(Epoch 10) VALID LOSS:0.5434 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 11) TRAIN LOSS:0.4402 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.03it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 11) TRAIN LOSS:0.4402 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5399 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.30it/s]


(Epoch 11) VALID LOSS:0.5399 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 12) TRAIN LOSS:0.4073 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.55it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 12) TRAIN LOSS:0.4073 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5532 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.10it/s]


(Epoch 12) VALID LOSS:0.5532 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 13) TRAIN LOSS:0.4335 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.69it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 13) TRAIN LOSS:0.4335 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5557 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.20it/s]


(Epoch 13) VALID LOSS:0.5557 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 14) TRAIN LOSS:0.3837 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.33it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 14) TRAIN LOSS:0.3837 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5474 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.04it/s]


(Epoch 14) VALID LOSS:0.5474 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 15) TRAIN LOSS:0.3755 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.52it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 15) TRAIN LOSS:0.3755 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5416 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.02it/s]


(Epoch 15) VALID LOSS:0.5416 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 16) TRAIN LOSS:0.3995 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.31it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 16) TRAIN LOSS:0.3995 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5474 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.18it/s]


(Epoch 16) VALID LOSS:0.5474 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 17) TRAIN LOSS:0.3483 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.42it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 17) TRAIN LOSS:0.3483 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5610 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.10it/s]


(Epoch 17) VALID LOSS:0.5610 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 18) TRAIN LOSS:0.3609 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 18) TRAIN LOSS:0.3609 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5747 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.32it/s]


(Epoch 18) VALID LOSS:0.5747 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 19) TRAIN LOSS:0.3255 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.75it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 19) TRAIN LOSS:0.3255 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5756 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.48it/s]


(Epoch 19) VALID LOSS:0.5756 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 20) TRAIN LOSS:0.2946 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.40it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 20) TRAIN LOSS:0.2946 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5779 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.95it/s]


(Epoch 20) VALID LOSS:0.5779 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 21) TRAIN LOSS:0.2868 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.47it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 21) TRAIN LOSS:0.2868 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.5893 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.17it/s]


(Epoch 21) VALID LOSS:0.5893 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 22) TRAIN LOSS:0.2726 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 22) TRAIN LOSS:0.2726 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6033 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.05it/s]


(Epoch 22) VALID LOSS:0.6033 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 23) TRAIN LOSS:0.2709 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.52it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 23) TRAIN LOSS:0.2709 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6074 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.70it/s]


(Epoch 23) VALID LOSS:0.6074 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 24) TRAIN LOSS:0.2582 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.39it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 24) TRAIN LOSS:0.2582 ACC:0.7556 F1:0.4304 REC:0.5000 PRE:0.3778 LR:0.00000300


VALID LOSS:0.6255 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  7.01it/s]


(Epoch 24) VALID LOSS:0.6255 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


(Epoch 25) TRAIN LOSS:0.2461 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.37it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 25) TRAIN LOSS:0.2461 ACC:0.7778 F1:0.5192 REC:0.5455 PRE:0.8864 LR:0.00000300


VALID LOSS:0.6447 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667: 100%|██████████| 2/2 [00:00<00:00,  6.81it/s]


(Epoch 25) VALID LOSS:0.6447 ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667
== BEST METRICS ==
ACC:0.7333 F1:0.4231 REC:0.5000 PRE:0.3667


In [15]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_ft.pt'))

<All keys matched successfully>

In [16]:
# Test on Indian dataset
pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    valid_loss = loss.item()
    total_loss = total_loss + valid_loss

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label
    metrics = classification_metrics_fn(list_hyp, list_label)

    pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print('== EVAL METRICS ==')
print(metrics_to_string(eval_metrics))

TEST LOSS:0.8498 ACC:0.5084 F1:0.3370 REC:0.5000 PRE:0.2542:  34%|███▍      | 92/268 [00:03<00:06, 28.12it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors
TEST LOSS:0.8313 ACC:0.5248 F1:0.3442 REC:0.5000 PRE:0.2624: 100%|██████████| 268/268 [00:10<00:00, 24.89it/s]


== EVAL METRICS ==
ACC:0.5248 F1:0.3442 REC:0.5000 PRE:0.2624 FAKE_ACC:0.5248 FAKE_F1:0.0000 FAKE_REC:0.0000 FAKE_PRE:0.0000 REAL_ACC:0.5248 REAL_F1:0.6884 REAL_REC:1.0000 REAL_PRE:0.5248


  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
# Test on Yejin test set
model.eval()
torch.set_grad_enabled(False)

total_loss = 0
list_hyp, list_label = [], []
pbar = tqdm(test_zero_loader, leave=True, total=len(test_zero_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    total_loss += loss.item()

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label

    pbar.set_description("TEST LOSS:{:.4f}".format(total_loss/(i+1)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro', pos_label='fake')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print(f'TEST RESULT: {metrics_to_string(eval_metrics)}')

TEST LOSS:0.5722: 100%|██████████| 30/30 [00:01<00:00, 22.27it/s]

TEST RESULT: ACC:0.7511 F1:0.4289 REC:0.5000 PRE:0.3755 FAKE_ACC:0.7511 FAKE_F1:0.0000 FAKE_REC:0.0000 FAKE_PRE:0.0000 REAL_ACC:0.7511 REAL_F1:0.8578 REAL_REC:1.0000 REAL_PRE:0.7511





# Combine data

In [18]:
combined_train_df = pd.concat([train_loader.dataset.data, new_train_loader.dataset.data]).reset_index(drop=True)
combined_valid_df = pd.concat([valid_loader.dataset.data, new_valid_loader.dataset.data]).reset_index(drop=True)

comb_train_dataset = FakeNewsDataset(dataset_path=None, dataset=combined_train_df, tokenizer=tokenizer, lowercase=False)
comb_valid_dataset = FakeNewsDataset(dataset_path=None, dataset=combined_valid_df, tokenizer=tokenizer, lowercase=False)

comb_train_loader = FakeNewsDataLoader(dataset=comb_train_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=True)
comb_valid_loader = FakeNewsDataLoader(dataset=comb_valid_dataset, max_seq_len=512, batch_size=8, num_workers=8, shuffle=False)

In [19]:
# Train without mask
n_epochs = 25
best_val_metric, best_metrics, best_state_dict = 0, None, None
early_stop, count_stop = 5, 0
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(comb_train_loader, leave=True, total=len(comb_train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(comb_valid_loader, leave=True, total=len(comb_valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    if best_val_metric <= val_metric:
        torch.save(model.state_dict(), './tmp_hessian/model_direct_combine_ft.pt')
        best_val_metric = val_metric
        best_metrics = metrics
        count_stop = 0
    else:
        count_stop += 1
        if count_stop == early_stop:
            break
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.3335 LR:0.00000300:  45%|████▌     | 368/809 [00:38<00:46,  9.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 1) TRAIN LOSS:0.3089 LR:0.00000300:  54%|█████▍    | 438/809 [00:46<00:36, 10.30it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 1) TRAIN LOSS:0.2251 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.55it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.2251 ACC:0.9117 F1:0.9111 REC:0.9099 PRE:0.9140 LR:0.00000300


VALID LOSS:0.1189 ACC:0.9537 F1:0.9537 REC:0.9547 PRE:0.9532:  73%|███████▎  | 198/270 [00:07<00:03, 23.62it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1458 ACC:0.9476 F1:0.9475 REC:0.9486 PRE:0.9475: 100%|██████████| 270/270 [00:10<00:00, 25.27it/s]


(Epoch 1) VALID LOSS:0.1458 ACC:0.9476 F1:0.9475 REC:0.9486 PRE:0.9475


(Epoch 2) TRAIN LOSS:0.0937 LR:0.00000300:  46%|████▌     | 371/809 [00:38<00:46,  9.49it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0910 LR:0.00000300:  74%|███████▍  | 597/809 [01:02<00:20, 10.21it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 2) TRAIN LOSS:0.0929 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.53it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0929 ACC:0.9678 F1:0.9677 REC:0.9676 PRE:0.9678 LR:0.00000300


VALID LOSS:0.1128 ACC:0.9606 F1:0.9603 REC:0.9589 PRE:0.9629:  73%|███████▎  | 198/270 [00:07<00:03, 23.62it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1278 ACC:0.9578 F1:0.9576 REC:0.9566 PRE:0.9595: 100%|██████████| 270/270 [00:10<00:00, 24.73it/s]


(Epoch 2) VALID LOSS:0.1278 ACC:0.9578 F1:0.9576 REC:0.9566 PRE:0.9595


(Epoch 3) TRAIN LOSS:0.0773 LR:0.00000300:   2%|▏         | 16/809 [00:02<01:31,  8.70it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0616 LR:0.00000300:  71%|███████   | 575/809 [01:00<00:22, 10.29it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 3) TRAIN LOSS:0.0585 LR:0.00000300: 100%|██████████| 809/809 [01:25<00:00,  9.49it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0585 ACC:0.9813 F1:0.9812 REC:0.9812 PRE:0.9813 LR:0.00000300


VALID LOSS:0.1177 ACC:0.9637 F1:0.9634 REC:0.9618 PRE:0.9667:  74%|███████▍  | 200/270 [00:07<00:02, 24.50it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1449 ACC:0.9578 F1:0.9575 REC:0.9564 PRE:0.9599: 100%|██████████| 270/270 [00:10<00:00, 25.38it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 3) VALID LOSS:0.1449 ACC:0.9578 F1:0.9575 REC:0.9564 PRE:0.9599


(Epoch 4) TRAIN LOSS:0.0402 LR:0.00000300:  41%|████      | 329/809 [00:34<00:48,  9.99it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 4) TRAIN LOSS:0.0420 LR:0.00000300:  78%|███████▊  | 631/809 [01:06<00:17, 10.08it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 4) TRAIN LOSS:0.0451 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.56it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0451 ACC:0.9856 F1:0.9856 REC:0.9856 PRE:0.9856 LR:0.00000300


VALID LOSS:0.0933 ACC:0.9700 F1:0.9698 REC:0.9688 PRE:0.9714:  73%|███████▎  | 198/270 [00:07<00:03, 23.84it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1162 ACC:0.9633 F1:0.9632 REC:0.9624 PRE:0.9645: 100%|██████████| 270/270 [00:10<00:00, 25.49it/s]


(Epoch 4) VALID LOSS:0.1162 ACC:0.9633 F1:0.9632 REC:0.9624 PRE:0.9645


(Epoch 5) TRAIN LOSS:0.0282 LR:0.00000300:  42%|████▏     | 342/809 [00:35<00:50,  9.31it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 5) TRAIN LOSS:0.0301 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.57it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0301 ACC:0.9909 F1:0.9909 REC:0.9908 PRE:0.9909 LR:0.00000300


VALID LOSS:0.1042 ACC:0.9712 F1:0.9711 REC:0.9699 PRE:0.9730:  73%|███████▎  | 198/270 [00:07<00:03, 22.83it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1456 ACC:0.9624 F1:0.9622 REC:0.9614 PRE:0.9638: 100%|██████████| 270/270 [00:10<00:00, 24.70it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:0.1456 ACC:0.9624 F1:0.9622 REC:0.9614 PRE:0.9638


(Epoch 6) TRAIN LOSS:0.0155 LR:0.00000300:   6%|▋         | 51/809 [00:05<01:14, 10.16it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 6) TRAIN LOSS:0.0233 LR:0.00000300:  43%|████▎     | 344/809 [00:36<00:48,  9.65it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 6) TRAIN LOSS:0.0260 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.57it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0260 ACC:0.9909 F1:0.9909 REC:0.9909 PRE:0.9908 LR:0.00000300


VALID LOSS:0.0747 ACC:0.9769 F1:0.9768 REC:0.9766 PRE:0.9770:  74%|███████▎  | 199/270 [00:07<00:02, 24.40it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1107 ACC:0.9689 F1:0.9688 REC:0.9688 PRE:0.9689: 100%|██████████| 270/270 [00:10<00:00, 25.54it/s]


(Epoch 6) VALID LOSS:0.1107 ACC:0.9689 F1:0.9688 REC:0.9688 PRE:0.9689


(Epoch 7) TRAIN LOSS:0.0155 LR:0.00000300:  17%|█▋        | 141/809 [00:15<01:05, 10.18it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 7) TRAIN LOSS:0.0162 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.60it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.0162 ACC:0.9951 F1:0.9950 REC:0.9951 PRE:0.9950 LR:0.00000300


VALID LOSS:0.1075 ACC:0.9706 F1:0.9704 REC:0.9694 PRE:0.9721:  74%|███████▍  | 200/270 [00:07<00:02, 24.47it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1433 ACC:0.9624 F1:0.9622 REC:0.9615 PRE:0.9636: 100%|██████████| 270/270 [00:10<00:00, 25.92it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 7) VALID LOSS:0.1433 ACC:0.9624 F1:0.9622 REC:0.9615 PRE:0.9636


(Epoch 8) TRAIN LOSS:0.0141 LR:0.00000300:  36%|███▌      | 291/809 [00:30<00:54,  9.57it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 8) TRAIN LOSS:0.0137 LR:0.00000300:  37%|███▋      | 299/809 [00:31<00:53,  9.59it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 8) TRAIN LOSS:0.0162 LR:0.00000300: 100%|██████████| 809/809 [01:23<00:00,  9.65it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.0162 ACC:0.9952 F1:0.9952 REC:0.9952 PRE:0.9952 LR:0.00000300


VALID LOSS:0.1144 ACC:0.9756 F1:0.9755 REC:0.9742 PRE:0.9775:  74%|███████▍  | 200/270 [00:07<00:02, 24.12it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1531 ACC:0.9652 F1:0.9650 REC:0.9642 PRE:0.9666: 100%|██████████| 270/270 [00:10<00:00, 25.57it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 8) VALID LOSS:0.1531 ACC:0.9652 F1:0.9650 REC:0.9642 PRE:0.9666


(Epoch 9) TRAIN LOSS:0.0060 LR:0.00000300:  35%|███▍      | 280/809 [00:29<01:00,  8.79it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 9) TRAIN LOSS:0.0127 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.59it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 9) TRAIN LOSS:0.0127 ACC:0.9949 F1:0.9949 REC:0.9949 PRE:0.9949 LR:0.00000300


VALID LOSS:0.0972 ACC:0.9775 F1:0.9774 REC:0.9768 PRE:0.9782:  74%|███████▎  | 199/270 [00:07<00:02, 25.26it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1411 ACC:0.9689 F1:0.9688 REC:0.9684 PRE:0.9694: 100%|██████████| 270/270 [00:10<00:00, 25.86it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 9) VALID LOSS:0.1411 ACC:0.9689 F1:0.9688 REC:0.9684 PRE:0.9694


(Epoch 10) TRAIN LOSS:0.0204 LR:0.00000300:  12%|█▏        | 100/809 [00:10<01:14,  9.47it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 10) TRAIN LOSS:0.0091 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.56it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 10) TRAIN LOSS:0.0091 ACC:0.9968 F1:0.9967 REC:0.9967 PRE:0.9968 LR:0.00000300


VALID LOSS:0.1123 ACC:0.9769 F1:0.9767 REC:0.9759 PRE:0.9780:  74%|███████▎  | 199/270 [00:07<00:02, 24.40it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1574 ACC:0.9675 F1:0.9674 REC:0.9669 PRE:0.9682: 100%|██████████| 270/270 [00:10<00:00, 25.72it/s]
  0%|          | 0/809 [00:00<?, ?it/s]

(Epoch 10) VALID LOSS:0.1574 ACC:0.9675 F1:0.9674 REC:0.9669 PRE:0.9682


(Epoch 11) TRAIN LOSS:0.0013 LR:0.00000300:  16%|█▌        | 131/809 [00:13<01:07, 10.01it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1925 > 512). Running this sequence through the model will result in indexing errors
(Epoch 11) TRAIN LOSS:0.0065 LR:0.00000300:  77%|███████▋  | 625/809 [01:05<00:19,  9.24it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 512). Running this sequence through the model will result in indexing errors
(Epoch 11) TRAIN LOSS:0.0085 LR:0.00000300: 100%|██████████| 809/809 [01:24<00:00,  9.59it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

(Epoch 11) TRAIN LOSS:0.0085 ACC:0.9969 F1:0.9969 REC:0.9969 PRE:0.9969 LR:0.00000300


VALID LOSS:0.1184 ACC:0.9738 F1:0.9736 REC:0.9725 PRE:0.9754:  73%|███████▎  | 198/270 [00:07<00:02, 24.41it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
VALID LOSS:0.1635 ACC:0.9657 F1:0.9655 REC:0.9647 PRE:0.9669: 100%|██████████| 270/270 [00:10<00:00, 25.93it/s]


(Epoch 11) VALID LOSS:0.1635 ACC:0.9657 F1:0.9655 REC:0.9647 PRE:0.9669
== BEST METRICS ==
ACC:0.9689 F1:0.9688 REC:0.9688 PRE:0.9689


In [21]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_combine_ft.pt'))

<All keys matched successfully>

In [22]:
pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    valid_loss = loss.item()
    total_loss = total_loss + valid_loss

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label
    metrics = classification_metrics_fn(list_hyp, list_label)

    pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print('== EVAL METRICS ==')
print(metrics_to_string(eval_metrics))

TEST LOSS:0.5615 ACC:0.9674 F1:0.9673 REC:0.9668 PRE:0.9681:  35%|███▌      | 94/268 [00:05<00:08, 19.88it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors
TEST LOSS:0.2385 ACC:0.9721 F1:0.9720 REC:0.9715 PRE:0.9726: 100%|██████████| 268/268 [00:15<00:00, 17.22it/s]


== EVAL METRICS ==
ACC:0.9721 F1:0.9720 REC:0.9715 PRE:0.9726 FAKE_ACC:0.9721 FAKE_F1:0.9703 FAKE_REC:0.9604 FAKE_PRE:0.9805 REAL_ACC:0.9721 REAL_F1:0.9736 REAL_REC:0.9827 REAL_PRE:0.9647


In [23]:
# Test on Yejin test set
model.eval()
torch.set_grad_enabled(False)

total_loss = 0
list_hyp, list_label = [], []
pbar = tqdm(test_zero_loader, leave=True, total=len(test_zero_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    total_loss += loss.item()

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label

    pbar.set_description("TEST LOSS:{:.4f}".format(total_loss/(i+1)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro', pos_label='fake')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print(f'TEST RESULT: {metrics_to_string(eval_metrics)}')

TEST LOSS:2.5691: 100%|██████████| 30/30 [00:01<00:00, 23.20it/s]

TEST RESULT: ACC:0.4895 F1:0.4894 REC:0.6601 PRE:0.6639 FAKE_ACC:0.4895 FAKE_F1:0.4937 FAKE_REC:1.0000 FAKE_PRE:0.3278 REAL_ACC:0.4895 REAL_F1:0.4851 REAL_REC:0.3202 REAL_PRE:1.0000





# Additional Training on Yejin Dataset

In [18]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_combine_ft.pt'))

<All keys matched successfully>

In [19]:
optimizer = optim.Adam(model.parameters(), lr=3e-6)
model = model.cuda()

In [20]:
# Train without mask
n_epochs = 25
best_val_metric, best_metrics, best_state_dict = 0, None, None
early_stop, count_stop = 5, 0
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(new_train_loader, leave=True, total=len(new_train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(new_valid_loader, leave=True, total=len(new_valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    if best_val_metric <= val_metric:
        torch.save(model.state_dict(), './tmp_hessian/model_direct_addi_ft.pt')
        best_val_metric = val_metric
        best_metrics = metrics
        count_stop = 0
    else:
        count_stop += 1
        if count_stop == early_stop:
            break
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.1539 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.29it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.1539 ACC:0.9556 F1:0.9432 REC:0.9706 PRE:0.9231 LR:0.00000300


VALID LOSS:1.5242 ACC:0.6667 F1:0.6606 REC:0.7727 PRE:0.7222: 100%|██████████| 2/2 [00:00<00:00,  7.08it/s]


(Epoch 1) VALID LOSS:1.5242 ACC:0.6667 F1:0.6606 REC:0.7727 PRE:0.7222


(Epoch 2) TRAIN LOSS:0.0965 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.28it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0965 ACC:0.9778 F1:0.9689 REC:0.9545 PRE:0.9857 LR:0.00000300


VALID LOSS:1.1694 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.00it/s]
  0%|          | 0/6 [00:00<?, ?it/s]

(Epoch 2) VALID LOSS:1.1694 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 3) TRAIN LOSS:0.0238 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.46it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0238 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.2959 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.39it/s]
  0%|          | 0/6 [00:00<?, ?it/s]

(Epoch 3) VALID LOSS:1.2959 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 4) TRAIN LOSS:0.0071 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.74it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0071 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.4431 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.10it/s]
  0%|          | 0/6 [00:00<?, ?it/s]

(Epoch 4) VALID LOSS:1.4431 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 5) TRAIN LOSS:0.0095 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.51it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0095 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.4993 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.01it/s]
  0%|          | 0/6 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:1.4993 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 6) TRAIN LOSS:0.0050 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.14it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0050 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.5138 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.01it/s]

(Epoch 6) VALID LOSS:1.5138 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161
== BEST METRICS ==
ACC:0.6667 F1:0.6606 REC:0.7727 PRE:0.7222





In [21]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_addi_ft.pt'))

<All keys matched successfully>

In [22]:
# Test on Indian dataset
pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    valid_loss = loss.item()
    total_loss = total_loss + valid_loss

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label
    metrics = classification_metrics_fn(list_hyp, list_label)

    pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print('== EVAL METRICS ==')
print(metrics_to_string(eval_metrics))

TEST LOSS:0.2029 ACC:0.9497 F1:0.9496 REC:0.9492 PRE:0.9517:  35%|███▌      | 95/268 [00:03<00:05, 30.88it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors
TEST LOSS:0.1680 ACC:0.9555 F1:0.9551 REC:0.9536 PRE:0.9591: 100%|██████████| 268/268 [00:10<00:00, 25.90it/s]


== EVAL METRICS ==
ACC:0.9555 F1:0.9551 REC:0.9536 PRE:0.9591 FAKE_ACC:0.9555 FAKE_F1:0.9513 FAKE_REC:0.9160 FAKE_PRE:0.9895 REAL_ACC:0.9555 REAL_F1:0.9589 REAL_REC:0.9912 REAL_PRE:0.9287


In [23]:
# Test on Yejin test set
model.eval()
torch.set_grad_enabled(False)

total_loss = 0
list_hyp, list_label = [], []
pbar = tqdm(test_zero_loader, leave=True, total=len(test_zero_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    total_loss += loss.item()

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label

    pbar.set_description("TEST LOSS:{:.4f}".format(total_loss/(i+1)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro', pos_label='fake')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print(f'TEST RESULT: {metrics_to_string(eval_metrics)}')

TEST LOSS:1.3612: 100%|██████████| 30/30 [00:01<00:00, 22.27it/s]

TEST RESULT: ACC:0.6456 F1:0.6297 REC:0.7244 PRE:0.6691 FAKE_ACC:0.6456 FAKE_F1:0.5532 FAKE_REC:0.8814 FAKE_PRE:0.4031 REAL_ACC:0.6456 REAL_F1:0.7063 REAL_REC:0.5674 REAL_PRE:0.9352





# Second Additional Training on Yejin Dataset

In [34]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_addi_ft.pt'))

<All keys matched successfully>

In [35]:
optimizer = optim.Adam(model.parameters(), lr=3e-6)
model = model.cuda()

In [36]:
# Train without mask
n_epochs = 10
best_val_metric, best_metrics, best_state_dict = 0, None, None
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(new_train_loader, leave=True, total=len(new_train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(new_valid_loader, leave=True, total=len(new_valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
        loss, batch_hyp, batch_label, logits, label_batch = outputs
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = classification_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = classification_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))
    
    # Early stopping
    val_metric = metrics['F1']
    torch.save(model.state_dict(), './tmp_hessian/model_direct_addi_ft_2.pt')
    best_val_metric = val_metric
    best_metrics = metrics
            
print('== BEST METRICS ==')
print(metrics_to_string(best_metrics))

(Epoch 1) TRAIN LOSS:0.1056 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.29it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.1056 ACC:0.9778 F1:0.9689 REC:0.9545 PRE:0.9857 LR:0.00000300


VALID LOSS:1.1364 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179: 100%|██████████| 2/2 [00:00<00:00,  7.33it/s]


(Epoch 1) VALID LOSS:1.1364 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179


(Epoch 2) TRAIN LOSS:0.0109 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.39it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0109 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.3561 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.59it/s]


(Epoch 2) VALID LOSS:1.3561 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 3) TRAIN LOSS:0.0067 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.64it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0067 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.5694 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.18it/s]


(Epoch 3) VALID LOSS:1.5694 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 4) TRAIN LOSS:0.0071 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.35it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0071 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.6214 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  6.94it/s]


(Epoch 4) VALID LOSS:1.6214 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 5) TRAIN LOSS:0.0032 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.74it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0032 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.6076 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.28it/s]


(Epoch 5) VALID LOSS:1.6076 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 6) TRAIN LOSS:0.0034 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.46it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0034 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.5743 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.01it/s]


(Epoch 6) VALID LOSS:1.5743 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 7) TRAIN LOSS:0.0022 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.48it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.0022 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.5339 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161: 100%|██████████| 2/2 [00:00<00:00,  7.25it/s]


(Epoch 7) VALID LOSS:1.5339 ACC:0.6000 F1:0.5833 REC:0.6477 PRE:0.6161


(Epoch 8) TRAIN LOSS:0.0020 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.57it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.0020 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.4888 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179: 100%|██████████| 2/2 [00:00<00:00,  7.17it/s]


(Epoch 8) VALID LOSS:1.4888 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179


(Epoch 9) TRAIN LOSS:0.0021 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.42it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 9) TRAIN LOSS:0.0021 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.4677 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179: 100%|██████████| 2/2 [00:00<00:00,  6.79it/s]


(Epoch 9) VALID LOSS:1.4677 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179


(Epoch 10) TRAIN LOSS:0.0012 LR:0.00000300: 100%|██████████| 6/6 [00:00<00:00,  7.54it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 10) TRAIN LOSS:0.0012 ACC:1.0000 F1:1.0000 REC:1.0000 PRE:1.0000 LR:0.00000300


VALID LOSS:1.4609 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179: 100%|██████████| 2/2 [00:00<00:00,  6.85it/s]


(Epoch 10) VALID LOSS:1.4609 ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179
== BEST METRICS ==
ACC:0.5333 F1:0.4976 REC:0.5227 PRE:0.5179


In [37]:
# Load best model
model.load_state_dict(torch.load('./tmp_hessian/model_direct_addi_ft_2.pt'))

<All keys matched successfully>

In [38]:
# Test on Indian dataset
pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    valid_loss = loss.item()
    total_loss = total_loss + valid_loss

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label
    metrics = classification_metrics_fn(list_hyp, list_label)

    pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print('== EVAL METRICS ==')
print(metrics_to_string(eval_metrics))

TEST LOSS:0.3027 ACC:0.9368 F1:0.9366 REC:0.9360 PRE:0.9405:  35%|███▌      | 95/268 [00:03<00:05, 29.71it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors
TEST LOSS:0.2840 ACC:0.9336 F1:0.9329 REC:0.9307 PRE:0.9412: 100%|██████████| 268/268 [00:10<00:00, 25.25it/s]


== EVAL METRICS ==
ACC:0.9336 F1:0.9329 REC:0.9307 PRE:0.9412 FAKE_ACC:0.9336 FAKE_F1:0.9258 FAKE_REC:0.8711 FAKE_PRE:0.9878 REAL_ACC:0.9336 REAL_F1:0.9400 REAL_REC:0.9903 REAL_PRE:0.8946


In [39]:
# Test on Yejin test set
model.eval()
torch.set_grad_enabled(False)

total_loss = 0
list_hyp, list_label = [], []
pbar = tqdm(test_zero_loader, leave=True, total=len(test_zero_loader))
for i, batch_data in enumerate(pbar):
    batch_seq = batch_data[-1]        
    outputs = forward_mask_sequence_classification(model, batch_data[:-1], i2w=i2w, apply_mask=False, device='cuda')
    loss, batch_hyp, batch_label, logits, label_batch = outputs

    # Calculate total loss
    total_loss += loss.item()

    # Calculate evaluation metrics
    list_hyp += batch_hyp
    list_label += batch_label

    pbar.set_description("TEST LOSS:{:.4f}".format(total_loss/(i+1)))

eval_metrics = classification_metrics_fn(list_hyp, list_label, average='macro', pos_label='fake')
fake_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='fake')
for key in fake_metrics.keys():
    eval_metrics[f'FAKE_{key}'] = fake_metrics[key]
real_metrics = classification_metrics_fn(list_hyp, list_label, average='binary', pos_label='real')
for key in real_metrics.keys():
    eval_metrics[f'REAL_{key}'] = real_metrics[key]

print(f'TEST RESULT: {metrics_to_string(eval_metrics)}')

TEST LOSS:1.2427: 100%|██████████| 30/30 [00:01<00:00, 21.34it/s]


TEST RESULT: ACC:0.7004 F1:0.6769 REC:0.7552 PRE:0.6912 FAKE_ACC:0.7004 FAKE_F1:0.5896 FAKE_REC:0.8644 FAKE_PRE:0.4474 REAL_ACC:0.7004 REAL_F1:0.7641 REAL_REC:0.6461 REAL_PRE:0.9350
