In [6]:
!pip install -q transformers

In [7]:
import pandas as pd
from tqdm import tqdm


In [8]:
train=pd.read_csv("/kaggle/input/tweeter/train.csv")
test=pd.read_csv("/kaggle/input/tweeter/test.csv")
val=pd.read_csv("/kaggle/input/tweeter/val.csv")

In [10]:
y_train = train['Category']
y_val = val['Category']
y_test = test['Category']
X_train = train['Text']
X_val = val['Text']
X_test = test['Text']

In [12]:
def text_preprocessing(text):
    text = re.sub(r'(@.*?)[\s]', ' ', text)
    text = re.sub(r'&amp;', '&', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [13]:
from transformers import AutoTokenizer

bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

def preprocessing_for_bert(data):
    input_ids = []
    attention_masks = []
    for sentence in tqdm(data):
        encoded_sent = bert_tokenizer.encode_plus(
            text=text_preprocessing(sentence),  
            add_special_tokens=True,        
            max_length=MAXIMUM_LENGTH,                  
            pad_to_max_length=True,         
            return_attention_mask=True      
            )

        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [15]:
import regex as re
MAXIMUM_LENGTH = 64

token_ids = list(preprocessing_for_bert([X_train[0]])[0].squeeze().numpy())
train_inputs, train_masks = preprocessing_for_bert(X_train)
val_inputs, val_masks = preprocessing_for_bert(X_val)

  0%|          | 0/1 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 1/1 [00:00<00:00, 91.18it/s]


Original:  Disnleyland ISN'T tha happiest place in tha world! Grad trip can suck my cock, well maybe not magic mountain &amp; its a small world ride 
Token IDs:  [101, 4487, 2015, 20554, 3240, 3122, 3475, 1005, 1056, 22794, 5292, 9397, 10458, 2173, 1999, 22794, 2088, 999, 24665, 4215, 4440, 2064, 11891, 2026, 10338, 1010, 2092, 2672, 2025, 3894, 3137, 1004, 2049, 1037, 2235, 2088, 4536, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Tokenizing data...


100%|██████████| 1120000/1120000 [03:34<00:00, 5210.61it/s]
100%|██████████| 240000/240000 [00:44<00:00, 5408.35it/s]


In [16]:
y_train.replace(4,1,inplace= True)
y_val.replace(4,1, inplace=True)

In [17]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
train_labels = torch.tensor(y_train.values)
val_labels = torch.tensor(y_val.values)
batch_size = 512

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)


validation_data = TensorDataset(val_inputs, val_masks, val_labels)
validation_sampler = SequentialSampler(validation_data)
vaidation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [18]:
import torch
import torch.nn as nn
from transformers import BertModel

class BERT_Classification(nn.Module):
    def __init__(self, bert_freezing=False):
        super(BERT_Classification, self).__init__()
        Dimensions_input, H, Dimensions_out = 768, 50, 2
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Sequential(
            nn.Linear(Dimensions_input, H),
            nn.ReLU(),
            nn.Linear(H, Dimensions_out)
        )

        if bert_freezing:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)

        last_hidden_state_cls = outputs[0][:, 0, :]

        logits = self.classifier(last_hidden_state_cls)

        return logits

In [19]:
from transformers import AdamW, get_linear_schedule_with_warmup

def initialize_model(num_epochs=4):
    bert_classification_model = BERT_Classification(bert_freezing=False)
    bert_classification_model= nn.DataParallel(bert_classification_model)
    bert_classification_model.to(device)
    optimizer = AdamW(bert_classification_model.parameters(),
                      learning_rate=5e-5,    
                      eps=1e-8   
                      )
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, 
                                                num_training_steps=total_steps)
    return bert_classification_model, optimizer, scheduler

In [20]:
import random
import time
import numpy as np
loss_fn = nn.CrossEntropyLoss()
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def Train(model, train_dataloader, vaidation_dataloader=None, num_epochs=4, evaluation=False):
    for epoch_i in range(num_epochs):
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0
        for step, batch in enumerate(tqdm(train_dataloader)):
            batch_counts +=1 # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
            model.zero_grad()
            logits = model(b_input_ids, b_attn_mask)
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
            time_elapsed = time.time() - t0_batch
            batch_loss, batch_counts = 0, 0
            t0_batch = time.time()
            avg_train_loss = total_loss / len(train_dataloader)
    if evaluation == True:
        validation_loss, validation_accuracy = evaluate(model, vaidation_dataloader)
        time_elapsed = time.time() - t0_epoch
        
