<a href="https://colab.research.google.com/github/seonmia/NLP/blob/main/XLNet_SentimentClassification(baseline_code).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperparameters

In [24]:
used_model = 'xlnet-base-cased'
cased = False if 'uncased' in used_model else True

train_batch_size = 32
eval_batch_size = 32
test_batch_size = 32

learning_rate = 2e-5
train_epoch = 4
# weight_decay = 0.001

wandb_project = "final_project1" # WandB에 넣어둘 프로젝트 이름을 적절히 설정해주세요
wandb_team = "seonmia" # WandB 팀명

# Import requirements

In [1]:
!pip install sentencepiece  # xlnet import 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 5.1 MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers

In [None]:
!pip install wandb

In [6]:
import os
import pdb
import argparse
from dataclasses import dataclass, field
from typing import Optional
from collections import defaultdict
import wandb
from time import time

import torch
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm, trange
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    AutoConfig,
    AdamW,
)

In [None]:
wandb.login() #추가 

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [8]:
from transformers import XLNetTokenizer, XLNetForSequenceClassification   # xlnet import 

# 1. Preprocess

In [7]:
def make_id_file(task, tokenizer, cased):
    def make_data_strings(file_name, cased):
        data_strings = []
        with open(os.path.join(file_name), 'r', encoding='utf-8') as f:
            id_file_data = [tokenizer.encode(line if cased else line.lower()) for line in f.readlines()]
        for item in id_file_data:
            data_strings.append(' '.join([str(k) for k in item]))
        return data_strings
  
    print('it will take some times...')
    train_pos = make_data_strings('sentiment.train.1', cased)
    train_neg = make_data_strings('sentiment.train.0', cased)
    dev_pos = make_data_strings('sentiment.dev.1', cased)
    dev_neg = make_data_strings('sentiment.dev.0', cased)

    print('make id file finished!')
    return train_pos, train_neg, dev_pos, dev_neg

In [10]:
# tokenizer = BertTokenizer.from_pretrained(used_model)
tokenizer = XLNetTokenizer.from_pretrained(used_model)

In [11]:
%cd /content/drive/MyDrive/GoormProject1/goorm-project-1-text-classification

/content/drive/.shortcut-targets-by-id/1ovgSHdL_LDsDV-KWBQ2NNEs2v8Mpi0fm/GoormProject1/goorm-project-1-text-classification


In [None]:
!ls

In [13]:
train_pos, train_neg, dev_pos, dev_neg = make_id_file('yelp', tokenizer, cased)

it will take some times...
make id file finished!


In [14]:
class SentimentDataset(object):
    #  def __init__(self, pos, neg):
    def __init__(self, tokenizer, pos, neg):
        self.tokenizer = tokenizer
        self.data = []
        self.label = []

        for pos_sent in pos:
            self.data += [self._cast_to_int(pos_sent.strip().split())]
            self.label += [[1]]
        for neg_sent in neg:
            self.data += [self._cast_to_int(neg_sent.strip().split())]
            self.label += [[0]]
    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
         sample = self.data[index]
         return np.array(sample), np.array(self.label[index])         

In [15]:
train_dataset = SentimentDataset(tokenizer, train_pos, train_neg)
dev_dataset = SentimentDataset(tokenizer, dev_pos, dev_neg)

In [16]:
for i, item in enumerate(train_dataset):
    print(item)
    if i == 10:
        break

(array([2712,  626,   17,    9,    4,    3]), array([1]))
(array([11647,  2211,   348,    17,     9,     4,     3]), array([1]))
(array([  63,   77,   47, 1362,  632,   23,   21, 2528, 5145,   59,   27,
        343,  195,   17,    9,    4,    3]), array([1]))
(array([   36,    17,    26,    23,    24,   195, 15294,    68,  6938,
         101, 14378,    17,     9,     4,     3]), array([1]))
(array([  18,  891,   27, 3667,   17,    9,    4,    3]), array([1]))
(array([ 195, 1808,  626,   17,    9,    4,    3]), array([1]))
(array([195, 348,  17,   9,   4,   3]), array([1]))
(array([10480,    20,   191,    27, 18749,    21,  3895,    20,   632,
          23,    17,     9,     4,     3]), array([1]))
(array([  312,   250,    28,  3953,    49,  1808, 21929,    21,  5751,
          17,     9,     4,     3]), array([1]))
