In [6]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer

# Descriptions
plant_crop_descriptions = {
    0: "Apple Leaf",
    1: "Blueberry Leaf",
    2: "Cherry Leaf",
    3: "Corn Leaf",
    4: "Grape Leaf",
    5: "Orange Leaf",
    6: "Peach Leaf",
    7: "Pepper Leaf",
    8: "Potato Leaf",
    9: "Raspberry Leaf",
    10: "Soybean Leaf",
    11: "Squash Leaf",
    12: "Strawberry Leaf",
    13: "Tomato Leaf",
}


plant_disease_descriptions = {
    0: "Healthy Leaf- vibrant, disease-free, normal leaf appearance.",
    1: "Apple Leaf with Scab - dark spots, rough, scaly surface.",
    2: "Black Rot - blackened lesions, wilted tissue, spreading damage.",
    3: "Cedar Apple Rust - orange spots, thread-like fungal growth visible.",
    4: "Powdery Mildew - white powdery patches, fungal overgrowth on surface.",
    5: "Cercospora Leaf Spot (Gray Leaf Spot) - gray, circular, small necrotic lesions.",
    6: "Common Rust - brown pustules, raised fungal spores on leaf.",
    7: "Northern Leaf Blight - elongated gray-brown lesions, spreading necrotic areas.",
    8: "Esca (Black Measles) - dark streaks, cracked, dying leaf tissue.",
    9: "Leaf Blight (Isariopsis Leaf Spot) - brown patches, soft, rotting edges.",
    10: "Huanglongbing (Citrus Greening) - yellow veins, deformed fruit, stunted growth.",
    11: "Bacterial Spot - wet-looking spots, irregular, gradually turning black.",
    12: "Early Blight - brown concentric rings, yellowing surrounding tissue areas.",
    13: "Late Blight - large black patches, water-soaked lesions on leaves.",
    14: "Leaf Scorch - browning tips, drying edges, desiccated appearance.",
    15: "Leaf Mold - fuzzy growth, pale yellow spots, leaf decay.",
    16: "Septoria Leaf Spot - small brown spots, spreading, surrounded by yellow.",
    17: "Spider Mites (Two-Spotted Spider Mite) - speckled leaves, fine webbing visible.",
    18: "Target Spot - dark circular spots, lighter centers, extensive necrosis.",
    19: "Tomato Leaf with Yellow Leaf Curl Virus - curled leaves, stunted yellow growth.",
    20: "Tomato Leaf with Mosaic Virus - mottled appearance, discolored green-yellow patterns."
}


# Add device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models and move them to the correct device
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

# Define a projection layer to align dimensions and move to device
projection_layer = nn.Linear(512, 384, bias=False).to(device)

# Prepare crop and disease descriptions
crop_texts = list(plant_crop_descriptions.values())
disease_texts = list(plant_disease_descriptions.values())

# Generate text embeddings using SentenceTransformer and move to device
crop_embeddings = text_model.encode(crop_texts, convert_to_tensor=True).to(device)
disease_embeddings = text_model.encode(disease_texts, convert_to_tensor=True).to(device)

# Normalize embeddings
crop_embeddings = crop_embeddings / crop_embeddings.norm(p=2, dim=-1, keepdim=True)
disease_embeddings = disease_embeddings / disease_embeddings.norm(p=2, dim=-1, keepdim=True)

def predict_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    inputs = clip_processor(images=image, return_tensors="pt", padding=True).to(device)

    # Get image embedding
    with torch.no_grad():
        image_embedding = clip_model.get_image_features(**inputs)
        image_embedding = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)  # Normalize
        image_embedding = projection_layer(image_embedding)  # Project to 384 dimensions

    # Compute similarities
    crop_similarities = torch.nn.functional.cosine_similarity(image_embedding, crop_embeddings.unsqueeze(0), dim=-1)
    disease_similarities = torch.nn.functional.cosine_similarity(image_embedding, disease_embeddings.unsqueeze(0), dim=-1)

    # Get predictions
    predicted_crop_idx = torch.argmax(crop_similarities).item()  # Crop index
    predicted_disease_idx = torch.argmax(disease_similarities).item()  # Disease index

    return predicted_crop_idx, predicted_disease_idx



In [7]:
# Example usage
image_path = "plantvillage/686a65da-b541-4b08-94bf-631cf6d27449___RS_HL-7584.JPG"  # Replace with your image path
predicted_crop_idx, predicted_disease_idx = predict_image(image_path)

print(f"Predicted Crop Index: {predicted_crop_idx}")
print(f"Predicted Disease Index: {predicted_disease_idx}")

Predicted Crop Index: 5
Predicted Disease Index: 16


In [8]:
import pandas as pd
from tqdm import tqdm

def evaluate_model(csv_file):
    # Load the CSV file
    column_names = ['image_name', 'crop_label', 'disease_label']
    data = pd.read_csv("PV test seen.csv", header=None, names=column_names)
    
    # Initialize counters for correct predictions
    total_images = 0
    correct_crop_predictions = 0
    correct_disease_predictions = 0

    for _, row in tqdm(data.iterrows(), total=len(data)):
        # Extract image path and true labels
        image_path = "plantvillage/" + row['image_name']  # Prepend "plantvillage/"
        true_crop_label = int(row['crop_label'])  # Crop label
        true_disease_label = int(row['disease_label'])  # Disease label

        # Predict crop and disease labels
        predicted_crop_idx, predicted_disease_idx = predict_image(image_path)

        # Update counters
        total_images += 1
        if predicted_crop_idx == true_crop_label:
            correct_crop_predictions += 1
        if predicted_disease_idx == true_disease_label:
            correct_disease_predictions += 1

    # Calculate accuracy
    crop_accuracy = correct_crop_predictions / total_images * 100
    disease_accuracy = correct_disease_predictions / total_images * 100

    return crop_accuracy, disease_accuracy



In [10]:
csv_file = "PV test seen.csv"
crop_acc, disease_acc = evaluate_model(csv_file)

print(f"Crop Accuracy: {crop_acc:.2f}%")
print(f"Disease Accuracy: {disease_acc:.2f}%")


  0%|          | 0/10279 [00:00<?, ?it/s]

100%|██████████| 10279/10279 [09:29<00:00, 18.03it/s]

Crop Accuracy: 4.03%
Disease Accuracy: 2.90%



