In [None]:
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForPreTraining, BertConfig, BertForMaskedLM, BertForSequenceClassification
from pathlib import Path
import torch

from fastai.text import Tokenizer, Vocab
import pandas as pd
import collections
import os
from tqdm import tqdm, trange
import sys
import random
import numpy as np
import apex
from sklearn.model_selection import train_test_split

import datetime
    
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.optimization import BertAdam

from fast_bert.modeling import BertForMultiLabelSequenceClassification
from fast_bert.data import BertDataBunch, InputExample, InputFeatures, MultiLabelTextProcessor, convert_examples_to_features
from fast_bert.learner import BertLearner
from fast_bert.metrics import accuracy_multilabel, accuracy_thresh, fbeta, roc_auc

In [None]:
torch.cuda.empty_cache()

In [None]:
pd.set_option('display.max_colwidth', -1)
run_start_time = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')

In [None]:
DATA_PATH = Path('../data/')
LABEL_PATH = Path('../labels/')

AUG_DATA_PATH = Path('../data/data_augmentation/')

MODEL_PATH=Path('../models/')
LOG_PATH=Path('../logs/')
MODEL_PATH.mkdir(exist_ok=True)

model_state_dict = None

# BERT_PRETRAINED_PATH = Path('../../bert_models/pretrained-weights/cased_L-12_H-768_A-12/')
BERT_PRETRAINED_PATH = Path('../../bert_models/pretrained-weights/uncased_L-12_H-768_A-12/')
# BERT_PRETRAINED_PATH = Path('../../bert_fastai/pretrained-weights/uncased_L-24_H-1024_A-16/')
# FINETUNED_PATH = Path('../models/finetuned_model.bin')
FINETUNED_PATH = None
# model_state_dict = torch.load(FINETUNED_PATH)

LOG_PATH.mkdir(exist_ok=True)

In [None]:
args = {
    "run_text": "multilabel toxic comments with freezable layers",
    "train_size": -1,
    "val_size": -1,
    "log_path": LOG_PATH,
    "full_data_dir": DATA_PATH,
    "data_dir": DATA_PATH,
    "task_name": "toxic_classification_lib",
    "no_cuda": False,
    "bert_model": BERT_PRETRAINED_PATH,
    "output_dir": MODEL_PATH/'output',
    "max_seq_length": 512,
    "do_train": True,
    "do_eval": True,
    "do_lower_case": True,
    "train_batch_size": 16,
    "eval_batch_size": 16,
    "learning_rate": 5e-6,
    "num_train_epochs": 4.0,
    "warmup_proportion": 0.1,
    "no_cuda": False,
    "local_rank": -1,
    "seed": 42,
    "gradient_accumulation_steps": 1,
    "optimize_on_cpu": False,
    "fp16": True,
    "loss_scale": 128
}

In [None]:
import logging

logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, args["run_text"]))

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    handlers=[
        logging.FileHandler(logfile),
        logging.StreamHandler(sys.stdout)
    ])

logger = logging.getLogger()

In [None]:
logger.info(args)

In [None]:
tokenizer = BertTokenizer.from_pretrained(BERT_PRETRAINED_PATH, do_lower_case=args['do_lower_case'])

In [None]:
device = torch.device('cuda')
if torch.cuda.device_count() > 1:
    multi_gpu = True
else:
    multi_gpu = False

In [None]:
label_cols = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [None]:
databunch = BertDataBunch(args['data_dir'], LABEL_PATH, tokenizer, train_file='train.csv', val_file='val.csv',
                          test_data='test.csv',
                          text_col="comment_text", label_col=label_cols,
                          bs=args['train_batch_size'], maxlen=args['max_seq_length'], 
                          multi_gpu=multi_gpu, multi_label=True)
databunch.save()

In [None]:
# databunch = BertDataBunch.load(args['data_dir'])

In [None]:
num_labels = len(databunch.labels)

In [None]:
# torch.distributed.init_process_group(backend="nccl", 
#                                      init_method = "tcp://localhost:23459", 
#                                      rank=0, world_size=1)

In [None]:
metrics = []
metrics.append({'name': 'accuracy_thresh', 'function': accuracy_thresh})
metrics.append({'name': 'roc_auc', 'function': roc_auc})
metrics.append({'name': 'fbeta', 'function': fbeta})
metrics.append({'name': 'accuracy_single', 'function': accuracy_multilabel})

In [None]:
learner = BertLearner.from_pretrained_model(databunch, BERT_PRETRAINED_PATH, metrics, device, logger, 
                                            finetuned_wgts_path=FINETUNED_PATH, 
                                            is_fp16=args['fp16'], loss_scale=args['loss_scale'], 
                                            multi_gpu=multi_gpu,  multi_label=True)

In [None]:
learner.fit(4, lr=args['learning_rate'], schedule_type="warmup_cosine_hard_restarts")