In [1]:
import os, gc, sys
import random

import pandas as pd
import numpy as np

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score

from simpletransformers.classification import ClassificationModel
import torch



In [2]:
SEED = 2020
BASE_PATH = '/home/toshiya/Workspace/learning/signate/SIGNATE_Student_Cup_2020'
TEXT_COL = "description"
TARGET = "jobflag"
NUM_CLASS = 4
N_FOLDS = 4

In [3]:
def metric_f1(labels, preds):
    return f1_score(labels, preds, average='macro')

In [4]:
def seed_everything(seed):
    """for reproducibility.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(SEED)

In [5]:
train = pd.read_csv(os.path.join(BASE_PATH, 'data', "train.csv")).drop(['id'], axis=1)
train = train.rename(columns={TEXT_COL:'text', TARGET:'label'})
train['label'] -= 1

test = pd.read_csv(os.path.join(BASE_PATH, 'data', "test.csv"))
test = test.rename(columns={TEXT_COL:'text'}).drop(['id'], axis=1)

In [6]:
kfold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
train['fold_id'] = -1
for fold, (train_idx, valid_idx) in enumerate(kfold.split(train.index, train['label'])):
    train.loc[train.iloc[valid_idx].index, 'fold_id'] = fold

X_train = train.loc[train['fold_id']!=0]
X_valid = train.loc[train['fold_id']==0]

In [7]:
os.makedirs(os.path.join(BASE_PATH, "result", "03_NLP_baseline"), exist_ok=True)
params = {
    "output_dir": os.path.join(BASE_PATH, "result", "03_NLP_baseline"),
    "max_seq_length": 128,
    "train_batch_size": 32,
    "eval_batch_size": 64,
    "num_train_epochs": 70,
    "learning_rate": 1e-4,
    "manual_seed":SEED,
    "save_eval_checkpoints": False,
    "save_steps": 6900,
    "save_model_every_epoch": False,
    "no_cache": True,
    "overwrite_output_dir": True
}
model = ClassificationModel('bert', 'bert-base-cased', num_labels=4,
                            args=params, use_cuda=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [8]:
model.train_model(X_train)



HBox(children=(FloatProgress(value=0.0, max=2198.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=70.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 70', max=69.0, style=ProgressStyle(des…






HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 7 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 8 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 9 of 70', max=69.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 10 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 11 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 12 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 13 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 14 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 15 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 16 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 17 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 18 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 19 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 20 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 21 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 22 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 23 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 24 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 25 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 26 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 27 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 28 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 29 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 30 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 31 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 32 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 33 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 34 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 35 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 36 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 37 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 38 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 39 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 40 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 41 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 42 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 43 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 44 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 45 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 46 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 47 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 48 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 49 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 50 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 51 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 52 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 53 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 54 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 55 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 56 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 57 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 58 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 59 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 60 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 61 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 62 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 63 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 64 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 65 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 66 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 67 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 68 of 70', max=69.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 69 of 70', max=69.0, style=ProgressStyle(de…





In [9]:
result, model_outputs, wrong_predictions = model.eval_model(X_valid, f1=metric_f1)
print(result)



HBox(children=(FloatProgress(value=0.0, max=733.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=12.0, style=ProgressStyle(descri…


{'mcc': 0.49885132810886523, 'f1': 0.6188720998376489, 'eval_loss': 3.4990211526552835}


In [10]:
y_pred, raw_outputs = model.predict(test['text'])
print(y_pred)

HBox(children=(FloatProgress(value=0.0, max=1743.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))


[3 2 2 ... 0 2 2]


In [11]:
test = pd.read_csv(os.path.join(BASE_PATH, "data", "test.csv"))
submit = pd.DataFrame({'index':test['id'], 'pred':y_pred+1})
submit

Unnamed: 0,index,pred
0,2931,4
1,2932,3
2,2933,3
3,2934,1
4,2935,3
...,...,...
1738,4669,1
1739,4670,3
1740,4671,1
1741,4672,3


In [12]:
submit.to_csv(os.path.join(BASE_PATH, "result", "03_NLP_baseline", "03_NLP_baseline.csv"), index=False, header=False)