In [1]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
from datasets import Dataset  # Import Dataset here
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load and prepare data
df = pd.read_csv(r"input-files\training_data.csv")
df = df.head(1000)
texts = df['input_text'].tolist()
labels = df['Category'].tolist()

In [5]:
le = LabelEncoder()
encoded_labels = le.fit_transform(labels)
train_examples = [InputExample(texts=[text, text], label=label) for text, label in zip(texts, encoded_labels)]


In [6]:
# Define a simple Transformer + Pooling model
word_embedding_model = models.Transformer('all-MiniLM-L6-v2')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Define loss function for single-sentence classification
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(set(encoded_labels)))

# Fine-tune with explicit single-sentence mode
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    warmup_steps=10,
    show_progress_bar=True
)



Step,Training Loss


In [7]:
model.save(r"finetuned-model/finetuned_embed_model")

In [29]:
train_loss

SoftmaxLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertModel'})
    (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  )
  (classifier): Linear(in_features=1152, out_features=3, bias=True)
  (loss_fct): CrossEntropyLoss()
)