In [None]:
# imports
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

In [2]:
# Load your data
df = pd.read_csv('argumentdetection7.csv')

print(f"Total sentences: {len(df)}")
print(df['category'].value_counts())


Total sentences: 860866
evidence    833983
non_info     25221
claim         1662
Name: category, dtype: int64


In [8]:
# Take a stratified sample 
sample_size = 10000

if len(df) > sample_size:  # take either all rows in that category OR sample_size/3, whichever is smaller
    df_sample = df.groupby('category', group_keys=False).apply(
        lambda x: x.sample(min(len(x), sample_size // 3), random_state=11)
    )
else:
    df_sample = df

print(df_sample['category'].value_counts())


evidence    3333
non_info    3333
claim       1662
Name: category, dtype: int64


In [9]:
# Load sentence transformer
model = SentenceTransformer('all-MiniLM-L6-v2')
# Generate embeddings
embeddings = model.encode(df_sample['sentence'].tolist(), 
                         show_progress_bar=True,
                         batch_size=32)

# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
    embeddings, 
    df_sample['category'], 
    test_size=0.2, 
    random_state=42,
    stratify=df_sample['category']
)

# Train simple classifier
classifier = LogisticRegression(max_iter=1000, random_state=11)
classifier.fit(X_train, y_train)

# Evaluate
y_pred = classifier.predict(X_test)

print("Classification Report")
print(classification_report(y_test, y_pred))

print("Confusion Matrix")
cm = confusion_matrix(y_test, y_pred, labels=['claim', 'evidence', 'non_info'])
print(cm)

Batches:   0%|          | 0/261 [00:00<?, ?it/s]

Classification Report
              precision    recall  f1-score   support

       claim       0.87      0.92      0.89       332
    evidence       0.83      0.85      0.84       667
    non_info       0.87      0.81      0.84       667

    accuracy                           0.85      1666
   macro avg       0.85      0.86      0.86      1666
weighted avg       0.85      0.85      0.85      1666

Confusion Matrix
[[305  17  10]
 [ 23 570  74]
 [ 23 101 543]]
