In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
# import torch_xla.utils.serialization as xser

In [None]:
import re
import os
import torch
import pandas as pd
from scipy import stats
import numpy as np

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib

import logging
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule, BertModel, XLMRobertaModel, XLMRobertaTokenizer, DistilBertModel, DistilBertTokenizer
import sys
from sklearn import metrics, model_selection

import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
import time

warnings.filterwarnings("ignore")

class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def onehot(size, target):
    vec = torch.zeros(size, dtype=torch.float32)
    vec[target] = 1.
    return vec
        
class BERTDatasetTraining(torch.utils.data.TensorDataset):
    def __init__(self, comment_text, targets, tokenizer, idxs=None, max_length=200, test=False):
        self.comment_text = comment_text
        self.test = test
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.targets = targets
        self.idxs = idxs

    def get_tokens(self, text):
        encoded = self.tokenizer.encode_plus(
            text, 
            add_special_tokens=True, 
            max_length=self.max_length, 
            pad_to_max_length=True
        )
        return encoded['input_ids'], encoded['attention_mask']

    def __getitem__(self, item):
    
        text = self.comment_text[item]

        encoded = self.get_tokens(text)

        if not self.test:
            return torch.tensor(encoded[0],dtype=torch.long), torch.tensor(encoded[1],dtype=torch.long), torch.tensor(self.targets[item], dtype=torch.float)
        else:
            return torch.tensor(self.idxs[item]), torch.tensor(encoded[0],dtype=torch.long), torch.tensor(encoded[1],dtype=torch.long)          

    def __len__(self):
        return len(self.targets)

In [None]:
def clean(text):
    if re.findall(r'\[\'\'\]',text):
        return ''
    else:
        retext = re.split(',',text)
        clean_text = []
        retext[0] = retext[0][1:]
        retext[-1] = retext[-1][:-1]
        for i in range(len(retext)):
            retext[i] = " ".join(retext[i].split())
            retext[i] = retext[i][1:-1]
        return '.'.join(retext)

def clean_test(text):
    if re.findall(r'\[\'\'\]',text):
        return '[UNK]'
    else:
        retext = re.split(',',text)
        clean_text = []
        retext[0] = retext[0][1:]
        retext[-1] = retext[-1][:-1]
        for i in range(len(retext)):
            retext[i] = " ".join(retext[i].split())
            retext[i] = retext[i][1:-1]
        return '.'.join(retext)

In [None]:
class ToxicSimpleNNModel(nn.Module):

    def __init__(self):
        super(ToxicSimpleNNModel, self).__init__()
        self.encoder = BertModel.from_pretrained("../input/bert-base-multilingual-uncased/")
