In [4]:
import sys
print(sys.executable)

# Install required libraries
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers", "datasets", "torch", "accelerate"])


/usr/bin/python3


/usr/bin/python3


0

In [14]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load AG News dataset from Hugging Face
dataset = load_dataset("ag_news")
test_data = dataset["test"]
print(f"Test dataset loaded: {len(test_data)} examples")

Test dataset loaded: 7600 examples


In [8]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "mistralai/Mistral-7B-Instruct-v0.1"
cache_dir = "./model_cache"

# Check if cache has valid model files
def is_cache_valid(cache_path):
    if not os.path.exists(cache_path):
        return False
    required_files = ['pytorch_model.bin', 'model.safetensors', 'config.json']
    return any(os.path.exists(os.path.join(cache_path, f)) for f in required_files)

# Load or download model
if is_cache_valid(cache_dir):
    print("Loading from cache...")
    tokenizer = AutoTokenizer.from_pretrained(cache_dir)
    model = AutoModelForCausalLM.from_pretrained(cache_dir, torch_dtype=torch.float16, device_map="auto")
else:
    print("Downloading model...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

print("Model and tokenizer loaded successfully")

Downloading model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Model and tokenizer loaded successfully


In [15]:
# Define label mapping and create classification prompt
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

def create_prompt(text):
    return f"""Classify the following news headline into one of the categories: World, Sports, Business, Sci/Tech.

Text: "{text}"
Label:"""

print("Label mapping and prompt function created")

Label mapping and prompt function created


In [17]:
# Perform inference on test set
predictions = []
true_labels = []

# First, let's check the structure of test_data
print("Sample example:", test_data[0])

for i in range(min(100, len(test_data))):  # Start with 100 examples for testing
    example = test_data[i]
    
    # Access the text and label from the example
    text = example.get("text") or example.get("content")
    label = example.get("label")
    
    if text is None:
        continue
    
    prompt = create_prompt(text)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
    
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the generated part (after the prompt)
    prediction_text = full_output[len(prompt):].strip()
    
    # Try to match prediction to a label
    matched_label = "Unknown"
    for lbl in label_map.values():
        if lbl.lower() in prediction_text.lower():
            matched_label = lbl
            break
    
    predictions.append(matched_label)
    true_labels.append(label_map[label])

print(f"Inference completed on {len(predictions)} examples")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Sample example: {'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 'label': 2}


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

Inference completed on 100 examples


In [18]:
# Compute accuracy and evaluate baseline
from sklearn.metrics import accuracy_score, classification_report

accuracy = accuracy_score(true_labels, predictions)
print(f"Baseline Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(true_labels, predictions))

Baseline Accuracy: 0.7600

Classification Report:
              precision    recall  f1-score   support

    Business       0.44      1.00      0.62        12
    Sci/Tech       0.77      0.65      0.71        37
      Sports       0.89      0.81      0.85        21
       World       1.00      0.77      0.87        30

    accuracy                           0.76       100
   macro avg       0.78      0.81      0.76       100
weighted avg       0.83      0.76      0.77       100



In [None]:
# Generate confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Get unique labels
labels_list = sorted(label_map.values())

# Compute confusion matrix
cm = confusion_matrix(true_labels, predictions, labels=labels_list)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=labels_list, 
            yticklabels=labels_list,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - LLM Baseline Predictions')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

print("\nConfusion Matrix:")
print(cm)
print("\nLabels order:", labels_list)