<a href="https://colab.research.google.com/github/roshan-d21/Fake-News-Detector/blob/master/BERT/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import torch.utils.data as data_utils
import torch.optim as optim
import gc #garbage collector for gpu memory 
from tqdm import tqdm

In [2]:
%%capture
!pip install transformers

In [3]:
%%capture
from transformers import BertForSequenceClassification, BertTokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
import pandas as pd
news_data = pd.read_csv('fake_news.csv')
news_data.dropna(inplace=True)
news_data['target'] = news_data.apply(lambda row: 1 if row.Label else 0, axis=1)
news_data.head(10)

Unnamed: 0,Statement,Label,target
0,Building a wall on the U.S.-Mexico border will...,True,1
1,Wisconsin is on pace to double the number of l...,False,1
2,Says John McCain has done nothing to help the ...,False,1
3,Suzanne Bonamici supports a plan that will cut...,True,1
4,When asked by a reporter whether hes at the ce...,False,1
5,Over the past five years the federal governmen...,True,1
6,Says that Tennessee law requires that schools ...,True,1
7,"Says Vice President Joe Biden ""admits that the...",False,1
8,Donald Trump is against marriage equality. He ...,True,1
9,We know that more than half of Hillary Clinton...,False,1


In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [6]:
tokenized_df = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:510] + ['[SEP]'], news_data['Statement']))

In [7]:
totalpadlength = 512

In [8]:
indexed_tokens = list(map(tokenizer.convert_tokens_to_ids, tokenized_df))

In [9]:
index_padded = np.array([xi + [0] * (totalpadlength - len(xi)) for xi in indexed_tokens])

In [10]:
target_variable = news_data['target'].values

In [11]:
all_words = []
for l in tokenized_df:
  all_words.extend(l)
all_indices = []
for i in indexed_tokens:
  all_indices.extend(i)

word_to_ix = dict(zip(all_words, all_indices))
ix_to_word = dict(zip(all_indices, all_words))

In [12]:
mask_variable = [[float(i>0) for i in j] for j in index_padded]

In [13]:
BATCH_SIZE = 8
def format_tensors(text_data, mask, labels, batch_size):
    X = torch.from_numpy(text_data)
    X = X.long()
    mask = torch.tensor(mask)
    y = torch.from_numpy(labels)
    y = y.long()
    tensordata = data_utils.TensorDataset(X, mask, y)
    loader = data_utils.DataLoader(tensordata, batch_size=batch_size, shuffle=False)
    return loader

X_train, X_test, y_train, y_test = train_test_split(index_padded, target_variable, 
                                                    test_size=0.1, random_state=42)

train_masks, test_masks, _, _ = train_test_split(mask_variable, index_padded, 
                                                       test_size=0.1, random_state=42)

trainloader = format_tensors(X_train, train_masks, y_train,BATCH_SIZE)
testloader = format_tensors(X_test, test_masks, y_test, BATCH_SIZE)

In [14]:
model = BertForSequenceClassification.from_pretrained('bert-base-cased')
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [26]:
def compute_accuracy(model, dataloader, device, m=1):
    tqdm()
    model.eval()
    correct_preds, num_samples = 0,0
    inc = 0
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            if i > 110: break
            token_ids, masks, labels = tuple(t.to(device) for t in batch)
            output = model(input_ids=token_ids, attention_mask=masks, labels=labels)
            yhat = output.logits
            prediction = (torch.sigmoid(yhat[:,1]) > 0.5).long()
            num_samples += labels.size(0)
            correct_preds += (prediction==labels.long()).sum()
            del token_ids, masks, labels #memory
        torch.cuda.empty_cache() #memory
        gc.collect() # memory
        return correct_preds.float()/num_samples *100

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() #memory
gc.collect() #memory
NUM_EPOCHS = 1
loss_function = nn.BCEWithLogitsLoss()
losses = []
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-6)
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    iteration = 0
    for i, batch in enumerate(trainloader):
        if i > 120: break
        iteration += 1
        token_ids, masks, labels = tuple(t.to(device) for t in batch)
        optimizer.zero_grad()
        output = model(input_ids=token_ids, attention_mask=masks, labels=labels)
        loss = output.loss
        loss.backward()
        optimizer.step()
        running_loss += float(loss.item())
        del token_ids, masks, labels #memory
    
        if not i%25:
            print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
                  f'Batch {i+1:03d}/{len(trainloader):03d} | '
                  f'Average Loss in last {iteration} iteration(s): {(running_loss/iteration):.4f}')
            running_loss = 0.0
            iteration = 0
        torch.cuda.empty_cache() #memory
        gc.collect() #memory
        losses.append(float(loss.item()))
    # with torch.set_grad_enabled(False):
    #     print(f'\nTraining Accuracy: '
    #           f'{compute_accuracy(model, trainloader, device):.2f}%')