#         self.encoder = DistilBertModel.from_pretrained("distilbert-base-multilingual-cased")
#         self.encoder = XLMRobertaModel.from_pretrained("../input/bert-base-multilingual-uncased/")
        self.dropout = nn.Dropout(0.3)
        self.linear_1 = nn.Linear(
            in_features=self.encoder.pooler.dense.out_features*2,
            out_features=self.encoder.pooler.dense.out_features*2,
        )
        self.linear_2 = nn.Linear(
            in_features=self.encoder.pooler.dense.out_features*2,
            out_features=self.encoder.pooler.dense.out_features,
        )
        self.linear_3 = nn.Linear(
            in_features=self.encoder.pooler.dense.out_features,
            out_features=1,
        )

    def forward(self, input_ids, attention_mask):
        # bs, seq_length = input_ids.shape
        seq_x = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        apool = torch.mean(seq_x[0], 1)
        mpool, _ = torch.max(seq_x[0], 1)
        x = torch.cat((apool, mpool), 1)
        x = self.dropout(x)
        x = self.linear_1(x)
        x = self.dropout(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return self.linear_3(x)

In [None]:
# mx = BERTBaseUncased(bert_path="../input/bert-base-multilingual-uncased/")
mx=ToxicSimpleNNModel()

df_train1 = pd.read_csv("../input/jigsaw-dataset/train_sent_prep.csv", usecols=["comment_text", "toxic"],converters={'comment_text': lambda x: clean(x)}).fillna("none")
df_train1 = df_train1[df_train1.comment_text!='']
df_train_full = df_train1.reset_index(drop=True)
train = df_train_full.sample(frac=1).reset_index(drop=True)
del df_train1,df_train_full

df_valid = pd.read_csv("../input/jigsaw-dataset/validation_sent_pr.csv", usecols=["comment_text", "toxic"], converters={'comment_text': lambda x: clean(x)})
valid = df_valid[df_valid.comment_text!=''].reset_index(drop=True)
del df_valid


In [None]:
df_test = pd.read_csv("../input/jigsaw-dataset/test_sent_pr.csv", usecols=["content"], converters={'content': lambda x: clean_test(x)})
df_test['comment_text'] = df_test.content
test = df_test[['comment_text']].reset_index(drop=True)

In [None]:
class RocAucMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = np.array([0,1])
        self.y_pred = np.array([0.5,0.5])
        self.score = 0

    def update(self, y_true, y_pred):
#         y_true = y_true.cpu().numpy().argmax(axis=1)
        y_true = y_true.cpu().numpy()
        y_pred = torch.sigmoid(y_pred).cpu().detach().numpy()[:,0]
#         y_pred = nn.functional.softmax(y_pred, dim=1).data.cpu().numpy()[:,1]
        self.y_true = np.hstack((self.y_true, y_true))
        self.y_pred = np.hstack((self.y_pred, y_pred))
        self.score = metrics.roc_auc_score(self.y_true, self.y_pred, labels=np.array([0, 1]))
    
    @property
    def avg(self):
        return self.score

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

# class LabelSmoothing(nn.Module):
#     def __init__(self, smoothing = 0.1):
#         super(LabelSmoothing, self).__init__()
#         self.confidence = 1.0 - smoothing
#         self.smoothing = smoothing

#     def forward(self, x, target):
#         if self.training:
#             x = x.float()
#             target = target.float()
#             logprobs = torch.nn.functional.log_softmax(x, dim = -1)
#             nll_loss = -logprobs * target
#             nll_loss = nll_loss.sum(-1)
#             smooth_loss = -logprobs.mean(dim=-1)
#             loss = self.confidence * nll_loss + self.smoothing * smooth_loss
#             return loss.mean()
#         else:
#             return torch.nn.functional.cross_entropy(x, target)
        
# def loss_fn(outputs, targets):
#     return LabelSmoothing()(outputs, targets)
    
def train_loop_fn(loader,model,optimizer,device,scheduler):
    tracker = xm.RateTracker()
    # cb = callback()
    losses = AverageMeter()
    auc = RocAucMeter()
    # start_time = time.time()


    gradient_accumulation_steps = 1
    model.train()
    for step, data in enumerate(loader):
        input_ids, input_masks, labels = data
        pred = model(input_ids = input_ids,
                    attention_mask = input_masks
                  )
        loss = loss_fn(pred, labels.float())
        loss.backward()
        loss = loss.detach().item()

        auc.update(labels, pred)
        losses.update(loss, input_ids.size(0))
        if (step + 1) % gradient_accumulation_steps == 0:
          # Calling the step function on an Optimizer makes an update to its parameters
            xm.optimizer_step(optimizer)
            scheduler.step()
            optimizer.zero_grad()

#         tracker.add(FLAGS['batch_size'])
        if (step+1) % 20 == 0:
            xm.master_print('[xla:{}]({}) Rate={:.2f} Loss={:.4f} AUC={:.4f} GlobalRate={:.2f} Time={}'.format(
                xm.get_ordinal(), step+1, tracker.rate(), losses.avg, auc.avg,
                tracker.global_rate(), time.asctime()))
    del loss
#     del losses
#     del outputs
    del input_ids
    del labels
  
    return losses, auc

def valid_loop_fn(loader,model,device): 
    losses = AverageMeter()
    final_scores = RocAucMeter()

    model.eval()
    with torch.no_grad():
        for j, data in enumerate(loader):

          # get the inputs
            input_ids, input_masks, labels = data
            pred = model(input_ids = input_ids.long(),
                       attention_mask = input_masks
                      )

            loss_val = loss_fn(pred, labels.float())
            final_scores.update(labels, pred)
            losses.update(loss_val.detach().item(), input_ids.size(0))
            del loss_val,pred,input_ids,input_masks,labels

    return losses, final_scores

# def valid_loop_fn(loader,model,device):
#     losses = AverageMeter()
#     final_scores = RocAucMeter()
#     fin_targets = []
#     fin_outputs = []

#     model.eval()
#     with torch.no_grad():
#         for j, data in enumerate(loader):

#           # get the inputs
#             input_ids, input_masks, labels = data
#             pred = model(input_ids = input_ids.long(),
#                        attention_mask = input_masks
#                       )

#             loss_val = loss_fn(pred, labels.float())
#             final_scores.update(labels, pred)
#             losses.update(loss_val.detach().item(), input_ids.size(0))
#             targets_np = labels.cpu().detach().numpy().tolist()
#             outputs_np = pred.cpu().detach().numpy()[:,1].tolist()
#             fin_targets.extend(targets_np)
#             fin_outputs.extend(outputs_np)
#             del loss_val,pred,input_ids,input_masks,labels

#     return losses, final_scores, fin_outputs, fin_targets

In [None]:
def test_loop_fn(loader,model,device): 

    model.eval()
    result = {'id': [], 'toxic': []}
    with torch.no_grad():
        for j, data in enumerate(loader):
        
          # get the inputs
            idxs, input_ids, input_masks = data
            pred = model(input_ids = input_ids.long(),
                       attention_mask = input_masks
                      )
            y_pred = torch.sigmoid(pred).cpu().detach().numpy()[:,0]
            result['id'].extend(idxs.cpu().numpy())
            result['toxic'].extend(y_pred)
            del pred,input_ids,input_masks,idxs

    return result

In [None]:
mx.load_state_dict(torch.load('model_0.pt'))

In [None]:
!mkdir ./node_submissions/

In [None]:
import datetime
import glob
import random

In [None]:
def _run():
    
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 64
    EPOCHS = 1

#     tokenizer = transformers.BertTokenizer.from_pretrained("../input/bertbase-multilingual-cased/", do_lower_case=True)
    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-multilingual-uncased", do_lower_case=True)
    
#     tokenizer = transformers.XLMRobertaTokenizer.from_pretrained("../input/xlm-roberta-base/", do_lower_case=True)

    count_proc = xm.xrt_world_size()
    num_proc = xm.get_ordinal()
    size_tr_dataset = round(len(train)/count_proc)
    size_val_dataset = round(len(valid)/count_proc)
    size_test_dataset = round(len(test)/count_proc)
    print('Всего ядер {}'.format(count_proc))

    train_data_loader = torch.utils.data.DataLoader(BERTDatasetTraining(
        comment_text = train[num_proc*size_tr_dataset:(num_proc+1)*size_tr_dataset].comment_text.values,
        targets=train[num_proc*size_tr_dataset:(num_proc+1)*size_tr_dataset].toxic.values, max_length=MAX_LEN, tokenizer=tokenizer),
                                                    batch_size=TRAIN_BATCH_SIZE, shuffle=False,
                                                    num_workers = 4, drop_last=False)
        
    valid_data_loader = torch.utils.data.DataLoader(BERTDatasetTraining(
        comment_text=valid[num_proc*size_val_dataset:(num_proc+1)*size_val_dataset].comment_text.values,
        targets=valid[num_proc*size_val_dataset:(num_proc+1)*size_val_dataset].toxic.values, max_length=MAX_LEN, tokenizer=tokenizer),
                                                    batch_size=TRAIN_BATCH_SIZE, shuffle=False,
                                                    num_workers = 4, drop_last=False)
    
    test_data_loader = torch.utils.data.DataLoader(BERTDatasetTraining(
        comment_text=valid[num_proc*size_test_dataset:(num_proc+1)*size_test_dataset].comment_text.values,
        targets=valid[num_proc*size_test_dataset:(num_proc+1)*size_test_dataset].toxic.values,
        idxs=test[num_proc*size_test_dataset:(num_proc+1)*size_test_dataset].index.values, max_length=MAX_LEN, tokenizer=tokenizer, test=True),
                                                    batch_size=TRAIN_BATCH_SIZE, shuffle=False,
                                                    num_workers = 4, drop_last=False)

    device = xm.xla_device()
    model = mx.to(device)

    param_optimizer = list(model.named_parameters())
#     no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
#     optimizer_grouped_parameters = [
#         {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
#         {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 0.4 * 1e-5 * xm.xrt_world_size()
    xm.master_print('Количество батчей {}'.format(len(train_data_loader)))
    num_train_steps = int(len(train_data_loader) * EPOCHS)
#     num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    warmup_proportion = 0.01
    num_warmup_steps = round(num_train_steps * warmup_proportion)
    xm.master_print(f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()},  num_warmup_steps={num_warmup_steps}')

#     optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_train_steps
    )
    
    best_valid_loss = float("Inf")

    for epoch in range(EPOCHS):
        start_time = time.time()
        xm.master_print("Start training epoch {}".format(epoch))
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_loss, train_score = train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
        del para_loader
        xm.master_print(f'[RESULT]: Train. Epoch: {epoch}, loss: {train_loss.avg:.5f}, final_score: {train_score.avg:.5f}, time: {(time.time() - start_time):.5f}')

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        valid_loss, valid_score= valid_loop_fn(para_loader.per_device_loader(device), model, device)
#         o, t = valid_loop_fn(para_loader.per_device_loader(device), model, device)
#         auc = metrics.roc_auc_score(np.array(t), o)

