### Fine-tune BERT for classification

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset

# Load CSV files
train_df = pd.read_csv('data/collections/train.txt', sep='\t')
test_df = pd.read_csv('data/collections/test.txt', sep='\t')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
label_map = {'positive': 0, 'neutral': 1, 'negative': 2}
train_df['label'] = train_df['sentiment'].map(label_map)
test_df['label'] = test_df['sentiment'].map(label_map)

In [3]:
train_df.head()

Unnamed: 0,id,sentiment,tweet,label
0,638165350966669312,negative,Nicki's butt is just too big like c'mon that's...,2
1,640169120600862720,neutral,Haruna Lukmon may av just played himself out o...,1
2,635946254254624769,neutral,Zach Putnam will be unavailable for the White ...,1
3,667121920333258752,negative,"""""""@daithimckay what about the victims of IRA ...",2
4,628637519643586560,neutral,"""""""LHP Matt Boyd, traded to @tigers in David P...",1


In [4]:
test_dataset = Dataset.from_pandas(test_df[['tweet', 'label']])

train_dataset, val_dataset = train_test_split(train_df, test_size=0.1, random_state=42)
train_dataset = Dataset.from_pandas(train_dataset[['tweet', 'label']])
val_dataset = Dataset.from_pandas(val_dataset[['tweet', 'label']])

In [5]:
from transformers import DistilBertTokenizer

# Load the pre-trained DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples['tweet'], padding='max_length', truncation=True, max_length=128)

# Tokenize the datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

Map: 100%|██████████| 16781/16781 [00:05<00:00, 3290.33 examples/s]
Map: 100%|██████████| 1865/1865 [00:00<00:00, 3189.33 examples/s]
Map: 100%|██████████| 4662/4662 [00:01<00:00, 3235.61 examples/s]


In [6]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments

# Load the pre-trained DistilBERT model for sequence classification
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
import torch

# Check if GPU is available and move model to GPU if so
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [8]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,  # Increase batch size
    per_device_eval_batch_size=64,   # Larger evaluation batch size
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    fp16=True,
    gradient_accumulation_steps=2,  # Accumulate gradients over 2 steps if batch size is too large
)

