In [2]:
import torch
import numpy as np

from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm import tqdm
import argparse
import os
import pandas as pd

2025-02-01 15:29:19.534910: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
s3_path = 's3://hugging-face-multiclass-textclassification-bucket369/training_data/newsCorpora.csv'
df = pd.read_csv(s3_path, sep='\t', names=["ID", "TITLE", "URL", "PUBLISHER", "CATEGORY", "STORY", "HOSTNAME", "TIMESTAMP"])
df = df[['TITLE','CATEGORY']]

my_dict = {
    'e':'Entertainment',
    'b':'Business',
    't':'Science',
    'm':'Health'
}

def update_cat(x):
    x = my_dict[x]
    return x

df['CATEGORY'] = df['CATEGORY'].apply(lambda x: update_cat(x))
print(df) # This will show up in our cloud watch logs

                                                    TITLE  CATEGORY
0       Fed official says weak data caused by weather,...  Business
1       Fed's Charles Plosser sees high bar for change...  Business
2       US open: Stocks fall after Fed official hints ...  Business
3       Fed risks falling 'behind the curve', Charles ...  Business
4       Fed's Plosser: Nasty Weather Has Curbed Job Gr...  Business
...                                                   ...       ...
422414  Surgeons to remove 4-year-old's rib to rebuild...    Health
422415  Boy to have surgery on esophagus after battery...    Health
422416  Child who swallowed battery to have reconstruc...    Health
422417  Phoenix boy undergoes surgery to repair throat...    Health
422418  Phoenix boy undergoes surgery to repair throat...    Health

[422419 rows x 2 columns]


In [6]:
# This is just a tip
df = df.sample(frac=0.05, random_state=369)
df = df.reset_index(drop=True)
# This is where the tip ends
df

Unnamed: 0,TITLE,CATEGORY
0,Facebook treats you like a lab rat,Science
1,News in Breif,Business
2,"Regulators shine light on shark-infested ""dark...",Business
3,Blood Moon 2014 Dates; Plus Why Does The Moon ...,Science
4,Why “OK Google” Should be Available on Google ...,Science
...,...,...
21116,Medicare reveals physician payment data,Health
21117,Capital One profit rises 10 pct due to lower p...,Business
21118,Mortgage Interest Rates Move Up at Wells Fargo...,Business
21119,Texas cheerleader and critics face off over co...,Entertainment


In [9]:
encode_dict = {}

def encode_cat(x):
    if x not in encode_dict.keys():
        encode_dict[x] = len(encode_dict)
    return encode_dict[x] 


In [10]:
df['ENCODE_CAT'] = df['CATEGORY'].apply(lambda x: encode_cat(x))
df

Unnamed: 0,TITLE,CATEGORY,ENCODE_CAT
0,Facebook treats you like a lab rat,Science,0
1,News in Breif,Business,1
2,"Regulators shine light on shark-infested ""dark...",Business,1
3,Blood Moon 2014 Dates; Plus Why Does The Moon ...,Science,0
4,Why “OK Google” Should be Available on Google ...,Science,0
...,...,...,...
21116,Medicare reveals physician payment data,Health,3
21117,Capital One profit rises 10 pct due to lower p...,Business,1
21118,Mortgage Interest Rates Move Up at Wells Fargo...,Business,1
21119,Texas cheerleader and critics face off over co...,Entertainment,2


In [11]:
encode_dict

{'Science': 0, 'Business': 1, 'Entertainment': 2, 'Health': 3}

In [12]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

inputs = tokenizer.encode_plus(
    "I love soccer and mixed martial arts",
    "I love roses! Let's have some fun and watch a few movies",
    add_special_tokens=True,
    max_length=20,
    padding='max_length',
    truncation=True,
    return_token_type_ids=True,
    return_attention_mask=True
)

print("Inputs IDs:", inputs['input_ids'])
print("Attention Mask", inputs['attention_mask'])
print("Token tyoe IDs:", inputs['token_type_ids'])


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


Inputs IDs: [101, 1045, 2293, 4715, 1998, 3816, 7761, 2840, 102, 1045, 2293, 10529, 999, 2292, 1005, 1055, 2031, 2070, 4569, 102]
Attention Mask [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Token tyoe IDs: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