Epoch: 001/001 | Batch 001/986 | Average Loss in last 1 iteration(s): 0.6010
Epoch: 001/001 | Batch 026/986 | Average Loss in last 25 iteration(s): 0.5126
Epoch: 001/001 | Batch 051/986 | Average Loss in last 25 iteration(s): 0.3493
Epoch: 001/001 | Batch 076/986 | Average Loss in last 25 iteration(s): 0.2471
Epoch: 001/001 | Batch 101/986 | Average Loss in last 25 iteration(s): 0.1436


0it [00:00, ?it/s]
 16%|█▌        | 159/986 [01:24<07:24,  1.86it/s]

KeyboardInterrupt: ignored

In [27]:
with torch.set_grad_enabled(False):
    print(f'\nTraining Accuracy: 'f'{compute_accuracy(model, trainloader, device):.2f}%')





0it [00:00, ?it/s]




  0%|          | 0/986 [00:00<?, ?it/s][A[A[A[A



  0%|          | 2/986 [00:00<05:04,  3.23it/s][A[A[A[A



  0%|          | 3/986 [00:01<06:11,  2.65it/s][A[A[A[A



  0%|          | 4/986 [00:01<06:58,  2.35it/s][A[A[A[A



  1%|          | 5/986 [00:02<07:31,  2.17it/s][A[A[A[A



  1%|          | 6/986 [00:02<07:55,  2.06it/s][A[A[A[A



  1%|          | 7/986 [00:03<08:12,  1.99it/s][A[A[A[A



  1%|          | 8/986 [00:03<08:22,  1.95it/s][A[A[A[A



  1%|          | 9/986 [00:04<08:30,  1.91it/s][A[A[A[A



  1%|          | 10/986 [00:04<08:38,  1.88it/s][A[A[A[A



  1%|          | 11/986 [00:05<08:40,  1.87it/s][A[A[A[A



  1%|          | 12/986 [00:06<08:42,  1.86it/s][A[A[A[A



  1%|▏         | 13/986 [00:06<08:43,  1.86it/s][A[A[A[A



  1%|▏         | 14/986 [00:07<08:44,  1.85it/s][A[A[A[A



  2%|▏         | 15/986 [00:07<08:45,  1.85it/s][A[A[A[A



  2%|▏         | 16/986 [00:08


Training Accuracy: 94.92%


In [20]:
with torch.set_grad_enabled(False):
  print(f'\n\nTest Accuracy:'
  f'{compute_accuracy(model, testloader, device, 2):.2f}%')


0it [00:00, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s][A
  2%|▏         | 2/110 [00:00<00:32,  3.31it/s][A
  3%|▎         | 3/110 [00:01<00:39,  2.70it/s][A
  4%|▎         | 4/110 [00:01<00:44,  2.38it/s][A
  5%|▍         | 5/110 [00:02<00:47,  2.20it/s][A
  5%|▌         | 6/110 [00:02<00:49,  2.09it/s][A
  6%|▋         | 7/110 [00:03<00:51,  2.01it/s][A
  7%|▋         | 8/110 [00:03<00:51,  1.97it/s][A
  8%|▊         | 9/110 [00:04<00:52,  1.94it/s][A
  9%|▉         | 10/110 [00:04<00:52,  1.92it/s][A
 10%|█         | 11/110 [00:05<00:52,  1.90it/s][A
 11%|█         | 12/110 [00:05<00:51,  1.89it/s][A
 12%|█▏        | 13/110 [00:06<00:51,  1.88it/s][A
 13%|█▎        | 14/110 [00:07<00:51,  1.87it/s][A
 14%|█▎        | 15/110 [00:07<00:50,  1.88it/s][A
 15%|█▍        | 16/110 [00:08<00:50,  1.87it/s][A
 15%|█▌        | 17/110 [00:08<00:49,  1.87it/s][A
 16%|█▋        | 18/110 [00:09<00:49,  1.87it/s][A
 17%|█▋        | 19/110 [00:09<00:48,  1.87it/s][A
 18%|█▊ 



Test Accuracy:89.83%
