In [1]:
import os
import sys
from importlib import reload

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import tqdm

import torch
from torch.utils import data as D

# Local imports
sys.path.append('../src')
import dataset
import trainer
import models
import utils
import preprocessing

# Transformers
import transformers
from transformers import XLMRobertaModel, XLMRobertaTokenizer, XLMRobertaConfig
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule

# Setup device
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Seed 
utils.seed_everything()

print('use', device)

[nltk_data] Downloading package punkt to
[nltk_data]     /gpfs/hpc/home/papkov/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


INFO: Pandarallel will run on 2 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
use cuda


## Datasets

In [2]:
reload(dataset)

<module 'dataset' from '../src/dataset.py'>

In [3]:
debug = False

In [4]:
%%time
if debug:
    valid = dataset.Dataset('../input/validatio_debug_32.npz')
else:
    valid = dataset.Dataset('../input/validation.npz')
valid.x.shape, valid.y.shape

CPU times: user 145 ms, sys: 43.2 ms, total: 188 ms
Wall time: 261 ms


((8000, 512), (8000,))

In [5]:
%%time
if debug:
    test = dataset.Dataset('../input/test_debug_32.npz')
else:
    test = dataset.Dataset('../input/test.npz')
test.x.shape, test.y.shape

CPU times: user 731 ms, sys: 205 ms, total: 936 ms
Wall time: 1 s


((63812, 512), (63812,))

In [6]:
%%time
if debug:
    train = dataset.Dataset('../input/jigsaw-toxic-comment-trai_debug_32.npz')
else:
    train = dataset.Dataset('../input/jigsaw-toxic-comment-train.npz')
train.x.shape, train.y.shape

CPU times: user 2.14 s, sys: 700 ms, total: 2.84 s
Wall time: 2.89 s


((223549, 512), (223549,))

## Model

In [2]:
reload(models)

<module 'models' from '../src/models.py'>

In [3]:
backbone = XLMRobertaModel(XLMRobertaConfig.from_pretrained('xlm-roberta-large'))

We can turn of regularization to make debugging easier

In [100]:
# to reload the module and not overload gpu
del model

In [38]:
# model = models.Model(backbone, mix=False, dropout=0)

In [11]:
model = models.Model(backbone, mix=True, dropout=0.25)

## Feature extraction

In [6]:
reload(preprocessing)

INFO: Pandarallel will run on 2 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


[nltk_data] Downloading package punkt to
[nltk_data]     /gpfs/hpc/home/papkov/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


<module 'preprocessing' from '../src/preprocessing.py'>

In [7]:
preprocessing.extract_roberta_features_to_file('../input/validation.npz', backbone=backbone, device=device)

feature extraction: 100%|##########| 63/63 [04:35<00:00,  4.38s/it]


In [None]:
preprocessing.extract_roberta_features_to_file('../input/test.npz', backbone=backbone, device=device)

feature extraction:  30%|###       | 152/499 [11:12<25:35,  4.43s/it]

In [None]:
preprocessing.extract_roberta_features_to_file('../input/jigsaw-toxic-comment-train.npz', backbone=backbone, device=device, batch_size=128)

## Data loaders

In [7]:
batch_size = 2
num_workers = 4

loader_train = D.DataLoader(train, 
                            sampler=train.weighted_sampler(), 
                            batch_size=batch_size, num_workers=num_workers)
loader_valid = D.DataLoader(valid, 
                            batch_size=batch_size, num_workers=num_workers)
loader_test = D.DataLoader(test, 
                           batch_size=batch_size, num_workers=num_workers)

In [8]:
len(loader_train), len(loader_valid), len(loader_test)

(21384, 4000, 31906)

## Trainer

In [12]:
reload(trainer)

<module 'trainer' from '../src/trainer.py'>

In [13]:
# we may optimize only head (with encoder pretrained)
optimizer = AdamW(model.head.parameters(), lr=1e-4)
# scheduler = get_linear_schedule_with_warmup()

In [14]:
trnr = trainer.Trainer('base', model, 
                       loader_train, loader_valid, loader_test,
                       epochs=5,
                       monitor='val_loss',
                       optimizer=optimizer,
                      )

Sanity check for output

In [8]:
#x, y, am = next(iter(loader_train))

In [21]:
#out, loss = trnr(x, y, am)

In [22]:
#out

tensor([[1.0928, 1.2950],
        [0.8992, 1.2395],
        [1.1048, 1.3800],
        [1.1402, 1.3978]], device='cuda:0', grad_fn=<AddmmBackward>)

In [23]:
#loss

tensor(0.6357, device='cuda:0', grad_fn=<MeanBackward0>)

## Training

In [105]:
trnr.fit()

ep. 0000 (lr 1.00e-04): 100%|##########| 16/16 [00:03<00:00,  4.20it/s, loss=0.556, acc=0.781]
valid: 100%|##########| 16/16 [00:01<00:00, 11.97it/s]


Epoch 0 complete. val loss (avg): 0.5137, val acc: 0.8125
Saved model to ../checkpoints//base_last.pth


ep. 0001 (lr 8.18e-05): 100%|##########| 16/16 [00:03<00:00,  4.24it/s, loss=0.311, acc=0.906]
valid: 100%|##########| 16/16 [00:01<00:00, 11.12it/s]


Epoch 1 complete. val loss (avg): 0.4835, val acc: 0.8125
Saved model to ../checkpoints//base.pth


ep. 0002 (lr 4.74e-05): 100%|##########| 16/16 [00:03<00:00,  4.22it/s, loss=0.29, acc=0.906] 
valid: 100%|##########| 16/16 [00:01<00:00, 11.89it/s]


Epoch 2 complete. val loss (avg): 0.4961, val acc: 0.8125
Saved model to ../checkpoints//base_last.pth


ep. 0003 (lr 1.82e-05): 100%|##########| 16/16 [00:03<00:00,  4.22it/s, loss=0.325, acc=0.906]
valid: 100%|##########| 16/16 [00:01<00:00, 11.52it/s]


Epoch 3 complete. val loss (avg): 0.5018, val acc: 0.8125
Saved model to ../checkpoints//base_last.pth


ep. 0004 (lr 2.64e-06): 100%|##########| 16/16 [00:03<00:00,  4.22it/s, loss=0.314, acc=0.906]
valid: 100%|##########| 16/16 [00:01<00:00, 11.95it/s]


Epoch 4 complete. val loss (avg): 0.5025, val acc: 0.8125
Saved model to ../checkpoints//base_last.pth


## Prediction

In [15]:
pred, loss, acc = trnr.validate()

valid: 100%|##########| 4000/4000 [04:53<00:00, 13.63it/s]


In [18]:
acc

0.15375

In [19]:
pred, loss, acc = trnr.test()

test: 100%|##########| 31906/31906 [38:58<00:00, 13.65it/s]


In [22]:
loss

2.320265071657932