In [None]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

model_path = "Models/ClassificationModel"

# Load the trained model
model = DistilBertForSequenceClassification.from_pretrained(model_path)

# Load the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(model_path)

print("Model and tokenizer loaded successfully!")

  from .autonotebook import tqdm as notebook_tqdm


Model and tokenizer loaded successfully!


In [3]:
import torch
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
    with torch.no_grad():
        outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=1).item()
    categories = ["Name", "Phone Number", "Amount", "Account Number"]
    return categories[prediction]

# Example Prediction
example_text = "Transfer 0 rupees to nitin"
print("Predicted Category:", predict(example_text))

Predicted Category: Amount


In [5]:
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.3, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.3, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.3, inplace=False)


In [6]:
%pip install torchinfo

Collecting torchinfoNote: you may need to restart the kernel to use updated packages.

  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0



[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
from torchinfo import summary
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Create dummy input
inputs = tokenizer("This is a test sentence", return_tensors="pt")

# Use torchinfo to print summary
summary(model, input_data=(inputs['input_ids'],), depth=3)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Layer (type:depth-idx)                                  Output Shape              Param #
DistilBertForSequenceClassification                     [1, 2]                    --
├─DistilBertModel: 1-1                                  [1, 7, 768]               --
│    └─Embeddings: 2-1                                  [1, 7, 768]               --
│    │    └─Embedding: 3-1                              [1, 7, 768]               23,440,896
│    │    └─Embedding: 3-2                              [1, 7, 768]               393,216
│    │    └─LayerNorm: 3-3                              [1, 7, 768]               1,536
│    │    └─Dropout: 3-4                                [1, 7, 768]               --
│    └─Transformer: 2-2                                 [1, 7, 768]               --
│    │    └─ModuleList: 3-5                             --                        42,527,232
├─Linear: 1-2                                           [1, 768]                  590,592
├─Dropout: 1-3                 