In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
import torch
from torch.utils.data import DataLoader, TensorDataset
print(torch.cuda.is_available())
DEVICE = torch.device('cuda')
print(DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


True
cuda


In [8]:
# Load dataframe
df = pd.read_csv('../datasets/sentiment_dataset.csv',
                names=['sentiment', 'text'],
                encoding='utf-8', encoding_errors='replace')
# Split data into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(df['text'], df['sentiment'], test_size=0.2, random_state=42)

# Load pretrained DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Tokenize texts and convert to tensors
train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, max_length=128)
val_encodings = tokenizer(list(val_texts), truncation=True, padding=True, max_length=128)

train_labels_encoded = [tokenizer.encode(label, add_special_tokens=False)[0] for label in train_labels]
val_labels_encoded = [tokenizer.encode(label, add_special_tokens=False)[0] for label in val_labels]

label_to_index = {'negative': 0, 'neutral': 1, 'positive': 2}
train_labels_encoded = [label_to_index[label] for label in train_labels]
val_labels_encoded = [label_to_index[label] for label in val_labels]

train_dataset = TensorDataset(
    torch.tensor(train_encodings['input_ids']),
    torch.tensor(train_encodings['attention_mask']),
    torch.tensor(train_labels_encoded)
)

val_dataset = TensorDataset(
    torch.tensor(val_encodings['input_ids']),
    torch.tensor(val_encodings['attention_mask']),
    torch.tensor(val_labels_encoded)
)

# Load pretrained DistilBERT model
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
# Training parameters
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 6
batch_size = 32

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

train_loss_list = []
val_loss_list = []
val_accuracy_list = []
train_accuracy_list = []
# Fine-tuning loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    num_batches = 0
    temp_train_accuracy_list = []
    for batch in train_dataloader:
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        labels = labels.to(DEVICE)
        model = model.to(DEVICE)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        num_batches += 1
        loss.backward()
        optimizer.step()
        
        predicted_labels = torch.argmax(outputs.logits, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
        total_predictions = labels.size(0)
        accuracy = correct_predictions / total_predictions
        temp_train_accuracy_list.append(accuracy)
    avg_train_accuracy = sum(temp_train_accuracy_list) / len(temp_train_accuracy_list)
    avg_loss = total_loss / len(train_dataloader)
    train_loss_list.append(avg_loss)
    train_accuracy_list.append(avg_train_accuracy)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}, Train Accuracy: {100 * avg_train_accuracy:.2f}')
    # Validation
    model.eval()
    val_loss = 0
    num_batches = 0
    temp_val_accuracy_list = []
    for batch in val_dataloader:
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        labels = labels.to(DEVICE)
        model = model.to(DEVICE)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        val_loss += outputs.loss.item()
        num_batches += 1
        # Calculate accuracy
        predicted_labels = torch.argmax(outputs.logits, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
        total_predictions = labels.size(0)
        accuracy = correct_predictions / total_predictions
        temp_val_accuracy_list.append(accuracy)

    avg_val_loss = val_loss / num_batches
    avg_val_accuracy = sum(temp_val_accuracy_list) / len(temp_val_accuracy_list)

    val_loss_list.append(avg_val_loss)
    val_accuracy_list.append(avg_val_accuracy)

    print(f'Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss}, Validation Accuracy: {100*avg_val_accuracy:.2f}%')
# Evaluation (optional)
# After training, you can evaluate the model on a separate test set if available
model.save_pretrained("../flair_model")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/6, Loss: 0.5638235094117336, Train Accuracy: 0.7566598360655737
Epoch 1/6, Validation Loss: 0.3572952949231671, Validation Accuracy: 0.8576612903225806
Epoch 2/6, Loss: 0.2673838506834429, Train Accuracy: 0.8985655737704918
Epoch 2/6, Validation Loss: 0.36078875103304464, Validation Accuracy: 0.8568548387096774
Epoch 3/6, Loss: 0.1458199705745353, Train Accuracy: 0.9492827868852459
Epoch 3/6, Validation Loss: 0.4350225581757484, Validation Accuracy: 0.8649193548387096
Epoch 4/6, Loss: 0.08256977375932649, Train Accuracy: 0.9723360655737705
Epoch 4/6, Validation Loss: 0.54813377096528, Validation Accuracy: 0.8528225806451613
Epoch 5/6, Loss: 0.04871327719330543, Train Accuracy: 0.9864241803278688
Epoch 5/6, Validation Loss: 0.644670850326938, Validation Accuracy: 0.8352822580645162
Epoch 6/6, Loss: 0.03029251960297039, Train Accuracy: 0.9910348360655737
Epoch 6/6, Validation Loss: 0.6924170566001727, Validation Accuracy: 0.8596774193548387