def evaluate(model, vaidation_dataloader):
    model.eval()
    validation_accuracy = []
    validation_loss = []
    for batch in vaidation_dataloader:
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)
            loss = loss_fn(logits, b_labels)
            validation_loss.append(loss.item())
            preds = torch.argmax(logits, dim=1).flatten()
            accuracy = (preds == b_labels).cpu().numpy().mean() * 100
            validation_accuracy.append(accuracy)
    validation_loss = np.mean(validation_loss)
    validation_accuracy = np.mean(validation_accuracy)
    return validation_loss, validation_accuracy

In [21]:
set_seed(42) 
bert_classification_model, optimizer, scheduler = initialize_model(num_epochs=1)
Train(bert_classification_model, train_dataloader, vaidation_dataloader, num_epochs=1, evaluation=True)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]



 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


  1%|          | 21/2188 [00:50<1:21:03,  2.24s/it]

   1    |   20    |   0.548140   |     -      |     -     |   50.22  
----------------------------------------------------------------------


  2%|▏         | 41/2188 [01:36<1:24:26,  2.36s/it]

   1    |   40    |   0.403959   |     -      |     -     |   46.27  
----------------------------------------------------------------------


  3%|▎         | 61/2188 [02:25<1:26:19,  2.44s/it]

   1    |   60    |   0.398806   |     -      |     -     |   48.67  
----------------------------------------------------------------------


  4%|▎         | 81/2188 [03:13<1:24:16,  2.40s/it]

   1    |   80    |   0.380845   |     -      |     -     |   47.92  
----------------------------------------------------------------------


  5%|▍         | 101/2188 [04:01<1:23:40,  2.41s/it]

   1    |   100   |   0.372116   |     -      |     -     |   48.38  
----------------------------------------------------------------------


  6%|▌         | 121/2188 [04:49<1:22:43,  2.40s/it]

   1    |   120   |   0.369729   |     -      |     -     |   47.88  
----------------------------------------------------------------------


  6%|▋         | 141/2188 [05:37<1:21:35,  2.39s/it]

   1    |   140   |   0.349518   |     -      |     -     |   47.96  
----------------------------------------------------------------------


  7%|▋         | 161/2188 [06:25<1:21:04,  2.40s/it]

   1    |   160   |   0.364483   |     -      |     -     |   47.97  
----------------------------------------------------------------------


  8%|▊         | 181/2188 [07:13<1:20:09,  2.40s/it]

   1    |   180   |   0.349124   |     -      |     -     |   48.09  
----------------------------------------------------------------------


  9%|▉         | 201/2188 [08:01<1:19:37,  2.40s/it]

   1    |   200   |   0.357527   |     -      |     -     |   48.04  
----------------------------------------------------------------------


 10%|█         | 221/2188 [08:49<1:18:44,  2.40s/it]

   1    |   220   |   0.343327   |     -      |     -     |   48.02  
----------------------------------------------------------------------


 11%|█         | 241/2188 [09:37<1:17:59,  2.40s/it]

   1    |   240   |   0.346932   |     -      |     -     |   48.13  
----------------------------------------------------------------------


 12%|█▏        | 261/2188 [10:25<1:17:17,  2.41s/it]

   1    |   260   |   0.344584   |     -      |     -     |   48.10  
----------------------------------------------------------------------


 13%|█▎        | 281/2188 [11:13<1:16:33,  2.41s/it]

   1    |   280   |   0.342927   |     -      |     -     |   48.10  
----------------------------------------------------------------------


 14%|█▍        | 301/2188 [12:01<1:15:31,  2.40s/it]

   1    |   300   |   0.340757   |     -      |     -     |   48.06  
----------------------------------------------------------------------


 15%|█▍        | 321/2188 [12:49<1:14:45,  2.40s/it]

   1    |   320   |   0.334586   |     -      |     -     |   48.05  
----------------------------------------------------------------------


 16%|█▌        | 341/2188 [13:38<1:14:09,  2.41s/it]

   1    |   340   |   0.344353   |     -      |     -     |   48.27  
----------------------------------------------------------------------


 16%|█▋        | 361/2188 [14:26<1:13:02,  2.40s/it]

   1    |   360   |   0.340962   |     -      |     -     |   48.02  
----------------------------------------------------------------------


 17%|█▋        | 381/2188 [15:14<1:12:20,  2.40s/it]

   1    |   380   |   0.340870   |     -      |     -     |   48.07  
