# BERT Model

Testing the pytorch BERT model for fake news classification.

## Library

In [1]:
import pandas as pd
import sklearn as sk
from sklearn.model_selection import train_test_split
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForSequenceClassification, BertAdam

## Data Processing

In [2]:
dataset = "data/clean/fake_reliable_news_headlines.csv"
dataset_df = pd.read_csv(dataset).dropna()
dataset_df.head()

Unnamed: 0,id,type,domain,title
0,34.0,fake,beforeitsnews.com,Surprise: Socialist Hotbed Of Venezuela Has Lo...
1,35.0,fake,beforeitsnews.com,Water Cooler 1/25/18 Open Thread; Fake News ? ...
2,36.0,fake,beforeitsnews.com,Veteran Commentator Calls Out the Growing “Eth...
3,37.0,fake,beforeitsnews.com,"Lost Words, Hidden Words, Otters, Banks and Books"
4,38.0,fake,beforeitsnews.com,Red Alert: Bond Yields Are SCREAMING “Inflatio...


In [3]:
sample_size = 20000
fake_df = dataset_df[dataset_df.type == 'fake']
reliable_df = dataset_df[dataset_df.type == 'reliable']
bert_df = sk.utils.shuffle(pd.concat([fake_df.sample(sample_size), reliable_df.sample(sample_size)]))
bert_df = bert_df.loc[:, ['type','title']]
bert_df['type'] = bert_df.type.map(dict(fake=0, reliable=1))
bert_df.head()

Unnamed: 0,type,title
148910,0,"Organic Avocado Oil Market Demand, Overview, P..."
696109,0,George Bush “The Illumination of a thousand po...
348143,0,"Parohia gorjeană Cloşani, în haină de sărbătoare"
2692062,1,"If There Is a Recession in 2016, This Is How I..."
299997,0,Irán y la Cultura del Miedo


## Model Initialization

In [4]:
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=2)

## Training Functions

In [5]:
# dataset batch processing functions

pad = '[PAD]'
cls = '[CLS]'

def sample_batch(df, batchsize):
    batch = df.sample(batchsize)
    labels, titles = batch['type'], cls + batch['title']
    tokenized_titles = [tokenizer.tokenize(x) for x in titles]
    padded_titles = pad_batch(tokenized_titles)
    indexed_titles = [tokenizer.convert_tokens_to_ids(x) for x in padded_titles]
    return torch.tensor(labels.values), torch.tensor(indexed_titles)

def pad_batch(batch_titles):
    # assume batch_titles is already tokenized
    maxlen = max([len(x) for x in batch_titles])
    padded_batch = []
    for x in batch_titles:
        padded_batch.append(x + [pad]*(maxlen - len(x)))
    return padded_batch

In [9]:
def train(model, df, lr):
    optimizer = BertAdam(model.parameters(), lr=lr)
#     criterion = torch.nn.BCEWithLogitsLoss()
    criterion = torch.nn.CrossEntropyLoss()

    num_epochs = 1
    num_iterations = 32
    batchsize = 32
    for i in range(num_epochs):
        print("Epoch {}:".format(i))
        for j in range(0, num_iterations, batchsize):            
            # get the inputs
            labels, inputs = sample_batch(df, batchsize)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = torch.squeeze(model(inputs))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            if j % 100 == 0:    # print every 2000 mini-batches
                print('\t Iteration {} loss: {:3f}'.format(j, loss.item()))
    print('Finished Training')
    
def test(model, df):
    batchsize = 32
    correct = 0
    with roch.no_grad():
        for _ in range(0, len(df), batchsize):
            labels, inputs = sample_batch(df, batchsize)
            outputs = torch.squeeze(model(inputs))
            
    print(correct/len(df))

## Training

In [7]:
model_file = "models/bert.pt"
test_size = 0.2
train_df, test_df = train_test_split(bert_df, test_size=test_size)

In [10]:
lr = 1e-3
train(model, train_df, lr)
torch.save(model, model_file)

t_total value of -1 results in schedule not being applied


Epoch 0:
	 Iteration 0 loss: 0.720274
Finished Training


## Testing

In [None]:
test(mode, test_df)