# Define the Trainer
trainer = Trainer(
    model=model,                         # the model to be trained
    args=training_args,                  # training arguments
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,            # evaluation dataset
    tokenizer=tokenizer,                 # tokenizer to handle the tokenization
)

  trainer = Trainer(


In [9]:
trainer.train()

  1%|▏         | 10/786 [01:23<1:47:59,  8.35s/it]

{'loss': 1.0591, 'grad_norm': 2.382056951522827, 'learning_rate': 1.974554707379135e-05, 'epoch': 0.04}


  3%|▎         | 20/786 [02:50<1:50:01,  8.62s/it]

{'loss': 1.0353, 'grad_norm': 1.9395687580108643, 'learning_rate': 1.9491094147582698e-05, 'epoch': 0.08}


  4%|▍         | 30/786 [04:15<1:45:50,  8.40s/it]

{'loss': 0.9672, 'grad_norm': 3.399838447570801, 'learning_rate': 1.923664122137405e-05, 'epoch': 0.11}


  5%|▌         | 40/786 [05:39<1:44:39,  8.42s/it]

{'loss': 0.8599, 'grad_norm': 4.703188419342041, 'learning_rate': 1.8982188295165395e-05, 'epoch': 0.15}


  6%|▋         | 50/786 [07:04<1:44:19,  8.51s/it]

{'loss': 0.855, 'grad_norm': 8.077014923095703, 'learning_rate': 1.8727735368956746e-05, 'epoch': 0.19}


  8%|▊         | 60/786 [08:28<1:40:56,  8.34s/it]

{'loss': 0.8457, 'grad_norm': 5.116960525512695, 'learning_rate': 1.847328244274809e-05, 'epoch': 0.23}


  9%|▉         | 70/786 [09:52<1:40:05,  8.39s/it]

{'loss': 0.7874, 'grad_norm': 10.23357105255127, 'learning_rate': 1.8218829516539443e-05, 'epoch': 0.27}


 10%|█         | 80/786 [11:16<1:38:17,  8.35s/it]

{'loss': 0.734, 'grad_norm': 7.499925136566162, 'learning_rate': 1.796437659033079e-05, 'epoch': 0.3}


 11%|█▏        | 90/786 [12:39<1:36:40,  8.33s/it]

{'loss': 0.7396, 'grad_norm': 8.147540092468262, 'learning_rate': 1.770992366412214e-05, 'epoch': 0.34}


 13%|█▎        | 100/786 [14:02<1:34:23,  8.26s/it]

{'loss': 0.8002, 'grad_norm': 7.381656646728516, 'learning_rate': 1.7455470737913488e-05, 'epoch': 0.38}


 14%|█▍        | 110/786 [15:25<1:33:34,  8.31s/it]

{'loss': 0.6927, 'grad_norm': 6.926806449890137, 'learning_rate': 1.7201017811704836e-05, 'epoch': 0.42}


 15%|█▌        | 120/786 [16:46<1:30:29,  8.15s/it]

{'loss': 0.7194, 'grad_norm': 11.106562614440918, 'learning_rate': 1.6946564885496184e-05, 'epoch': 0.46}


 17%|█▋        | 130/786 [18:08<1:30:13,  8.25s/it]

{'loss': 0.7143, 'grad_norm': 6.6118879318237305, 'learning_rate': 1.6692111959287533e-05, 'epoch': 0.5}


 18%|█▊        | 140/786 [19:30<1:27:44,  8.15s/it]

{'loss': 0.7298, 'grad_norm': 8.503707885742188, 'learning_rate': 1.643765903307888e-05, 'epoch': 0.53}


 19%|█▉        | 150/786 [20:51<1:26:06,  8.12s/it]

{'loss': 0.6555, 'grad_norm': 5.410696029663086, 'learning_rate': 1.618320610687023e-05, 'epoch': 0.57}


 20%|██        | 160/786 [22:13<1:25:03,  8.15s/it]

{'loss': 0.6984, 'grad_norm': 9.54676628112793, 'learning_rate': 1.5928753180661577e-05, 'epoch': 0.61}


 22%|██▏       | 170/786 [23:34<1:23:43,  8.15s/it]

{'loss': 0.652, 'grad_norm': 8.367912292480469, 'learning_rate': 1.567430025445293e-05, 'epoch': 0.65}


 23%|██▎       | 180/786 [24:56<1:22:51,  8.20s/it]

{'loss': 0.6721, 'grad_norm': 5.939520835876465, 'learning_rate': 1.5419847328244274e-05, 'epoch': 0.69}


 24%|██▍       | 190/786 [26:18<1:21:27,  8.20s/it]

{'loss': 0.6866, 'grad_norm': 8.819945335388184, 'learning_rate': 1.5165394402035624e-05, 'epoch': 0.72}


 25%|██▌       | 200/786 [27:39<1:19:18,  8.12s/it]

{'loss': 0.6571, 'grad_norm': 13.823393821716309, 'learning_rate': 1.4910941475826972e-05, 'epoch': 0.76}


 27%|██▋       | 210/786 [29:00<1:17:37,  8.09s/it]

{'loss': 0.6893, 'grad_norm': 5.979026794433594, 'learning_rate': 1.4656488549618322e-05, 'epoch': 0.8}


 28%|██▊       | 220/786 [30:22<1:17:01,  8.16s/it]

{'loss': 0.6627, 'grad_norm': 6.348121166229248, 'learning_rate': 1.4402035623409672e-05, 'epoch': 0.84}


 29%|██▉       | 230/786 [31:43<1:15:07,  8.11s/it]

{'loss': 0.7104, 'grad_norm': 8.661341667175293, 'learning_rate': 1.4147582697201019e-05, 'epoch': 0.88}


 31%|███       | 240/786 [33:03<1:12:16,  7.94s/it]

{'loss': 0.6576, 'grad_norm': 6.143622875213623, 'learning_rate': 1.3893129770992369e-05, 'epoch': 0.91}


 32%|███▏      | 250/786 [34:31<1:16:55,  8.61s/it]

{'loss': 0.6857, 'grad_norm': 5.898812770843506, 'learning_rate': 1.3638676844783715e-05, 'epoch': 0.95}


 33%|███▎      | 260/786 [35:58<1:15:59,  8.67s/it]

{'loss': 0.6435, 'grad_norm': 8.719658851623535, 'learning_rate': 1.3384223918575065e-05, 'epoch': 0.99}


                                                   
 33%|███▎      | 262/786 [37:31<1:15:44,  8.67s/it]

{'eval_loss': 0.6444394588470459, 'eval_runtime': 74.2664, 'eval_samples_per_second': 25.112, 'eval_steps_per_second': 0.404, 'epoch': 1.0}


 34%|███▍      | 270/786 [38:35<1:26:40, 10.08s/it]

{'loss': 0.6149, 'grad_norm': 15.69955062866211, 'learning_rate': 1.3129770992366414e-05, 'epoch': 1.03}


 36%|███▌      | 280/786 [39:58<1:09:23,  8.23s/it]

{'loss': 0.5938, 'grad_norm': 6.204497337341309, 'learning_rate': 1.2875318066157762e-05, 'epoch': 1.07}


 37%|███▋      | 290/786 [41:19<1:07:06,  8.12s/it]

{'loss': 0.6066, 'grad_norm': 9.734415054321289, 'learning_rate': 1.262086513994911e-05, 'epoch': 1.1}


 38%|███▊      | 300/786 [42:40<1:05:41,  8.11s/it]

{'loss': 0.6098, 'grad_norm': 7.9809041023254395, 'learning_rate': 1.236641221374046e-05, 'epoch': 1.14}


 39%|███▉      | 310/786 [44:00<1:03:58,  8.06s/it]

{'loss': 0.6029, 'grad_norm': 10.132811546325684, 'learning_rate': 1.2111959287531807e-05, 'epoch': 1.18}


 41%|████      | 320/786 [45:22<1:03:04,  8.12s/it]

{'loss': 0.6153, 'grad_norm': 7.208290100097656, 'learning_rate': 1.1857506361323157e-05, 'epoch': 1.22}


 42%|████▏     | 330/786 [46:43<1:01:13,  8.05s/it]

{'loss': 0.5653, 'grad_norm': 6.845522403717041, 'learning_rate': 1.1603053435114503e-05, 'epoch': 1.26}


 43%|████▎     | 340/786 [48:04<59:54,  8.06s/it]  

{'loss': 0.5676, 'grad_norm': 7.617427825927734, 'learning_rate': 1.1348600508905853e-05, 'epoch': 1.3}


 45%|████▍     | 350/786 [49:25<58:53,  8.11s/it]  

{'loss': 0.5858, 'grad_norm': 8.889451026916504, 'learning_rate': 1.1094147582697202e-05, 'epoch': 1.33}


 46%|████▌     | 360/786 [50:50<1:00:45,  8.56s/it]

{'loss': 0.5698, 'grad_norm': 8.123685836791992, 'learning_rate': 1.0839694656488552e-05, 'epoch': 1.37}


 47%|████▋     | 370/786 [52:12<57:38,  8.31s/it]  

{'loss': 0.6562, 'grad_norm': 8.75717544555664, 'learning_rate': 1.0585241730279898e-05, 'epoch': 1.41}


 48%|████▊     | 380/786 [53:36<56:30,  8.35s/it]

{'loss': 0.6168, 'grad_norm': 9.699156761169434, 'learning_rate': 1.0330788804071248e-05, 'epoch': 1.45}


 50%|████▉     | 390/786 [55:00<55:32,  8.42s/it]

{'loss': 0.5791, 'grad_norm': 7.472648620605469, 'learning_rate': 1.0076335877862595e-05, 'epoch': 1.49}


 51%|█████     | 400/786 [56:21<52:00,  8.08s/it]

{'loss': 0.5953, 'grad_norm': 8.680340766906738, 'learning_rate': 9.821882951653945e-06, 'epoch': 1.52}


 52%|█████▏    | 410/786 [57:42<50:04,  7.99s/it]

{'loss': 0.5398, 'grad_norm': 6.5267462730407715, 'learning_rate': 9.567430025445293e-06, 'epoch': 1.56}


 53%|█████▎    | 420/786 [59:03<51:04,  8.37s/it]

{'loss': 0.6124, 'grad_norm': 10.26608943939209, 'learning_rate': 9.312977099236641e-06, 'epoch': 1.6}


 55%|█████▍    | 430/786 [1:00:25<47:30,  8.01s/it]

{'loss': 0.6004, 'grad_norm': 9.702875137329102, 'learning_rate': 9.058524173027991e-06, 'epoch': 1.64}


 56%|█████▌    | 440/786 [1:01:46<46:23,  8.05s/it]

{'loss': 0.5827, 'grad_norm': 8.34211254119873, 'learning_rate': 8.80407124681934e-06, 'epoch': 1.68}


 57%|█████▋    | 450/786 [1:03:05<43:07,  7.70s/it]

{'loss': 0.5489, 'grad_norm': 7.30422306060791, 'learning_rate': 8.549618320610688e-06, 'epoch': 1.71}


 59%|█████▊    | 460/786 [1:04:26<44:30,  8.19s/it]

{'loss': 0.5644, 'grad_norm': 7.2086896896362305, 'learning_rate': 8.295165394402036e-06, 'epoch': 1.75}


 60%|█████▉    | 470/786 [1:05:47<42:10,  8.01s/it]

{'loss': 0.5654, 'grad_norm': 8.922070503234863, 'learning_rate': 8.040712468193384e-06, 'epoch': 1.79}


 61%|██████    | 480/786 [1:07:12<43:10,  8.47s/it]

{'loss': 0.5735, 'grad_norm': 8.3564453125, 'learning_rate': 7.786259541984733e-06, 'epoch': 1.83}


 62%|██████▏   | 490/786 [1:08:39<42:41,  8.65s/it]

{'loss': 0.5689, 'grad_norm': 9.377689361572266, 'learning_rate': 7.531806615776082e-06, 'epoch': 1.87}


 64%|██████▎   | 500/786 [1:10:04<40:10,  8.43s/it]

{'loss': 0.5389, 'grad_norm': 15.032976150512695, 'learning_rate': 7.27735368956743e-06, 'epoch': 1.9}


 65%|██████▍   | 510/786 [1:11:31<39:00,  8.48s/it]

{'loss': 0.5746, 'grad_norm': 13.062674522399902, 'learning_rate': 7.022900763358779e-06, 'epoch': 1.94}


 66%|██████▌   | 520/786 [1:12:56<37:44,  8.51s/it]

{'loss': 0.5999, 'grad_norm': 17.32914161682129, 'learning_rate': 6.768447837150128e-06, 'epoch': 1.98}


                                                   
 67%|██████▋   | 525/786 [1:14:45<33:18,  7.66s/it]

{'eval_loss': 0.6486327648162842, 'eval_runtime': 69.1423, 'eval_samples_per_second': 26.973, 'eval_steps_per_second': 0.434, 'epoch': 2.0}


 67%|██████▋   | 530/786 [1:15:26<55:28, 13.00s/it]  

{'loss': 0.5442, 'grad_norm': 7.668583393096924, 'learning_rate': 6.5139949109414765e-06, 'epoch': 2.02}


 69%|██████▊   | 540/786 [1:16:46<33:32,  8.18s/it]

{'loss': 0.5338, 'grad_norm': 7.029555797576904, 'learning_rate': 6.259541984732826e-06, 'epoch': 2.06}


 70%|██████▉   | 550/786 [1:18:07<31:36,  8.04s/it]

{'loss': 0.4933, 'grad_norm': 6.667935848236084, 'learning_rate': 6.005089058524174e-06, 'epoch': 2.1}


 71%|███████   | 560/786 [1:19:27<30:27,  8.09s/it]

{'loss': 0.5302, 'grad_norm': 10.470666885375977, 'learning_rate': 5.750636132315522e-06, 'epoch': 2.13}


 73%|███████▎  | 570/786 [1:20:48<28:59,  8.05s/it]

{'loss': 0.4919, 'grad_norm': 9.969108581542969, 'learning_rate': 5.496183206106871e-06, 'epoch': 2.17}


 74%|███████▍  | 580/786 [1:22:08<27:32,  8.02s/it]

{'loss': 0.5294, 'grad_norm': 10.026636123657227, 'learning_rate': 5.2417302798982195e-06, 'epoch': 2.21}


 75%|███████▌  | 590/786 [1:23:28<26:08,  8.00s/it]

{'loss': 0.5247, 'grad_norm': 7.41146993637085, 'learning_rate': 4.987277353689568e-06, 'epoch': 2.25}


 76%|███████▋  | 600/786 [1:24:49<24:55,  8.04s/it]

{'loss': 0.4992, 'grad_norm': 6.894505500793457, 'learning_rate': 4.732824427480917e-06, 'epoch': 2.29}


 78%|███████▊  | 610/786 [1:26:09<23:34,  8.04s/it]

{'loss': 0.5528, 'grad_norm': 10.676773071289062, 'learning_rate': 4.478371501272265e-06, 'epoch': 2.32}


 79%|███████▉  | 620/786 [1:27:30<22:14,  8.04s/it]

{'loss': 0.4662, 'grad_norm': 12.244253158569336, 'learning_rate': 4.2239185750636135e-06, 'epoch': 2.36}


 80%|████████  | 630/786 [1:28:50<20:50,  8.01s/it]

{'loss': 0.484, 'grad_norm': 9.889323234558105, 'learning_rate': 3.969465648854962e-06, 'epoch': 2.4}


 81%|████████▏ | 640/786 [1:30:11<19:28,  8.01s/it]

{'loss': 0.5485, 'grad_norm': 10.018312454223633, 'learning_rate': 3.7150127226463105e-06, 'epoch': 2.44}


 83%|████████▎ | 650/786 [1:31:31<18:08,  8.00s/it]

{'loss': 0.4884, 'grad_norm': 9.947417259216309, 'learning_rate': 3.460559796437659e-06, 'epoch': 2.48}


 84%|████████▍ | 660/786 [1:32:51<16:46,  7.99s/it]

{'loss': 0.4671, 'grad_norm': 7.274786949157715, 'learning_rate': 3.206106870229008e-06, 'epoch': 2.51}


 85%|████████▌ | 670/786 [1:34:12<15:39,  8.10s/it]

{'loss': 0.4981, 'grad_norm': 7.524981498718262, 'learning_rate': 2.951653944020356e-06, 'epoch': 2.55}


 87%|████████▋ | 680/786 [1:35:33<14:13,  8.05s/it]

{'loss': 0.5069, 'grad_norm': 11.14845085144043, 'learning_rate': 2.6972010178117053e-06, 'epoch': 2.59}


 88%|████████▊ | 690/786 [1:36:53<12:51,  8.04s/it]

{'loss': 0.5323, 'grad_norm': 9.247360229492188, 'learning_rate': 2.4427480916030536e-06, 'epoch': 2.63}


 89%|████████▉ | 700/786 [1:38:14<11:40,  8.14s/it]

{'loss': 0.4854, 'grad_norm': 9.285196304321289, 'learning_rate': 2.1882951653944023e-06, 'epoch': 2.67}


 90%|█████████ | 710/786 [1:39:36<10:19,  8.15s/it]

{'loss': 0.4918, 'grad_norm': 10.760952949523926, 'learning_rate': 1.933842239185751e-06, 'epoch': 2.7}


 92%|█████████▏| 720/786 [1:40:56<08:50,  8.03s/it]

{'loss': 0.5368, 'grad_norm': 10.877464294433594, 'learning_rate': 1.6793893129770995e-06, 'epoch': 2.74}


 93%|█████████▎| 730/786 [1:42:16<07:28,  8.02s/it]

{'loss': 0.4939, 'grad_norm': 8.963440895080566, 'learning_rate': 1.424936386768448e-06, 'epoch': 2.78}


 94%|█████████▍| 740/786 [1:43:37<06:12,  8.10s/it]

{'loss': 0.5401, 'grad_norm': 10.394055366516113, 'learning_rate': 1.1704834605597967e-06, 'epoch': 2.82}


 95%|█████████▌| 750/786 [1:44:58<04:49,  8.04s/it]

{'loss': 0.5225, 'grad_norm': 8.99874210357666, 'learning_rate': 9.160305343511451e-07, 'epoch': 2.86}


 97%|█████████▋| 760/786 [1:46:19<03:30,  8.10s/it]

{'loss': 0.4916, 'grad_norm': 7.463350296020508, 'learning_rate': 6.615776081424936e-07, 'epoch': 2.9}


 98%|█████████▊| 770/786 [1:47:39<02:09,  8.09s/it]

{'loss': 0.5047, 'grad_norm': 8.169717788696289, 'learning_rate': 4.071246819338423e-07, 'epoch': 2.93}


 99%|█████████▉| 780/786 [1:48:59<00:48,  8.01s/it]

{'loss': 0.4997, 'grad_norm': 8.768998146057129, 'learning_rate': 1.5267175572519085e-07, 'epoch': 2.97}


                                                   
100%|██████████| 786/786 [1:50:58<00:00,  8.47s/it]

{'eval_loss': 0.6499510407447815, 'eval_runtime': 69.2581, 'eval_samples_per_second': 26.928, 'eval_steps_per_second': 0.433, 'epoch': 2.99}
{'train_runtime': 6658.3078, 'train_samples_per_second': 7.561, 'train_steps_per_second': 0.118, 'train_loss': 0.6156937730221348, 'epoch': 2.99}





TrainOutput(global_step=786, training_loss=0.6156937730221348, metrics={'train_runtime': 6658.3078, 'train_samples_per_second': 7.561, 'train_steps_per_second': 0.118, 'total_flos': 1664681251908096.0, 'train_loss': 0.6156937730221348, 'epoch': 2.994285714285714})

In [10]:
eval_results = trainer.evaluate()
print(eval_results)

100%|██████████| 30/30 [01:07<00:00,  2.26s/it]

{'eval_loss': 0.6499510407447815, 'eval_runtime': 70.1647, 'eval_samples_per_second': 26.58, 'eval_steps_per_second': 0.428, 'epoch': 2.994285714285714}





In [11]:
# Save the model and tokenizer
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")


('./fine_tuned_model\\tokenizer_config.json',
 './fine_tuned_model\\special_tokens_map.json',
 './fine_tuned_model\\vocab.txt',
 './fine_tuned_model\\added_tokens.json')

### Run test on transformer

In [15]:
# Load the fine-tuned model and tokenizer
model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model")

# Example inference
input_text = "Example positive tweet. Very good comments"
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

# Get predictions
outputs = model(**inputs)
logits = outputs.logits
prediction = logits.argmax(dim=-1).item()
print(f"Predicted class: {prediction}")


Predicted class: 0
