In [0]:
!pip install alchemy-catalyst
!pip install transformers
!pip install -U catalyst

In [0]:
!pip install --upgrade wandb
!wandb login

In [0]:
import wandb
import warnings
warnings.filterwarnings('ignore')

import pandas as pd 
import numpy as np

from transformers import BertTokenizer, BertForSequenceClassification

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext  import data


import catalyst.dl as dl
from collections import OrderedDict
from catalyst.dl.callbacks  import AccuracyCallback, EarlyStoppingCallback, WandbLogger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [0]:
# TODO Подобрать кол-во параметров requires_grad
# TODO Гиперпараметры
# Сравнить с LSTM и еще покопаться в лучшей модели немного

# Data

In [0]:
# uncomment if google colab:

import os 
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/My Drive/')

df = pd.read_csv("data/dataset.csv")

In [0]:
df.shape

(483202, 3)

In [0]:
df.head()

Unnamed: 0,text,label,sampling
0,"The police department in Green Mountain Falls,...",real,No sampling
1,"DHAKA, Bangladesh—Islamic State militants stor...",fake,nucleus
2,A few minutes into her visit with plastic surg...,real,No sampling
3,"Here is the second item from my ""Albany Inside...",real,No sampling
4,"Reversing a long and slow stock decline, share...",real,No sampling


# BERT

In [0]:
pretrained_weights = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
bert = BertForSequenceClassification.from_pretrained(pretrained_weights)

pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
embeddings_pretrained = bert.get_input_embeddings()
embeddings_pretrained

Embedding(30522, 768, padding_idx=0)

In [0]:
os.chdir('/content/drive/My Drive/')

In [0]:
def tokenize(text, tokenizer=tokenizer):
    return tokenizer.encode(text, max_length=512)

MAX_VOCAB_SIZE = 50000
classes={'fake': 0, 'real': 1}


TEXT = data.Field(sequential=True, 
                  include_lengths=False,
                  batch_first=True, 
                  tokenize=tokenize, 
                  pad_first=True,
                  lower=False,
                  use_vocab=False,
                  preprocessing=data.Pipeline(int),
                  pad_token=pad_index) 

LABEL = data.LabelField(dtype=torch.float, 
                        use_vocab=False, 
                        sequential=False,
                        preprocessing=lambda x: classes[x])


dataset = data.TabularDataset('data/dataset.csv', 
                                format='csv', fields=[('text', TEXT), ('label',LABEL), (None, None)], 
                                skip_header=True)

train, test = dataset.split(0.8, stratified=True)
train, valid = train.split(0.8, stratified=True)

1/10 * Epoch (train):   0% 0/2417 [36:47<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [23:18<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [17:48<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [11:34<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [08:46<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [07:28<?, ?it/s]
1/10 * Epoch (train):   0% 0/2417 [04:38<?, ?it/s]


In [0]:
class Batch:
    "Object for holding a batch of data during training."
    def __init__(self, text, label):
        self.text = text
        self.label = label


class BucketIteratorWrapper(DataLoader):
    __initialized = False

    def __init__(self, iterator: data.Iterator):
        self.batch_size = iterator.batch_size
        self.num_workers = 1
        self.collate_fn = None
        self.pin_memory = False
        self.drop_last = False
        self.timeout = 0
        self.worker_init_fn = None
        self.sampler = iterator
        self.batch_sampler = iterator
        self.__initialized = True

    def __iter__(self):
        return map(
            lambda batch: {'features': Batch(batch.text, batch.label).text,
                           'targets': Batch(batch.text, batch.label).label.unsqueeze(-1),
                          },
            self.batch_sampler.__iter__()
        )

    def __len__(self):
        return len(self.batch_sampler)

In [0]:
config = {'tokenization/embeddings': 'bert',
            'batch_size': 128,
          'hidden_size' : 256,
            'num_epochs': 10}

In [0]:
class MyModel(nn.Module):

    def __init__(self, bert):
        super(MyModel, self).__init__()
        self.bert = bert

    def forward(self, x):
        # print(x.size())
        x = bert(x)[0]
        # print(len(x))
        # print(x[0].size())
        # # loss, logits = outputs[:2]
        # _, preds = torch.max(F.softmax(x, dim=1),1)
        x, _ = torch.max(F.softmax(x, dim=1),1)
        # print(x.size())
        # print(x)
        return x.unsqueeze(-1)

# class MyCriterion(nn.Module):
#     def __init__(self, bert):
#         super(MyCriterion, self).__init__()
#         self.pad_idx = pad_idx
#         self.criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=pad_idx)
        
#     def forward(self, x, target):
#         x = x.contiguous().permute(0,2,1)
#         ntokens = (target != self.pad_idx).data.sum()
        
#         return self.criterion(x, target) / ntokens

In [0]:
model = MyModel(bert=bert)
model.to(device)

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train, valid, test),
    batch_sizes=(config['batch_size'], config['batch_size'], config['batch_size']),
    shuffle=True,
    device=device,
    sort=False,
    sort_key=lambda x: len(x.text),
    sort_within_batch=False,
)

In [0]:
for p in model.bert.bert.encoder.parameters(): 
    p.requires_grad = False 

for p in model.bert.bert.pooler.parameters():
    p.requires_grad = False

for p in model.bert.bert.embeddings.parameters(): 
    p.requires_grad = False 

In [0]:
for el in train_iterator:
    with torch.no_grad():
        print(el.label.size())
        print(model(el.text).size())
        break

torch.Size([128])
torch.Size([128, 1])


In [0]:
model = MyModel(bert=bert)
model.to(device)

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train, valid, test),
    batch_sizes=(config['batch_size'], config['batch_size'], config['batch_size']),
    shuffle=True,
    device=device,
    sort=False,
    sort_key=lambda x: len(x.text),
    sort_within_batch=False,
)

train_iterator = BucketIteratorWrapper(train_iterator)
valid_iterator = BucketIteratorWrapper(valid_iterator)
test_iterator = BucketIteratorWrapper(test_iterator)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=2)
criterion.to(device)