----------------------------------------------------------------------


 18%|█▊        | 401/2188 [16:02<1:11:19,  2.40s/it]

   1    |   400   |   0.332976   |     -      |     -     |   47.99  
----------------------------------------------------------------------


 19%|█▉        | 421/2188 [16:50<1:10:25,  2.39s/it]

   1    |   420   |   0.324703   |     -      |     -     |   47.85  
----------------------------------------------------------------------


 20%|██        | 441/2188 [17:37<1:09:51,  2.40s/it]

   1    |   440   |   0.336414   |     -      |     -     |   47.91  
----------------------------------------------------------------------


 21%|██        | 461/2188 [18:25<1:08:51,  2.39s/it]

   1    |   460   |   0.326176   |     -      |     -     |   47.88  
----------------------------------------------------------------------


 22%|██▏       | 481/2188 [19:13<1:08:21,  2.40s/it]

   1    |   480   |   0.329784   |     -      |     -     |   47.89  
----------------------------------------------------------------------


 23%|██▎       | 501/2188 [20:01<1:07:08,  2.39s/it]

   1    |   500   |   0.336087   |     -      |     -     |   47.74  
----------------------------------------------------------------------


 24%|██▍       | 521/2188 [20:49<1:06:33,  2.40s/it]

   1    |   520   |   0.331335   |     -      |     -     |   47.87  
----------------------------------------------------------------------


 25%|██▍       | 541/2188 [21:37<1:05:33,  2.39s/it]

   1    |   540   |   0.325649   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 26%|██▌       | 561/2188 [22:25<1:04:48,  2.39s/it]

   1    |   560   |   0.333206   |     -      |     -     |   47.81  
----------------------------------------------------------------------


 27%|██▋       | 581/2188 [23:12<1:03:58,  2.39s/it]

   1    |   580   |   0.326229   |     -      |     -     |   47.81  
----------------------------------------------------------------------


 27%|██▋       | 601/2188 [24:00<1:03:17,  2.39s/it]

   1    |   600   |   0.313256   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 28%|██▊       | 621/2188 [24:48<1:02:19,  2.39s/it]

   1    |   620   |   0.319264   |     -      |     -     |   47.77  
----------------------------------------------------------------------


 29%|██▉       | 641/2188 [25:36<1:01:35,  2.39s/it]

   1    |   640   |   0.325841   |     -      |     -     |   47.73  
----------------------------------------------------------------------


 30%|███       | 661/2188 [26:23<1:00:53,  2.39s/it]

   1    |   660   |   0.322865   |     -      |     -     |   47.82  
----------------------------------------------------------------------


 31%|███       | 681/2188 [27:11<1:00:02,  2.39s/it]

   1    |   680   |   0.314056   |     -      |     -     |   47.83  
----------------------------------------------------------------------


 32%|███▏      | 701/2188 [27:59<59:10,  2.39s/it]  

   1    |   700   |   0.324403   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 33%|███▎      | 721/2188 [28:47<59:00,  2.41s/it]

   1    |   720   |   0.319979   |     -      |     -     |   47.87  
----------------------------------------------------------------------


 34%|███▍      | 741/2188 [29:35<57:36,  2.39s/it]

   1    |   740   |   0.320265   |     -      |     -     |   47.79  
----------------------------------------------------------------------


 35%|███▍      | 761/2188 [30:23<56:46,  2.39s/it]

   1    |   760   |   0.325663   |     -      |     -     |   47.79  
----------------------------------------------------------------------


 36%|███▌      | 781/2188 [31:10<55:57,  2.39s/it]

   1    |   780   |   0.329599   |     -      |     -     |   47.74  
----------------------------------------------------------------------


 37%|███▋      | 801/2188 [31:58<55:07,  2.38s/it]

   1    |   800   |   0.314263   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 38%|███▊      | 821/2188 [32:46<54:23,  2.39s/it]

   1    |   820   |   0.324987   |     -      |     -     |   47.72  
----------------------------------------------------------------------


 38%|███▊      | 841/2188 [33:34<53:38,  2.39s/it]

   1    |   840   |   0.331840   |     -      |     -     |   47.72  
----------------------------------------------------------------------


 39%|███▉      | 861/2188 [34:21<52:49,  2.39s/it]

   1    |   860   |   0.318112   |     -      |     -     |   47.79  