#         gc.collect()
        xm.master_print('Epoch {}/{} \t loss={:.4f}\t val_loss={:.4f} \t train_score={:.4f}\t val_score={:.4f}\t time={:.2f}s'
                        .format(epoch+1, EPOCHS, train_loss.avg, valid_loss.avg, train_score.avg, valid_score.avg, (time.time() - start_time)))

        del para_loader
        print(f'[xla:{xm.get_ordinal()}] AUC = {valid_score.avg}')
        
        def reduce_fn(vals):
            return sum(vals) / len(vals)

        avgloss = xm.mesh_reduce('auc_reduce', valid_loss.avg, reduce_fn)
        auc = xm.mesh_reduce('auc_reduce', valid_score.avg, reduce_fn)
        xm.master_print(f'AUC AVG = {auc}')
        
        xm.save(model.state_dict(), "model_{}.pt".format(epoch))
        if best_valid_loss > avgloss:
            best_valid_loss = avgloss
            xm.save(model.state_dict(), "model.pt")
            para_loader = pl.ParallelLoader(test_data_loader, [device])
            result = test_loop_fn(para_loader.per_device_loader(device), model, device)
            result = pd.DataFrame(result)
#             result.to_csv('submission.csv', index=False)
            node_count = len(glob.glob('node_submissions/*.csv'))
            result.to_csv(f'node_submissions/submission_{node_count}_{xm.get_ordinal()}_{random.random()}.csv', index=False)
            del para_loader,result
        del train_loss, train_score, valid_loss, valid_score
#         del train_loss, train_score
    para_loader = pl.ParallelLoader(train_data_loader, [device])
    valid_loss, valid_score= valid_loop_fn(para_loader.per_device_loader(device), model,  optimizer, device, scheduler=scheduler)
    xm.master_print('val_loss={:.4f} \t  val_score={:.4f}\t'
                    .format(valid_loss.avg, valid_score.avg))
    del para_loader
    result = test_loop_fn(para_loader.per_device_loader(device), model, device)
    result = pd.DataFrame(result)
#     result.to_csv('submission.csv', index=False)
    node_count = len(glob('node_submissions/*.csv'))
    result.to_csv(f'node_submissions/submission_{node_count}_{xm.get_ordinal()}_{random.random()}.csv', index=False)


In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

In [None]:
submission = pd.concat([pd.read_csv(path) for path in glob.glob('node_submissions/*.csv')]).groupby('id').mean()