# Контекст 

Использовать модель transformer-BERT для классификации тональности IMDB

In [2]:
import time
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification

# train - val - test

In [3]:
torch.backends.cudnn.deterministic = True
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
data = pd.concat([train, test], axis = 0, ignore_index = True).sample(frac = 1, random_state = RANDOM_SEED)
data.head()

Unnamed: 0,text,sentiment
33553,This is a clever and entertaining film about t...,pos
9427,The Woman In Black is fantastic in all aspects...,pos
199,"...about the importance of being young, having...",pos
12447,Recap: It's business as usual at Louche's casi...,pos
39489,There have been some harsh criticisms of Coman...,pos


In [5]:
data['sentiment'] = data['sentiment'].map(lambda x: np.where(x == 'neg', 0, 1))
data.head()

Unnamed: 0,text,sentiment
33553,This is a clever and entertaining film about t...,1
9427,The Woman In Black is fantastic in all aspects...,1
199,"...about the importance of being young, having...",1
12447,Recap: It's business as usual at Louche's casi...,1
39489,There have been some harsh criticisms of Coman...,1


In [6]:
data['sentiment'].value_counts()

sentiment
1    25000
0    25000
Name: count, dtype: int64

In [7]:
train_texts = data.iloc[:35000]['text'].values
train_labels = data.iloc[:35000]['sentiment'].values

valid_texts = data.iloc[35000:40000]['text'].values
valid_labels = data.iloc[35000:40000]['sentiment'].values

test_texts = data.iloc[40000:]['text'].values
test_labels = data.iloc[40000:]['sentiment'].values

# Токенизация

In [8]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(
    list(train_texts),
    truncation = True,
    padding = True
    )

valid_encodings = tokenizer(
    list(valid_texts),
    truncation = True,
    padding = True
    )

test_encodings = tokenizer(
    list(test_texts),
    truncation = True,
    padding = True
    )

# Dataset - Dataloader

In [9]:
class IMDbDataset(Dataset):

    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)

In [10]:
train_dataset = IMDbDataset(train_encodings, train_labels)
valid_dataset = IMDbDataset(valid_encodings, valid_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) 
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False) 
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [11]:
train_loader.dataset[0]

{'input_ids': tensor([  101,  2023,  2003,  1037, 12266,  1998, 14036,  2143,  2055,  1996,
          2067,  9954, 17773,  2008,  2253,  2006,  2000,  5454,  1037,  6332,
          2000,  5206,  9806,  2006,  1996,  3892,  2265,  1012,  2174,  1010,
          1996,  3185,  1005,  1055,  7209,  9459,  1010,  2008,  2585, 26593,
          1005,  1055,  6020, 21699, 11725,  2020, 17092,  2138,  1997,  2010,
          6801,  3267,  1010,  1998,  1996,  2265, 15911,  2011,  1996,  4372,
          2705, 20793,  3672,  1997,  2019, 14092,  6108, 18798,  2080,  1999,
          1996,  2877,  4038,  2565,  2006,  2547,  2003,  3432,  4895, 16344,
          5657,  1010,  2004,  3087,  2040,  2038,  3427,  2119, 18798,  2080,
          1998, 26593,  4685,  2064,  2156,  2005,  3209,  1012,  1996,  2524,
          1998,  3722,  3606,  2003,  3599,  2004,  2028,  6788,  4654,  8586,
          3090,  1999,  1996,  2143,  1000,  6108, 18798,  2080,  2003,  1996,
          4569, 15580,  2102,  2158,  1

# Model

In [18]:
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased'
)

model.to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

# accuracy

In [22]:
def compute_accuracy(
        model,
        data_loader
):
    with torch.no_grad():

        correct_predict = 0
        full_samples = 0

        for _, batch in enumerate(data_loader):

            inputs = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(inputs, attention_mask = attention_mask)
            logits = outputs['logits']

            predicted_labels = torch.argmax(logits, 1)
            full_samples += labels.size(0)
            correct_predict += (predicted_labels == labels).sum()
    
    return correct_predict.float() / full_samples

# Тренировочный цикл

In [25]:
NUM_EPOCHS = 3
optim = torch.optim.Adam(model.parameters(), lr = 0.00005)

In [27]:
start_time = time.time()

for epoch in range(NUM_EPOCHS):
 
    model.train()
    
    for batch_idx, batch in enumerate(train_loader):
    
        ### Prepare data
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        ### Forward
        outputs = model(
            input_ids, 
            attention_mask=attention_mask,
            labels=labels
            )
        loss, logits = outputs['loss'], outputs['logits']
        
        ### Backward
        optim.zero_grad()
        loss.backward()
        optim.step()
    
        ### verbose
        if not batch_idx % 250:
            print(
                f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d}'
                f' | Batch '
                f'{batch_idx:04d}/'
                f'{len(train_loader):04d} | '
                f'Loss: {loss:.4f}')
    
    model.eval()
    with torch.set_grad_enabled(False):
        print(
            f'Training accuracy: '
            f'{compute_accuracy(model, train_loader):.2f}%'f'\nValid accuracy: '
            f'{compute_accuracy(model, valid_loader):.2f}%'
            )
    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')

print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')
print(f'Test accuracy: {compute_accuracy(model, test_loader):.2f}%')

Epoch: 0001/0003 | Batch 0000/2188 | Loss: 0.6453
Epoch: 0001/0003 | Batch 0250/2188 | Loss: 0.1840
Epoch: 0001/0003 | Batch 0500/2188 | Loss: 0.3302
Epoch: 0001/0003 | Batch 0750/2188 | Loss: 0.1972
Epoch: 0001/0003 | Batch 1000/2188 | Loss: 0.4093
Epoch: 0001/0003 | Batch 1250/2188 | Loss: 0.2997
Epoch: 0001/0003 | Batch 1500/2188 | Loss: 0.3943
Epoch: 0001/0003 | Batch 1750/2188 | Loss: 0.3038
Epoch: 0001/0003 | Batch 2000/2188 | Loss: 0.2379
Training accuracy: 0.96%
Valid accuracy: 0.93%
Time elapsed: 13.44 min
Epoch: 0002/0003 | Batch 0000/2188 | Loss: 0.1864
Epoch: 0002/0003 | Batch 0250/2188 | Loss: 0.0147
Epoch: 0002/0003 | Batch 0500/2188 | Loss: 0.0549
Epoch: 0002/0003 | Batch 0750/2188 | Loss: 0.0370
Epoch: 0002/0003 | Batch 1000/2188 | Loss: 0.0755
Epoch: 0002/0003 | Batch 1250/2188 | Loss: 0.1114
Epoch: 0002/0003 | Batch 1500/2188 | Loss: 0.2735
Epoch: 0002/0003 | Batch 1750/2188 | Loss: 0.0207
Epoch: 0002/0003 | Batch 2000/2188 | Loss: 0.1159
Training accuracy: 0.99%
Vali