----------------------------------------------------------------------


 40%|████      | 881/2188 [35:09<52:07,  2.39s/it]

   1    |   880   |   0.307551   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 41%|████      | 901/2188 [35:57<51:16,  2.39s/it]

   1    |   900   |   0.320007   |     -      |     -     |   47.98  
----------------------------------------------------------------------


 42%|████▏     | 921/2188 [36:45<50:26,  2.39s/it]

   1    |   920   |   0.312974   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 43%|████▎     | 941/2188 [37:33<49:44,  2.39s/it]

   1    |   940   |   0.315871   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 44%|████▍     | 961/2188 [38:21<48:54,  2.39s/it]

   1    |   960   |   0.315223   |     -      |     -     |   47.82  
----------------------------------------------------------------------


 45%|████▍     | 981/2188 [39:09<48:06,  2.39s/it]

   1    |   980   |   0.319284   |     -      |     -     |   47.98  
----------------------------------------------------------------------


 46%|████▌     | 1001/2188 [39:56<47:15,  2.39s/it]

   1    |  1000   |   0.316665   |     -      |     -     |   47.72  
----------------------------------------------------------------------


 47%|████▋     | 1021/2188 [40:44<46:25,  2.39s/it]

   1    |  1020   |   0.317339   |     -      |     -     |   47.81  
----------------------------------------------------------------------


 48%|████▊     | 1041/2188 [41:32<45:34,  2.38s/it]

   1    |  1040   |   0.312733   |     -      |     -     |   47.65  
----------------------------------------------------------------------


 48%|████▊     | 1061/2188 [42:20<44:54,  2.39s/it]

   1    |  1060   |   0.310849   |     -      |     -     |   47.94  
----------------------------------------------------------------------


 49%|████▉     | 1081/2188 [43:07<44:07,  2.39s/it]

   1    |  1080   |   0.311197   |     -      |     -     |   47.73  
----------------------------------------------------------------------


 50%|█████     | 1101/2188 [43:55<43:19,  2.39s/it]

   1    |  1100   |   0.303298   |     -      |     -     |   47.82  
----------------------------------------------------------------------


 51%|█████     | 1121/2188 [44:43<42:27,  2.39s/it]

   1    |  1120   |   0.312652   |     -      |     -     |   47.68  
----------------------------------------------------------------------


 52%|█████▏    | 1141/2188 [45:31<41:44,  2.39s/it]

   1    |  1140   |   0.306712   |     -      |     -     |   48.01  
----------------------------------------------------------------------


 53%|█████▎    | 1161/2188 [46:19<40:56,  2.39s/it]

   1    |  1160   |   0.312248   |     -      |     -     |   47.81  
----------------------------------------------------------------------


 54%|█████▍    | 1181/2188 [47:07<40:09,  2.39s/it]

   1    |  1180   |   0.311956   |     -      |     -     |   47.83  
----------------------------------------------------------------------


 55%|█████▍    | 1201/2188 [47:54<39:21,  2.39s/it]

   1    |  1200   |   0.314059   |     -      |     -     |   47.83  
----------------------------------------------------------------------


 56%|█████▌    | 1221/2188 [48:42<38:36,  2.40s/it]

   1    |  1220   |   0.311136   |     -      |     -     |   48.02  
----------------------------------------------------------------------


 57%|█████▋    | 1241/2188 [49:30<37:45,  2.39s/it]

   1    |  1240   |   0.307707   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 58%|█████▊    | 1261/2188 [50:18<36:53,  2.39s/it]

   1    |  1260   |   0.310110   |     -      |     -     |   47.75  
----------------------------------------------------------------------


 59%|█████▊    | 1281/2188 [51:06<36:06,  2.39s/it]

   1    |  1280   |   0.301060   |     -      |     -     |   47.75  
----------------------------------------------------------------------


 59%|█████▉    | 1301/2188 [51:53<35:12,  2.38s/it]

   1    |  1300   |   0.304270   |     -      |     -     |   47.69  
----------------------------------------------------------------------


 60%|██████    | 1321/2188 [52:41<34:31,  2.39s/it]

   1    |  1320   |   0.303691   |     -      |     -     |   47.88  
----------------------------------------------------------------------


 61%|██████▏   | 1341/2188 [53:29<33:42,  2.39s/it]

   1    |  1340   |   0.318521   |     -      |     -     |   47.78  
----------------------------------------------------------------------


 62%|██████▏   | 1361/2188 [54:17<32:52,  2.38s/it]

   1    |  1360   |   0.304007   |     -      |     -     |   47.71  
