In [4]:
import pandas as pd
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import accuracy_score
import os
from tqdm import tqdm  # 导入tqdm库来添加进度条

# Check if a GPU is available and use it, otherwise fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load CLIP model and processor from Hugging Face
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Move the model to GPU or CPU
model.to(device)

# Plant crop descriptions (for zero-shot classification)
plant_crop_descriptions = {
    0: "Apple Leaf",
    1: "Blueberry Leaf",
    2: "Cherry Leaf",
    3: "Corn Leaf",
    4: "Grape Leaf",
    5: "Orange Leaf",
    6: "Peach Leaf",
    7: "Pepper Leaf",
    8: "Potato Leaf",
    9: "Raspberry Leaf",
    10: "Soybean Leaf",
    11: "Squash Leaf",
    12: "Strawberry Leaf",
    13: "Tomato Leaf",
}

# Plant disease descriptions (for zero-shot classification)
plant_disease_descriptions = {
    0: "Healthy Leaf",
    1: "Apple Leaf with Scab",
    2: "Black Rot",
    3: "Cedar apple rust",
    4: "Powdery Mildew",
    5: "Cercospora leaf spot (Gray leaf spot)",
    6: "Common rust",
    7: "Northern Leaf Blight",
    8: "Esca (Black Measles)",
    9: "Leaf blight (Isariopsis Leaf Spot)",
    10: "Huanglongbing (Citrus Greening)",
    11: "Bacterial spot",
    12: "Early blight",
    13: "Late blight",
    14: "Leaf scorch",
    15: "Leaf Mold",
    16: "Septoria leaf spot",
    17: "Spider mites (Two-spotted spider mite)",
    18: "Target Spot",
    19: "Tomato Leaf with Yellow Leaf Curl Virus",
    20: "Tomato Leaf with Mosaic Virus",
}

# Load test data from the CSV file (no column names)
test_data = pd.read_csv("PV test seen.csv", header=None)

# List to store ground truth labels and predicted labels
ground_truth = []
predictions = []

# Set base directory for images
base_dir = "plantvillage"  # Base directory for images

# Loop through the test images in the folder using tqdm for progress bar
for index, row in tqdm(test_data.iterrows(), total=test_data.shape[0], desc="Testing Progress"):
    image_path = os.path.join(base_dir, row[0])  # Add base directory to image path
    true_crop = row[1]  # Plant crop type (this might not be used for classification directly)
    true_disease = row[2]  # True disease label

    # Load and process the image
    try:
        image = Image.open(image_path)
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        continue

    # Combine plant crop descriptions and plant disease descriptions
    class_descriptions = list(plant_crop_descriptions.values()) + list(plant_disease_descriptions.values())

    inputs = processor(text=class_descriptions, images=image, return_tensors="pt", padding=True)

    # Move inputs to the same device as the model (GPU or CPU)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Get model outputs
    with torch.no_grad():
        outputs = model(**inputs)

    # Calculate similarity between the image and each class description (crop + disease)
    logits_per_image = outputs.logits_per_image  # shape: [batch_size, num_classes]
    probs = logits_per_image.softmax(dim=1)  # softmax to get probabilities

    # Get the predicted class index
    predicted_class_idx = torch.argmax(probs, dim=1).item()
    predicted_class_label = class_descriptions[predicted_class_idx]

    # Add the true disease and prediction to the lists
    ground_truth.append(true_disease)
    predictions.append(predicted_class_label)

# Calculate the accuracy
accuracy = accuracy_score(ground_truth, predictions)
print(f"Accuracy: {accuracy * 100:.2f}%")


Using device: cuda


Testing Progress: 100%|██████████| 10279/10279 [07:56<00:00, 21.57it/s]

Accuracy: 0.00%



