In [1]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(ResNet9, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.res1 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1))

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.res2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, padding=1))

        self.classifier = nn.Sequential(nn.MaxPool2d(4), nn.Flatten(), nn.Linear(512, num_classes))

    def forward(self, xb):
        out = F.relu(self.conv1(xb))
        out = F.relu(self.conv2(out))
        out = self.res1(out) + out
        out = F.relu(self.conv3(out))
        out = F.relu(self.conv4(out))
        out = self.res2(out) + out
        out = self.classifier(out)
        return out


In [6]:
model = torch.load('./plant-disease-model-complete.pth', map_location=torch.device('cpu'))
model.eval() 


  model = torch.load('./plant-disease-model-complete.pth', map_location=torch.device('cpu'))


ResNet9(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  )
  (res1): Sequential(
    (0): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=Tr

In [7]:
classes=['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Tomato_mosaic_virus',
 'Tomato___healthy']

In [8]:
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),  
])

def predict_image(image_path, model):
    image = Image.open(image_path)  
    image = transform(image).unsqueeze(0)
    model.eval()

    with torch.no_grad():  
        output = model(image)
    
    print(output)
    # Get the predicted class (the index of the highest value in the output)
    _, predicted_class = torch.max(output, 1)
    print(_)
    print(predicted_class)
    return predicted_class.item()

# Example usage
image_path = './test/CornCommonRust3.jpg'
predicted_class = predict_image(image_path, model)
print(f"Predicted class: {predicted_class}")
print('predicted class',classes[predicted_class])
print('real class',image_path.split('/')[2])


tensor([[-11.0651, -10.9376,  -9.5871, -11.1030, -14.0394, -17.2442, -13.2983,
          -3.3063,   8.5075, -12.0222, -10.1560, -22.9238, -18.9456, -18.2947,
         -22.4203, -15.4425, -11.5877, -10.9310, -10.1139, -13.8660, -11.2583,
         -22.8433, -11.9410, -21.3866, -21.3792, -19.5197, -13.4042, -23.5856,
         -18.9905,  -8.6215,  -4.1754, -19.4976, -16.4698, -20.1809, -17.2937,
         -25.0801, -21.6495, -17.7910]])
tensor([8.5075])
tensor([8])
Predicted class: 8
predicted class Corn_(maize)___Common_rust_
real class CornCommonRust3.jpg


In [9]:
# Assuming model is an instance of ResNet9
scripted_model = torch.jit.script(model)  # Script the model
torch.jit.save(scripted_model, 'scripted-plant-disease-model.pth')  # Save scripted model


In [10]:
# Loading the model using TorchScript
loaded_model = torch.jit.load('scripted-plant-disease-model.pth')
loaded_model.eval()


RecursiveScriptModule(
  original_name=ResNet9
  (conv1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
  )
  (conv2): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
    (3): RecursiveScriptModule(original_name=MaxPool2d)
  )
  (res1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Conv2d)
      (1): RecursiveScriptModule(original_name=BatchNorm2d)
      (2): RecursiveScriptModule(original_name=ReLU)
    )
    (1): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Conv2d)
      (1): RecursiveSc

In [11]:
# Avoid this if you used torch.save()
# model = torch.load('plant-disease-model.pth')

# Use this instead
model = torch.jit.load('scripted-plant-disease-model.pth')


In [14]:
from torchvision import transforms
from PIL import Image
import torch

transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),  
])

def predict_image(image_path, model, threshold=0.5):
    image = Image.open(image_path)  
    image = transform(image).unsqueeze(0)  # Add batch dimension
    model.eval()

    with torch.no_grad():  
        output = model(image)
    
    # Print all the scores for each class
    print("All class scores (logits):", output)

    # Determine the class with the maximum score
    scores = output.squeeze().tolist()
    print(scores)
    max_score = max(scores)
    print(max(scores))
    predicted_class = scores.index(max_score)


    # Check if the maximum score exceeds the threshold
    if max_score < threshold:
        predicted_class = 500  # Set to zero class if below threshold
    
    print("Predicted class index:", predicted_class)
    return predicted_class

# Example usage
image_path = 'test\AppleCedarRust1.JPG'
predicted_class = predict_image(image_path, model)
print(f"Predicted class: {predicted_class}")


All class scores (logits): tensor([[-13.2536,  -7.0521,   6.9100, -18.6356,  -5.7853, -20.9146, -12.0162,
         -12.5028,  -8.1652, -16.2039, -19.2399, -11.1580,  -9.5031, -19.0430,
         -12.5575, -16.1077, -14.6961, -20.1847, -10.6525, -11.1227, -15.6926,
         -21.7511, -12.7993, -10.0117, -23.8417, -26.7067, -14.7377, -14.3273,
         -17.2078, -13.6712, -11.7254, -17.5820, -11.8973, -20.3368, -13.6747,
         -19.2308, -16.5484, -21.1036]])
[-13.253632545471191, -7.052058696746826, 6.909963607788086, -18.63562774658203, -5.785250186920166, -20.914567947387695, -12.016183853149414, -12.502751350402832, -8.165199279785156, -16.20386505126953, -19.239946365356445, -11.158018112182617, -9.503124237060547, -19.042964935302734, -12.557503700256348, -16.10773468017578, -14.696085929870605, -20.184720993041992, -10.652541160583496, -11.122721672058105, -15.692585945129395, -21.751129150390625, -12.799331665039062, -10.011682510375977, -23.8416805267334, -26.706724166870117, -