In [0]:
# !pip install transformers
# !pip install --upgrade wandb
# !wandb login <>

In [74]:
import wandb
wandb.init(project="dpl", name='bert_')

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertForSequenceClassification

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext  import data

from tqdm import tqdm_notebook

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data

In [0]:
def open_file(file):
    with open(file, 'r', encoding='utf-8') as f:
        text_list = [line for line in f.readlines()]
    return text_list

In [0]:
# uncomment if google colab:
import os 
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/My Drive/')

fake = open_file("data/fake.txt")
real = open_file("data/real.txt")
df = pd.read_csv("data/dataset.csv")

In [5]:
print(len(fake), len(real), df.shape)
print(fake[:2])
print(real[:2])
df.head()

37366 37366 (74730, 2)
['Spinach has terrorized generations of veggie-phobic kids, and many grownups don\'t much like it, either.."I think it\'s a little bit of a shock to see that he\'s been able to do this,"\n', 'All day, every day, Cheryl Bernstein thanks her 16-month-old son. the boy is a little boy.\n']
["Spinach has terrorized generations of veggie-phobic kids, and many grownups don't much like it, either. But when it's combined with seasonings and feta cheese and wrapped in a golden crisp phyllo dough crust, even those who despise Popeye's Â\xadfavorite food ask for seconds.\n", 'All day, every day, Cheryl Bernstein thanks her 16-month-old son. "I gave life to Reid, but he gave me life - a reason to get clean and go on," she said yesterday after graduating from the Manhattan Family Treatment Court program.\n']


Unnamed: 0,text,label
0,"Is a skull from Petralona Cave, Greece, the ol...",real
1,The Network Readiness Index published by the W...,fake
2,Now they've got Justin Bieber too. He was just...,real
3,"NOGALES, Arizona — Jessica Elizabeth Orellana ...",real
4,Many companies that are using cloud computing ...,fake


# BERT

In [75]:
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
model = BertForSequenceClassification.from_pretrained(pretrained_weights)

embeddings_pretrained = model.get_input_embeddings()
embeddings_pretrained

Embedding(30522, 768, padding_idx=0)

In [76]:
def tokenize(text, tokenizer=tokenizer):
    return tokenizer.encode(text)

MAX_VOCAB_SIZE = 50000
classes={'fake': 0, 'real': 1}


TEXT = data.Field(sequential=True, include_lengths=False, batch_first=True, tokenize=tokenize, 
             pad_first=True, lower=False) 
LABEL = data.LabelField(dtype=torch.long, use_vocab=True, preprocessing=lambda x: classes[x])

dataset = data.TabularDataset('data/dataset.csv', 
                                format='csv', fields=[('text', TEXT), ('label',LABEL),], 
                                skip_header=True)

TEXT.build_vocab(dataset,  max_size=MAX_VOCAB_SIZE, min_freq=2)
LABEL.build_vocab(dataset)
vocab = TEXT.vocab
print('Vocab size:', len(TEXT.vocab.itos))

train, test = dataset.split(0.8, stratified=True)
train, valid = train.split(0.8, stratified=True)

Vocab size: 24776


In [0]:
batch_size = 256
num_epochs = 10

model.to(device)
wandb.watch(model)

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train, valid, test),
    batch_sizes=(batch_size, batch_size, batch_size),
    shuffle=True,
    device=device,
    sort_key=lambda x: len(x.text),
    sort_within_batch=True,
)
optimizer = optim.Adam(model.parameters())

In [78]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(params)

109483778


In [0]:
# for p in model.named_parameters():
#     print(p[0])

In [0]:
for p in model.bert.encoder.parameters(): 
    p.requires_grad = False 

for p in model.bert.pooler.parameters():
    p.requires_grad = False

for p in model.bert.embeddings.parameters(): 
    p.requires_grad = False 

# for p in model.bert.encoder.layer[-1].parameters():
#     p.requires_grad = True

In [80]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(params)

1538


In [0]:
def accuracy_score(preds, y):
    preds = (preds == y).float()
    accuracy = preds.sum() / len(preds)
    return accuracy.item()