(array([  18,  109,  944, 1898, 3704,   17,    9,    4,    3]), array([1]))
(array([ 52, 250,  30, 172, 195,  17,   9,   4,   3]), array([1]))


In [17]:
def collate_fn_style(samples):
    input_ids, labels = zip(*samples)
    max_len = max(len(input_id) for input_id in input_ids)
    attention_mask = torch.tensor([[1] * len(input_id) + [0] * (max_len - len(input_id)) for input_id in input_ids])
    input_ids = pad_sequence([torch.tensor(input_id) for input_id in input_ids],
                             batch_first=True)
    
    token_type_ids = torch.tensor([[0] * len(input_id) for input_id in input_ids])
    position_ids = torch.tensor([list(range(len(input_id))) for input_id in input_ids])
    labels = torch.tensor(np.stack(labels, axis=0))

    return input_ids, attention_mask, token_type_ids, position_ids, labels

In [18]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=train_batch_size,
                                           shuffle=True, collate_fn=collate_fn_style,
                                           pin_memory=True, num_workers=2)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=eval_batch_size,
                                         shuffle=False, collate_fn=collate_fn_style,
                                         num_workers=2)

# 2. Train

In [19]:
random_seed=42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = BertForSequenceClassification.from_pretrained(used_model)
model = XLNetForSequenceClassification.from_pretrained(used_model)
model.to(device)

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias', 'logits_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

XLNetForSequenceClassification(
  (transformer): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation_function): GELUActivation()
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward

In [20]:
model.train()

# optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = AdamW(model.parameters(), lr=learning_rate)



In [23]:
def compute_acc(predictions, target_labels):
    return (np.array(predictions) == np.array(target_labels)).mean()

In [22]:
wandb.init(project=wandb_project, name=used_model+' '+str(int(time()))[-3:], entity=wandb_team)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [25]:

init_time=time()
lowest_valid_loss = 9999.


# 코드 추가
train_acc = []
train_loss = []
valid_acc = []
valid_loss = []


curr_train_loss = [] 
curr_train_acc = [] 
########

report_to ="wandb" 



for epoch in range(train_epoch):
    with tqdm(train_loader, unit="batch") as tepoch:
        for iteration, (input_ids, attention_mask, token_type_ids, position_ids, labels) in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch}")
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            position_ids = position_ids.to(device)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            
            output = model(input_ids=input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids,
                           position_ids=position_ids, 
                           labels=labels)
            loss = output.loss

            # 추가
            logits = output.logits
            batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
            batch_labels = [int(example) for example in labels]
            
            acc = compute_acc(batch_predictions, batch_labels)
            
            ##########



            loss.backward()

            optimizer.step()


            ###### 추가 ####
            curr_train_loss.append(loss.item())
            curr_train_acc.append(acc)

            ##############


            tepoch.set_postfix(loss=loss.item())
            if iteration != 0 and iteration % int(len(train_loader) / 5) == 0:
                # Evaluate the model five times per epoch
                with torch.no_grad():
                    model.eval()
                    curr_valid_loss = []   # valid_losses 수정 

                    # 추가
                    curr_valid_acc = []  
                    #####  

                    # predictions = []
                    # target_labels = []
                    for input_ids, attention_mask, token_type_ids, position_ids, labels in tqdm(dev_loader,
                                                                                                desc='Eval',
                                                                                                position=1,
                                                                                                leave=None):
                        input_ids = input_ids.to(device)
                        attention_mask = attention_mask.to(device)
                        token_type_ids = token_type_ids.to(device)
                        position_ids = position_ids.to(device)
                        labels = labels.to(device, dtype=torch.long)

                        output = model(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       token_type_ids=token_type_ids,
                                       position_ids=position_ids,
                                       labels=labels)

                        logits = output.logits
                        loss = output.loss
                        # valid_losses.append(loss.item())

                        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
                        batch_labels = [int(example) for example in labels]


                        curr_valid_loss.append(loss.item())
                        curr_valid_acc.append(compute_acc(batch_predictions, batch_labels))

                
                # 코드추가

                # loss /acc 계산
                mean_train_acc = sum(curr_train_acc) / len(curr_train_acc)
                mean_train_loss = sum(curr_train_loss) / len(curr_train_loss)
                mean_valid_acc = sum(curr_valid_acc) / len(curr_valid_acc)
                mean_valid_loss = sum(curr_valid_loss) / len(curr_valid_loss)

                train_acc.append(mean_train_acc)
                train_loss.append(mean_train_loss)
                valid_acc.append(mean_valid_acc)
                valid_loss.append(mean_valid_loss)
                
                curr_train_acc = [] 
                curr_train_loss = [] 

                # wandb log 수집 

                wandb.log({ 
                        "Train Loss": mean_train_loss,
                        "Train Accuracy": mean_train_acc,
                        "Valid Loss" : mean_valid_loss, 
                        "Valid Accuracy" : mean_valid_acc

                        })


                ###############

                if lowest_valid_loss > mean_valid_loss:
                    lowest_valid_loss = mean_valid_loss
                    print('Acc for model which have lower valid loss: ', mean_valid_acc)
                    torch.save(model.state_dict(), "./pytorch_model.bin")