BCEWithLogitsLoss()

In [0]:
for el in test_iterator:
    print(el['targets'].unique(return_counts=True))
    print(el['features'].size())
    print(el['targets'].size())
    print(model(el['features']).size())
    break

(tensor([0., 1.], device='cuda:0'), tensor([66, 62], device='cuda:0'))
torch.Size([128, 115])
torch.Size([128, 1])
torch.Size([128, 1])


In [0]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(params)

1538


In [0]:
for p in model.bert.bert.encoder.parameters(): 
    p.requires_grad = False 

for p in model.bert.bert.pooler.parameters():
    p.requires_grad = False

for p in model.bert.bert.embeddings.parameters(): 
    p.requires_grad = False 

# for p in model.bert.encoder.layer[-1].parameters():
#     p.requires_grad = True

In [0]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(params)

1538


# Train and Test

In [0]:
os.chdir('/content/')
logdir = '/content/'
RUN_NAME = 'bert_test'
RUN_ID = 'sbrtсxczybs'

In [0]:
from tqdm import tqdm
def clean_tqdm():
    for instance in list(tqdm._instances): 
        tqdm._decr_instances(instance)

for e in tqdm([1,2,3]):
    pass


100%|██████████| 3/3 [00:00<00:00, 15069.36it/s]


In [0]:
runner = dl.SupervisedRunner(device=device)
loaders = OrderedDict(
    {'train': train_iterator,
    'valid': valid_iterator}
)

clean_tqdm()
runner.train(
    model=model, 
    criterion=criterion,
    optimizer=optimizer, 
    scheduler=scheduler,
    loaders=loaders,
    logdir=logdir,
    num_epochs=config['num_epochs'],
    verbose=True,
    valid_loader="valid",
    callbacks=[AccuracyCallback(num_classes=2,
                                activation='Sigmoid',
                                threshold=0.5),
               EarlyStoppingCallback(patience=4),
               WandbLogger(log_on_batch_end=True,
                           project="dpl",
                           name=RUN_NAME,
                           config=config,
                           id=RUN_ID
                           )],
    monitoring_params={
                    "project": "dpl",
                    'tags': 'lstm',
                    'config': config,
    }
)

1/10 * Epoch (train):   0% 0/2417 [00:00<?, ?it/s]

Streaming file created twice in same run: /content/wandb/run-20200417_213440-sbrtсxczybs/wandb-events.jsonl


Early exiting
1/10 * Epoch (train):  11% 263/2417 [02:18<18:59,  1.89it/s, accuracy01=0.469, loss=0.745]

In [0]:
results = torch.load('/content/checkpoints/train.2.pth', map_location=device)
model.load_state_dict(results['model_state_dict'])

In [0]:
!cp "/content/checkpoints/train.2.pth" "/content/drive/My Drive/model_checkpoints/"

In [0]:
def accuracy_score(preds, y):
    preds = torch.round(torch.sigmoid(preds))
    preds = (preds == y).float()
    accuracy = preds.sum() / len(preds)
    return accuracy.item()

In [0]:
def test_model(model, test_iterator):
    test_acc = []
    with torch.no_grad():
        for item in test_iterator:
            x = item['features']
            y = item['targets'].squeeze(-1)
            preds = model(x).squeeze(-1)
            test_acc.append(accuracy_score(preds, y))
    test_acc = np.mean(test_acc) 
    return np.mean(test_acc)

In [0]:
test_accuracy = test_model(model, test_iterator)
print('Test accuracy: {}'.format(np.mean(test_accuracy)))

In [0]:
wandb.init(id=RUN_ID, config=config)
wandb.log({"Test accc" : test_accuracy})