In [0]:
def _train_epoch(model, iterator, optimizer, curr_epoch):
    model.train()

    running_loss = 0
    losses = []
    train_acc = []

    n_batches = len(iterator)    
    iterator = tqdm_notebook(iterator, total=n_batches, desc='epoch %d' % (curr_epoch), leave=True)
    
    for i, batch in enumerate(iterator):
        x = batch.text
        y = batch.label
        outputs = model(x, labels=y)
        loss, logits = outputs[:2]
        _, preds = torch.max(F.softmax(logits, dim=1),1)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()        
        curr_loss = loss.data.detach().item()
        losses.append(loss.item())
        
        acc = accuracy_score(preds, y)
        train_acc.append(acc)
        iterator.set_postfix(loss='%.5f' % curr_loss, acc='%.5f' % acc)

        wandb.log({
        "Train Accuracy": np.mean(train_acc),
        "Train Loss": curr_loss})
    return curr_loss, losses, train_acc

def _test_epoch(model, iterator):
    model.eval()    
    epoch_loss = 0
    losses = []
    test_acc = []

    n_batches = len(iterator)
    with torch.no_grad():
        for batch in iterator:
            x = batch.text
            y = batch.label
            outputs = model(x, labels=y)
            loss, logits = outputs[:2]
            _, preds = torch.max(F.softmax(logits, dim=1),1)
            test_acc.append(accuracy_score(preds, y))
            losses.append(loss.item())
            epoch_loss += loss.data.item()
    
    wandb.log({
        "Valid Accuracy": np.mean(test_acc),
        "Valid Loss": epoch_loss/n_batches})
    
    return epoch_loss / n_batches, losses, test_acc

def nn_train(model, train_iterator, valid_iterator, optimizer, n_epochs=20, early_stopping=0):

    prev_loss = 10500
    es_epochs = 0
    best_epoch = None
    history = pd.DataFrame()

    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []

    for epoch in range(n_epochs):
        train_loss, epoch_tl, train_acc = _train_epoch(model, train_iterator, optimizer, epoch)
        valid_loss,  epoch_vl, valid_acc = _test_epoch(model, valid_iterator)

        train_losses.extend(epoch_tl)
        valid_losses.extend(epoch_vl)
        train_accs.extend(train_acc)
        valid_accs.extend(valid_acc)

        print('validation loss %.5f' % valid_loss, 'validation accuracy  %.5f' % np.mean(valid_accs))

        record = {'epoch': epoch, 'train_loss': train_loss, 'valid_loss': valid_loss, 'train_acc': np.mean(train_accs), 'valid_acc':np.mean(valid_accs)}
        history = history.append(record, ignore_index=True)

        if early_stopping > 0:
            if valid_loss > prev_loss:
                es_epochs += 1
            else:
                es_epochs = 0
            if es_epochs >= early_stopping:
                best_epoch = history[history.valid_loss == history.valid_loss.min()].iloc[0]
                print('Early stopping! best epoch: %d val %.5f' % (best_epoch['epoch'], best_epoch['valid_loss']))
                break
            prev_loss = min(prev_loss, valid_loss)
    return history

In [0]:
history = nn_train(model, train_iterator, valid_iterator,
                   optimizer, n_epochs=1, early_stopping=5)

In [0]:
def test_model(model, test_iterator):
    test_acc = []

    with torch.no_grad():
        for item in test_iterator:
            x = item.text
            y = item.label
            outputs = model(x, labels=y)
            loss, logits = outputs[:2]
            _, preds = torch.max(F.softmax(logits, dim=1),1)
            test_acc.append(accuracy_score(preds, y))
    test_acc = np.mean(test_acc) 
    return np.mean(test_acc)

In [0]:
test_accuracy = test_model(model, test_iterator)
print('Test accuracy: {}'.format(np.mean(test_accuracy)))

wandb.log({
        "Test Accuracy": test_accuracy})

In [73]:
wandb.save('bert.h5')

[]