In [1]:
!pip install transformers datasets torch torchvision



In [2]:

from datasets import load_dataset

dataset = load_dataset("sem_eval_2018_task_1", "subtask5.english")

print(dataset["validation"][-1])


{'ID': '2018-En-03386', 'Tweet': 'I am really flattered and happy to hear those complements for my blog! You guys motivates me to write more for my blog. Thank you! sml 💞', 'anger': False, 'anticipation': False, 'disgust': False, 'fear': False, 'joy': True, 'love': False, 'optimism': True, 'pessimism': False, 'sadness': False, 'surprise': False, 'trust': False}


In [7]:
from transformers import BertTokenizer
import torch

def tokenize_data(dataset):
    # Load the BERT tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # Define the tokenization function
    def preprocess(example):
        # Tokenize the text
        encoding = tokenizer(
            example["Tweet"],
            padding="max_length",  # Pad to max length
            truncation=True,       # Truncate if too long
            max_length=128,        # Define max token length
        )
        # Convert labels to a tensor of floats
        labels = [example[label] for label in dataset["train"].features.keys() if label not in ["ID", "Tweet"]]
        encoding["labels"] = torch.tensor(labels, dtype=torch.float)
        return encoding

    # Apply the tokenization function to the dataset
    encoded_dataset = dataset.map(preprocess, batched=False)
    
    # Set format to PyTorch
    encoded_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
    
    return encoded_dataset


# Tokenize the dataset
encoded_dataset = tokenize_data(dataset)

# Print the keys of the last data point in the validation set
print(encoded_dataset["validation"][-1].keys())


Map:   0%|          | 0/6838 [00:00<?, ? examples/s]

Map:   0%|          | 0/3259 [00:00<?, ? examples/s]

Map:   0%|          | 0/886 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