----------------------------------------------------------------------


 63%|██████▎   | 1381/2188 [55:04<32:08,  2.39s/it]

   1    |  1380   |   0.313638   |     -      |     -     |   47.74  
----------------------------------------------------------------------


 64%|██████▍   | 1401/2188 [55:52<31:22,  2.39s/it]

   1    |  1400   |   0.301033   |     -      |     -     |   47.86  
----------------------------------------------------------------------


 65%|██████▍   | 1421/2188 [56:40<30:27,  2.38s/it]

   1    |  1420   |   0.301169   |     -      |     -     |   47.71  
----------------------------------------------------------------------


 66%|██████▌   | 1441/2188 [57:28<29:46,  2.39s/it]

   1    |  1440   |   0.302504   |     -      |     -     |   47.77  
----------------------------------------------------------------------


 67%|██████▋   | 1461/2188 [58:16<28:56,  2.39s/it]

   1    |  1460   |   0.308320   |     -      |     -     |   47.82  
----------------------------------------------------------------------


 68%|██████▊   | 1481/2188 [59:04<28:11,  2.39s/it]

   1    |  1480   |   0.306319   |     -      |     -     |   47.89  
----------------------------------------------------------------------


 69%|██████▊   | 1501/2188 [59:51<27:19,  2.39s/it]

   1    |  1500   |   0.303572   |     -      |     -     |   47.76  
----------------------------------------------------------------------


 70%|██████▉   | 1521/2188 [1:00:39<26:34,  2.39s/it]

   1    |  1520   |   0.307240   |     -      |     -     |   47.69  
----------------------------------------------------------------------


 70%|███████   | 1541/2188 [1:01:27<25:47,  2.39s/it]

   1    |  1540   |   0.300851   |     -      |     -     |   47.85  
----------------------------------------------------------------------


 71%|███████▏  | 1561/2188 [1:02:15<25:07,  2.40s/it]

   1    |  1560   |   0.299036   |     -      |     -     |   47.95  
----------------------------------------------------------------------


 72%|███████▏  | 1581/2188 [1:03:03<24:13,  2.39s/it]

   1    |  1580   |   0.301548   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 73%|███████▎  | 1601/2188 [1:03:50<23:23,  2.39s/it]

   1    |  1600   |   0.309014   |     -      |     -     |   47.83  
----------------------------------------------------------------------


 74%|███████▍  | 1621/2188 [1:04:38<22:37,  2.39s/it]

   1    |  1620   |   0.299205   |     -      |     -     |   47.83  
----------------------------------------------------------------------


 75%|███████▌  | 1641/2188 [1:05:26<21:47,  2.39s/it]

   1    |  1640   |   0.300328   |     -      |     -     |   48.00  
----------------------------------------------------------------------


 76%|███████▌  | 1661/2188 [1:06:14<21:01,  2.39s/it]

   1    |  1660   |   0.305908   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 77%|███████▋  | 1681/2188 [1:07:02<20:18,  2.40s/it]

   1    |  1680   |   0.298802   |     -      |     -     |   48.02  
----------------------------------------------------------------------


 78%|███████▊  | 1701/2188 [1:07:50<19:24,  2.39s/it]

   1    |  1700   |   0.291174   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 79%|███████▊  | 1721/2188 [1:08:38<18:42,  2.40s/it]

   1    |  1720   |   0.307118   |     -      |     -     |   48.02  
----------------------------------------------------------------------


 80%|███████▉  | 1741/2188 [1:09:26<17:49,  2.39s/it]

   1    |  1740   |   0.300934   |     -      |     -     |   47.91  
----------------------------------------------------------------------


 80%|████████  | 1761/2188 [1:10:14<17:06,  2.40s/it]

   1    |  1760   |   0.293771   |     -      |     -     |   47.99  
----------------------------------------------------------------------


 81%|████████▏ | 1781/2188 [1:11:02<16:20,  2.41s/it]

   1    |  1780   |   0.300839   |     -      |     -     |   48.10  
----------------------------------------------------------------------


 82%|████████▏ | 1801/2188 [1:11:50<15:29,  2.40s/it]

   1    |  1800   |   0.296138   |     -      |     -     |   48.07  
----------------------------------------------------------------------


 83%|████████▎ | 1821/2188 [1:12:38<14:41,  2.40s/it]

   1    |  1820   |   0.303438   |     -      |     -     |   48.21  
----------------------------------------------------------------------


 84%|████████▍ | 1841/2188 [1:13:26<13:51,  2.40s/it]

   1    |  1840   |   0.305125   |     -      |     -     |   47.94  
