In [None]:
import numpy as np
import torch
import torchvision.models as models
import torch.nn as nn
from PIL import Image
import tensorflow as tf
import os
from time import sleep
import requests
import json

In [None]:
def get_info(prompt):
  url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=Public API Key"

  headers = {
      'Content-Type': 'application/json'
  }

  data = {
      "contents": [
          {
              "parts": [
                  {
                      "text": prompt,
                  }
              ]
          }
      ]
  }

  response = requests.post(url, headers=headers, data=json.dumps(data))
  text = response.json()['candidates'][0]['content']['parts'][0]['text'].replace("**","")
  print(text)

In [None]:
def image_transform(image_path):
    image = Image.open(image_path).convert('RGB')  # Assuming RGB images
    image = image.resize((224, 224))  # Resize to specific dimension (adjust as needed)
    image = torch.from_numpy(np.array(image) / 255.0).permute(2, 0, 1).float().reshape(1,3,224,224)  # Convert to tensor and normalize
    return image

In [None]:
plant_disease_list = [
  {"Plant Name": "cassava", "Disease/Problem": "Curled leaf symptom"},
  {"Plant Name": "cassava", "Disease/Problem": "Healthy Leaf"},
  {"Plant Name": "cassava", "Disease/Problem": "Cassava mosaic disease (CMD)"},
  {"Plant Name": "maize", "Disease/Problem": "Fall armyworm (FAW)"},
  {"Plant Name": "cassava", "Disease/Problem": "Young healthy leaf"},
  {"Plant Name": "cassava", "Disease/Problem": "Old plant"},
  {"Plant Name": "cassava", "Disease/Problem": "Nutritional deficiency"},
  {"Plant Name": "cassava", "Disease/Problem": "Bacterial leaf streak (BLS)"},
  {"Plant Name": "cassava", "Disease/Problem": "Sooty mold"},
  {"Plant Name": "cassava", "Disease/Problem": "Brown streak disease (BSD)"},
  {"Plant Name": "cassava", "Disease/Problem": "Young cassava mosaic disease (YCMD)"},
  {"Plant Name": "cassava", "Disease/Problem": "Cassava green mottle disease (CGM)"},
  {"Plant Name": "cassava", "Disease/Problem": "Cassava root rot disease (CRM)"}
]

In [None]:
prompts = [
  "Explain cassava curled leaf symptom in detail and how to control it.",
  "Cassava healthy leaf is a good sign, but are there any preventive measures to avoid future diseases?",
  "How to identify and control cassava mosaic disease (CMD)?",
  "Explain the damage caused by fall armyworm (FAW) on maize and how to manage it.",
  "What differentiates a young healthy cassava leaf from a diseased one?",
  "What happens to a cassava plant as it gets old? Are there specific diseases more common at this stage?",
  "How to diagnose and address nutritional deficiencies in cassava plants?",
  "Explain the causes and control methods for bacterial leaf streak (BLS) on cassava.",
  "What is sooty mold on cassava and how to get rid of it?",
  "What are the management strategies for cassava brown streak disease (BSD)?",
  "How to differentiate between young cassava mosaic disease (YCMD) and regular cassava mosaic disease?",
  "What are the effects of cassava green mottle disease (CGM) and how to control it?",
  "Explain cassava root rot disease (CRM) and how to prevent it."
]

In [None]:
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path="resnet50.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
# Prepare input data (replace this with your actual input data)
#input_data = images.detach().numpy()#.transpose((0,1,2,3))
image_dir = "/content/samples/"
image_paths = [image_dir+item for item in os.listdir(image_dir)]
for image_path in image_paths:

  print(image_path)
  images = image_transform(image_path)
  # Set input tensor
  interpreter.set_tensor(input_details[0]['index'], images)
  # Run inference
  interpreter.invoke()
  # Get output tensor
  output_data = interpreter.get_tensor(output_details[0]['index'])

  # Print the output
  idx = np.argmax(output_data)
  print(idx)
  print(plant_disease_list[idx])
  get_info(prompts[idx])
  print("\n\n")

  sleep(2)