fin_time=time()
print('Time:',fin_time-init_time)

Epoch 0:  20%|█▉        | 2770/13852 [06:09<27:06,  6.82batch/s, loss=0.00691]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:15,  8.22it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:07, 16.91it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 20.46it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.95it/s][A
Eval:  10%|█         | 13/125 [00:00<00:04, 22.91it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 23.23it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 23.04it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 23.36it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 23.53it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:04, 24.23it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:03, 23.99it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 24.27it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 24.68it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03, 24.86it/s][A
Eval:  34%|███▍      | 43/125 [00:01<00:03, 24.32it/s][A
Eval:

Acc for model which have lower valid loss:  0.969


Epoch 0:  40%|███▉      | 5540/13852 [12:25<17:47,  7.79batch/s, loss=0.0983]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:16,  7.70it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:07, 17.28it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 20.45it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.85it/s][A
Eval:  10%|█         | 13/125 [00:00<00:04, 22.58it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 23.08it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 23.18it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 23.58it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 23.70it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:03, 24.38it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:03, 24.02it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 24.06it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 24.19it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03, 24.44it/s][A
Eval:  34%|███▍      | 43/125 [00:01<00:03, 24.24it/s][A
Eval: 

Acc for model which have lower valid loss:  0.9735


Epoch 0:  60%|█████▉    | 8310/13852 [18:35<12:03,  7.66batch/s, loss=0.0133]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:15,  7.98it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:06, 17.36it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 20.71it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.78it/s][A
Eval:  10%|█         | 13/125 [00:00<00:04, 22.71it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 23.17it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 23.42it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 23.38it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 23.63it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:04, 24.01it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:03, 23.85it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 23.88it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 24.44it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03, 24.54it/s][A
Eval:  34%|███▍      | 43/125 [00:01<00:03, 24.20it/s][A
Eval: 

Acc for model which have lower valid loss:  0.9795


Epoch 0: 100%|█████████▉| 13850/13852 [30:57<00:00,  7.18batch/s, loss=0.0154]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:14,  8.29it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:06, 17.38it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 20.47it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.75it/s][A
Eval:  10%|█         | 13/125 [00:00<00:04, 22.68it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 23.12it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 23.23it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 23.51it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 23.66it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:03, 24.29it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:03, 23.53it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 23.89it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 24.40it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03, 24.29it/s][A
Eval:  34%|███▍      | 43/125 [00:01<00:03, 24.04it/s][A
Eval:

Acc for model which have lower valid loss:  0.981


Epoch 0: 100%|██████████| 13852/13852 [31:05<00:00,  7.43batch/s, loss=0.0104]
Epoch 1:  20%|█▉        | 2770/13852 [06:06<24:38,  7.50batch/s, loss=0.00223]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:17,  7.25it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:07, 16.51it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 19.95it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.51it/s][A
Eval:  10%|█         | 13/125 [00:00<00:05, 22.31it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 22.84it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 23.07it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 23.08it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 23.42it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:03, 24.25it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:04, 23.32it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 23.66it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 24.26it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03

Acc for model which have lower valid loss:  0.98


Epoch 1: 100%|██████████| 13852/13852 [30:59<00:00,  7.45batch/s, loss=0.017]
Epoch 2:  20%|█▉        | 2770/13852 [06:08<24:07,  7.65batch/s, loss=0.0153]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:17,  7.27it/s][A
Eval:   3%|▎         | 4/125 [00:00<00:07, 16.43it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:05, 19.86it/s][A
Eval:   8%|▊         | 10/125 [00:00<00:05, 21.26it/s][A
Eval:  10%|█         | 13/125 [00:00<00:05, 22.06it/s][A
Eval:  13%|█▎        | 16/125 [00:00<00:04, 22.55it/s][A
Eval:  15%|█▌        | 19/125 [00:00<00:04, 22.71it/s][A
Eval:  18%|█▊        | 22/125 [00:01<00:04, 22.73it/s][A
Eval:  20%|██        | 25/125 [00:01<00:04, 22.46it/s][A
Eval:  22%|██▏       | 28/125 [00:01<00:04, 23.37it/s][A
Eval:  25%|██▍       | 31/125 [00:01<00:04, 23.06it/s][A
Eval:  27%|██▋       | 34/125 [00:01<00:03, 23.28it/s][A
Eval:  30%|██▉       | 37/125 [00:01<00:03, 23.76it/s][A
Eval:  32%|███▏      | 40/125 [00:01<00:03, 

Time: 7478.83292222023





# 3. Test

In [26]:
import pandas as pd
test_df = pd.read_csv('test_no_label.csv')

In [27]:
test_dataset = test_df['Id']

In [28]:
def make_id_file_test(tokenizer, test_dataset, cased):
    data_strings = []
    id_file_data = [tokenizer.encode(sent if cased else sent.lower()) for sent in test_dataset]
    for item in id_file_data:
        data_strings.append(' '.join([str(k) for k in item]))
    return data_strings

In [29]:
test = make_id_file_test(tokenizer, test_dataset, cased)

In [30]:
class SentimentTestDataset(object):
    def __init__(self, tokenizer, test):
        self.tokenizer = tokenizer
        self.data = []

        for sent in test:
            self.data += [self._cast_to_int(sent.strip().split())]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample)

In [31]:
test_dataset = SentimentTestDataset(tokenizer, test)

In [32]:
def collate_fn_style_test(samples):
    input_ids = samples
    max_len = max(len(input_id) for input_id in input_ids)
    attention_mask = torch.tensor(
        [[1] * len(input_id) + [0] * (max_len - len(input_id)) for input_id in
         input_ids])
    input_ids = pad_sequence([torch.tensor(input_id) for input_id in input_ids],
                             batch_first=True)
    
    
    token_type_ids = torch.tensor([[0] * len(input_id) for input_id in input_ids])
    position_ids = torch.tensor([list(range(len(input_id))) for input_id in input_ids])

    return input_ids, attention_mask, token_type_ids, position_ids

In [33]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                          shuffle=False, collate_fn=collate_fn_style_test,
                                          num_workers=2)

In [34]:
with torch.no_grad():
    model.eval()
    predictions = []
    for input_ids, attention_mask, token_type_ids, position_ids in tqdm(test_loader,
                                                                        desc='Test',
                                                                        position=1,
                                                                        leave=None):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        position_ids = position_ids.to(device)

        output = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids,
                       position_ids=position_ids)

        logits = output.logits
        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
        predictions += batch_predictions


Test:   0%|          | 0/32 [00:00<?, ?it/s][A
Test:   3%|▎         | 1/32 [00:00<00:04,  7.43it/s][A
Test:  12%|█▎        | 4/32 [00:00<00:01, 15.43it/s][A
Test:  22%|██▏       | 7/32 [00:00<00:01, 18.38it/s][A
Test:  31%|███▏      | 10/32 [00:00<00:01, 19.63it/s][A
Test:  41%|████      | 13/32 [00:00<00:00, 20.91it/s][A
Test:  50%|█████     | 16/32 [00:00<00:00, 21.88it/s][A
Test:  59%|█████▉    | 19/32 [00:00<00:00, 22.24it/s][A
Test:  69%|██████▉   | 22/32 [00:01<00:00, 22.46it/s][A
Test:  78%|███████▊  | 25/32 [00:01<00:00, 22.89it/s][A
Test:  88%|████████▊ | 28/32 [00:01<00:00, 22.73it/s][A
Test:  97%|█████████▋| 31/32 [00:01<00:00, 22.80it/s][A
                                                     [A

In [35]:
test_df['Category'] = predictions

In [36]:
test_df.to_csv('submission_xlnet.csv', index=False)

In [37]:
test_df['Category'].value_counts()

1    502
0    498
Name: Category, dtype: int64