----------------------------------------------------------------------


 85%|████████▌ | 1861/2188 [1:14:14<13:07,  2.41s/it]

   1    |  1860   |   0.306316   |     -      |     -     |   48.08  
----------------------------------------------------------------------


 86%|████████▌ | 1881/2188 [1:15:02<12:17,  2.40s/it]

   1    |  1880   |   0.301002   |     -      |     -     |   48.07  
----------------------------------------------------------------------


 87%|████████▋ | 1901/2188 [1:15:50<11:27,  2.40s/it]

   1    |  1900   |   0.299478   |     -      |     -     |   48.04  
----------------------------------------------------------------------


 88%|████████▊ | 1921/2188 [1:16:38<10:39,  2.39s/it]

   1    |  1920   |   0.295977   |     -      |     -     |   47.86  
----------------------------------------------------------------------


 89%|████████▊ | 1941/2188 [1:17:26<09:50,  2.39s/it]

   1    |  1940   |   0.294775   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 90%|████████▉ | 1961/2188 [1:18:14<09:03,  2.39s/it]

   1    |  1960   |   0.293907   |     -      |     -     |   47.88  
----------------------------------------------------------------------


 91%|█████████ | 1981/2188 [1:19:02<08:16,  2.40s/it]

   1    |  1980   |   0.304040   |     -      |     -     |   48.04  
----------------------------------------------------------------------


 91%|█████████▏| 2001/2188 [1:19:50<07:26,  2.39s/it]

   1    |  2000   |   0.298416   |     -      |     -     |   47.77  
----------------------------------------------------------------------


 92%|█████████▏| 2021/2188 [1:20:38<06:39,  2.39s/it]

   1    |  2020   |   0.291928   |     -      |     -     |   47.80  
----------------------------------------------------------------------


 93%|█████████▎| 2041/2188 [1:21:25<05:51,  2.39s/it]

   1    |  2040   |   0.303884   |     -      |     -     |   47.84  
----------------------------------------------------------------------


 94%|█████████▍| 2061/2188 [1:22:13<05:03,  2.39s/it]

   1    |  2060   |   0.293720   |     -      |     -     |   47.81  
----------------------------------------------------------------------


 95%|█████████▌| 2081/2188 [1:23:01<04:15,  2.39s/it]

   1    |  2080   |   0.291963   |     -      |     -     |   47.98  
----------------------------------------------------------------------


 96%|█████████▌| 2101/2188 [1:23:49<03:27,  2.39s/it]

   1    |  2100   |   0.302981   |     -      |     -     |   47.76  
----------------------------------------------------------------------


 97%|█████████▋| 2121/2188 [1:24:37<02:40,  2.39s/it]

   1    |  2120   |   0.296548   |     -      |     -     |   47.72  
----------------------------------------------------------------------


 98%|█████████▊| 2141/2188 [1:25:25<01:52,  2.39s/it]

   1    |  2140   |   0.297940   |     -      |     -     |   47.98  
----------------------------------------------------------------------


 99%|█████████▉| 2161/2188 [1:26:12<01:04,  2.39s/it]

   1    |  2160   |   0.298414   |     -      |     -     |   47.83  
----------------------------------------------------------------------


100%|█████████▉| 2181/2188 [1:27:00<00:16,  2.39s/it]

   1    |  2180   |   0.297026   |     -      |     -     |   47.81  
----------------------------------------------------------------------


100%|██████████| 2188/2188 [1:27:16<00:00,  2.39s/it]

   1    |  2187   |   0.296606   |     -      |     -     |   15.70  
----------------------------------------------------------------------





   1    |    -    |   0.320093   |  0.297008  |   87.30   |  5657.63 
----------------------------------------------------------------------


Training complete!


In [22]:
torch.save(bert_classification_model, './bert.pt')

In [23]:
y_test.replace(4,1,inplace=True)
test_labels = torch.tensor(y_test.values)
test_inputs, test_masks = preprocessing_for_bert(X_test)
test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

100%|██████████| 240000/240000 [00:46<00:00, 5107.75it/s]


In [24]:

loss,acc = evaluate(bert_classification_model, test_dataloader)

0.29630521765904133 87.28969549573561


In [26]:
print("The Loss Value is "+str(loss)+" .The accuracy values are "+str(acc))

The Loss Value is 0.29630521765904133 .The accuracy values are 87.28969549573561
