In [7]:
import pandas as pd
from flair.models import TARSClassifier
from flair.data import Sentence
from tqdm import tqdm
from flair.trainers import ModelTrainer

In [8]:

# Load dataset
ag_test = pd.read_csv('../data/AG News/test.csv')
ag_train = pd.read_csv('../data/AG News/train.csv')

# Define class mapping
class_mapping = {1: "World", 2: "Sports", 3: "Business", 4: "Science"}

# Apply the mapping
ag_test['Class'] = ag_test['Class Index'].replace(class_mapping)
ag_train['Class'] = ag_train['Class Index'].replace(class_mapping)

# Load pre-trained TARS model
tars = TARSClassifier.load('tars-base')

# Define possible topic labels
topic_labels = list(class_mapping.values())
tars.add_and_switch_to_new_task("ag_news_zero_shot", label_dictionary=topic_labels, label_type="classification")

# Perform zero-shot classification
predictions = []
true_labels = ag_test['Class'].tolist()


2025-05-21 22:35:09,812 TARS initialized without a task. You need to call .add_and_switch_to_new_task() before training this model


In [9]:
for text in tqdm(ag_test['Description'].head(4)):
    sentence = Sentence(text)
    tars.predict_zero_shot(sentence, topic_labels)  # Zero-shot prediction

    # Ensure there's a predicted label before accessing it
    if sentence.labels:
        predictions.append(sentence.labels[0].value)
    else:
        predictions.append("Unknown")  # Fallback for empty predictions


100%|██████████| 4/4 [00:00<00:00, 17.90it/s]


In [10]:
predictions

['Business', 'Science', 'Science', 'Sports']

In [10]:

# Compute accuracy
accuracy = sum([1 for pred, true in zip(predictions, true_labels) if pred == true]) / len(true_labels)
print(f"TARS Zero-Shot Accuracy: {accuracy * 100:.2f}%")


TARS Zero-Shot Accuracy: 21.83%


In [12]:
# Convert predictions to pandas Series and show distribution
pd.Series(predictions).value_counts()

Business    2079
Sports      1892
World       1795
Unknown     1445
Science      394
Name: count, dtype: int64

In [13]:
# Convert predictions to pandas Series and show distribution
pd.Series(true_labels).value_counts()

Business    1900
Science     1900
Sports      1900
World       1900
Name: count, dtype: int64