This is an inference notebook.
Training notebook can be found [here](https://www.kaggle.com/palash97/bert-pytorch-starter-code-training/)

## Import Libraries

In [None]:
import numpy as np
import pandas as pd

from scipy.stats import rankdata
from tqdm import tqdm
from transformers import BertTokenizer, BertModel, logging
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

logging.set_verbosity_error()

## Load test data

In [None]:
sample_sub_path = '../input/jigsaw-toxic-severity-rating/sample_submission.csv'
test_csv_path = '../input/jigsaw-toxic-severity-rating/comments_to_score.csv'

In [None]:
test_df = pd.read_csv(test_csv_path)
test_df.head()

In [None]:
print(len(test_df))

In [None]:
sample_sub_df = pd.read_csv(sample_sub_path)
sample_sub_df.head()

## Load pre-trained models

In [None]:
tokenizer_pretrained = '../input/jigsaw-v13/tokenizer_pretrained'
bert_pretrained = '../input/jigsaw-v13/bert_model_pretrained'

tokenizer = BertTokenizer.from_pretrained(tokenizer_pretrained, local_files_only=True)
bert_model = BertModel.from_pretrained(bert_pretrained, local_files_only=True)

In [None]:
class Net(nn.Module):
    def __init__(self, bert_model):
        super(Net, self).__init__()
        self.bert_model = bert_model
        self.fcdense = nn.Linear(self.bert_model.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        bert_out = self.bert_model(input_ids, attention_mask, return_dict=True)
        pooler_output =  bert_out['pooler_output']    # (batch_size, 768)
        output = self.fcdense(pooler_output)       # (batch_size, 1)
        return output

In [None]:
model = Net(bert_model=bert_model)
# model = nn.DataParallel(model)

In [None]:
best_model = '../input/jigsaw-model-v11/toxicity_best_model.pth.tar'
model.load_state_dict(torch.load(best_model, map_location='cuda'))

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Test the model

In [None]:
class CustomTestDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        comment_text = self.df.iloc[index, 1]
        
        encoding = self.tokenizer.encode_plus(comment_text,
                                             add_special_tokens=True,
                                             max_length=self.max_len,
                                             return_token_type_ids=False,
                                             padding='max_length',
                                             truncation=True,
                                             return_attention_mask=True,
                                             return_tensors='pt'
                                             )
        
        input_ids = encoding['input_ids'].squeeze()    # Shape: (max_length)
        attention_mask = encoding['attention_mask'].squeeze()    # Shape: (max_length)
            
        return input_ids, attention_mask

In [None]:
test_dataset = CustomTestDataset(test_df, tokenizer, max_len=400)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
def test_epoch(model, test_loader, DEVICE):
    model.eval()
    
    pred = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            input_ids, attention_mask = data
            input_ids = input_ids.to(DEVICE)   # (batch_size, seq_len)
            attention_mask = attention_mask.to(DEVICE)   # (batch_size, seq_len)

            output = model(input_ids, attention_mask)   # (batch_size, num_classes)

            pred.append(output.squeeze())
            
    pred = torch.cat(pred)
    return pred

In [None]:
model.to(device)
pred = test_epoch(model, test_loader, device)
pred.shape

In [None]:
predictions = pred.detach().cpu().numpy()

In [None]:
sample_sub_df['score'] = rankdata(predictions)

In [None]:
sample_sub_df.to_csv('submission.csv', index=False)

In [None]:
print('Done!')