In [1]:
import os
import torch
import joblib
import pandas as pd
from transformers import DistilBertTokenizerFast
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizerFast, get_linear_schedule_with_warmup
from datasets import Dataset
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# ─── 1) Define your model class (identical to training) ──────────────────
class ChatModerationModel(nn.Module):
    def __init__(self, num_labels, num_statuses):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        
        # Freeze first 3 transformer layers
        for layer in self.bert.transformer.layer[:3]:
            for param in layer.parameters():
                param.requires_grad = False
                
        self.dropout = nn.Dropout(0.6)  # Increased dropout
        self.label_head = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.4),  # Additional dropout
            nn.Linear(256, num_labels)
        )
        self.status_head = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.4),  # Additional dropout
            nn.Linear(256, num_statuses)
        )

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(out.last_hidden_state[:, 0])
        return self.label_head(pooled), self.status_head(pooled)


In [3]:

# ─── 2) Load tokenizer and encoders ───────────────────────────────────────
model_dir = r"U:\N\save"

tokenizer  = DistilBertTokenizerFast.from_pretrained(model_dir)
label_enc  = joblib.load(os.path.join(model_dir, "label_encoder.pkl"))
status_enc = joblib.load(os.path.join(model_dir, "status_encoder.pkl"))

# Reconstruct the status_map exactly as in training
status_map = {0: 'accepted', 1: 'accepted', 2: 'accepted',
              3: 'pending',  4: 'pending',
              5: 'blocked',  6: 'blocked', 7: 'blocked'}


In [4]:

# ─── 3) Instantiate model and load weights ───────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ChatModerationModel(
    num_labels   = len(label_enc.classes_),
    num_statuses = len(status_enc.classes_)
).to(device)

model.load_state_dict(torch.load(os.path.join(model_dir, "model_state.pt"), map_location=device))
model.eval()


ChatModerationModel(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Li

In [5]:

# ─── 4) Prediction helper ────────────────────────────────────────────────
def predict(text: str):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True, padding="max_length", max_length=128
    ).to(device)

    with torch.no_grad():
        logits_label, logits_status = model(**inputs)

    lbl_idx = logits_label.argmax(dim=1).item()
    stt_idx = logits_status.argmax(dim=1).item()

    print(f"Text: {text!r}")
    print(f"→ Label:    {label_enc.inverse_transform([lbl_idx])[0]}")
    print(f"→ Status:   {status_enc.inverse_transform([stt_idx])[0]}")
    print(f"→ Decision: {status_map[stt_idx]}")

# ─── 5) Try it out! ───────────────────────────────────────────────────────



In [None]:
predict("")



Text: 'OKAY'
→ Label:    Nothing Wrong
→ Status:   0
→ Decision: accepted
