In [1]:
import os
import cv2
import tqdm
import glob
import random
import numpy as np
import pandas as pd
from datetime import datetime

In [2]:
now = datetime.now()

EVAL_AMOUNT = 512
#DATASET_DIR = "_data/plant_pathology"
DATASET_DIR = "_data/plantdoc_csv"
INT_S1_DIR = f"_intermediate/stage1_plantdoc/{now.strftime('%Y_%m_%d_%H_%M_%S')}"
PATCHES_DIR = os.path.join(INT_S1_DIR, "patches")

In [3]:
os.makedirs(INT_S1_DIR, exist_ok=True)
os.makedirs(PATCHES_DIR, exist_ok=True)

In [4]:
train_data = pd.read_csv(os.path.join(DATASET_DIR, "data.csv"))

In [5]:
if EVAL_AMOUNT > len(train_data.index):
    indices = list(train_data.index)
else:
    indices = random.sample(list(train_data.index), k=EVAL_AMOUNT)

In [6]:
def get_patches(masks, image, apply_mask=False, padding=0):
    result = []
    
    for mask in masks:
        if apply_mask:
            image_tmp = image * (mask["segmentation"][:, :, np.newaxis])
        else:
            image_tmp = image
        
        bbox = mask["bbox"]
        x0 = bbox[1]-padding
        if x0 < 0:
            x0 = 0
        x1 = bbox[1]+bbox[3]+padding
        if x1 >= image.shape[0]:
            x1 = image.shape[0] - 1
        y0 = bbox[0]-padding
        if y0 < 0:
            y0 = 0
        y1 = bbox[0]+bbox[2]+padding
        if y1 >= image.shape[1]:
            y1 = image.shape[1] - 1
   
        x0 = int(x0)
        x1 = int(x1)
        y0 = int(y0)
        y1 = int(y1)

        try:
            patch = image_tmp[x0:x1, y0:y1]
        except:
            print(x0, x1, y0, y1, type(x0), type(x1), type(y0), type(y1)) 

        #mask['patch'] = patch
        
        if 0 in patch.shape:
            continue
        result.append(patch)
    
    return result

In [7]:
def get_patches_file(image_id):
    patches = []
    for file in glob.glob(os.path.join(PATCHES_DIR, image_id, "*.png")):
        patches.append(cv2.imread(file))
    return patches

In [8]:
device = "cuda"

In [9]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

def sam_generate_mask(image):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    return masks

  state_dict = torch.load(f)


In [10]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50, ResNet50_Weights
class BinaryResnetClassifier(nn.Module):
    def __init__(self, num_classes=1):
        super(BinaryResnetClassifier, self).__init__()
        # Load a pre-trained ResNet model
        self.resnet = resnet50(ResNet50_Weights.IMAGENET1K_V1) 
        # Modify the last fully connected layer
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        nn.init.xavier_normal_(self.resnet.fc.weight)

    def forward(self, x):
        # Pass the input through the ResNet
        x = self.resnet(x)
        return x
    
import torchvision.transforms.v2 as transforms

tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])


resnet = torch.load("../leaf_segmentation/out/leaf_classifier/resnet/resnet_latest.pth")
resnet = resnet.to(device)

def s1_sam_resnet(image, image_id):
    #masks = sam_generate_mask(image)
    #patches = get_patches(masks, image)
    patches = get_patches_file(image_id)
    from PIL import Image
    results = []
    masks = [{}] * len(patches)
    with torch.no_grad():
        for i, patch in enumerate(patches):
            input = tf(Image.fromarray(patch)).unsqueeze(0).to(device)
            result = torch.sigmoid(resnet(input)).cpu().item()

            results.append(result)
            masks[i]["patch"] = patch
            masks[i]["leaf_probability"] = result
    PROBABILITY_THRESHOLD = .5
    masks_filtered = [mask for mask, result in zip(masks,results) if result > PROBABILITY_THRESHOLD]
    return masks_filtered

  resnet = torch.load("../leaf_segmentation/out/leaf_classifier/resnet/resnet_latest.pth")


In [11]:
from ultralytics import YOLO, checks
model = YOLO("../leaf_segmentation/out/yolo_urban_street/train/weights/best.pt")

def s1_sam_yolo(image, image_id):
#    masks = sam_generate_mask(image)
#    patches = get_patches(masks, image)
    patches = get_patches_file(image_id)
    results_yolo = []
    masks = []
    for i, patch in enumerate(patches):
        result = model.predict(patch, verbose=False)
        # retrieve leaf (class 1) porbability
        prob = result[0].boxes.conf
        if len(prob) == 1:
            prob = prob.item()
        else:
            prob = 0
        results_yolo.append(prob)
        masks.append({ 'patch':patch, 'leaf_probability': prob})
    masks_filtered = [mask for mask in masks if mask['leaf_probability'] > .8]
    return masks_filtered

In [12]:
model_seg = YOLO("../leaf_segmentation/out/yolo_synthetic/train4/weights/best.pt")

def s1_yolo(image, image_id):
    result = model.predict(image)[0]
    print(result)
    print(result.masks)
    return masks_filtered

In [19]:
for index in tqdm.tqdm(indices, desc="Generating patches"):
    img_id = train_data.loc[index]["image_id"]
    img = cv2.imread(os.path.join(DATASET_DIR, "images", img_id + ".jpg"))
    img = cv2.resize(img, (640, 640))
    masks = sam_generate_mask(img)
    patches = get_patches(masks, img, apply_mask=False)
    os.makedirs(os.path.join(PATCHES_DIR, img_id), exist_ok=True)
    for i, patch in enumerate(patches):
        cv2.imwrite(os.path.join(PATCHES_DIR, img_id, f"patch{i}.png"), patch)

Generating patches: 100%|██████████| 512/512 [23:23<00:00,  2.74s/it]


In [20]:
stage1_dict = {
    "SAM + YOLOv8": s1_sam_yolo,
    "SAM + ResNet": s1_sam_resnet,
#    "YOLOv8": s1_yolo,
#    "Mask R-CNN": s1_mask_rcnn
}

In [None]:
stage1_results = {}

for stage1_name, stage1_model in stage1_dict.items():
    print(f"Running model {stage1_name}")
    stage1_results[stage1_name] = {}
    for index in tqdm.tqdm(indices, desc=stage1_name):
        gt_healthy = bool(train_data.loc[index]["healthy"])
        stage1_results[stage1_name][index] = {
            'healthy': gt_healthy,
            'masks': []
        }
        img = cv2.imread(os.path.join(DATASET_DIR, "images", train_data.loc[index]["image_id"] + ".jpg"))
        with torch.no_grad():
            leaf_masks = stage1_model(img, train_data.loc[index]["image_id"])
            stage1_results[stage1_name][index]['masks'] = leaf_masks
        torch.cuda.empty_cache()

Running model SAM + YOLOv8


SAM + YOLOv8: 100%|██████████| 512/512 [04:29<00:00,  1.90it/s]


Running model SAM + ResNet


SAM + ResNet:  40%|███▉      | 204/512 [02:34<05:58,  1.16s/it]

In [22]:
import pickle

with open(os.path.join(INT_S1_DIR, "total_data.pkl"), "wb+") as file:
    pickle.dump(stage1_results, file, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import pickle

for stage1_name, stage1_result in stage1_results.items():
    os.makedirs(os.path.join(INT_S1_DIR, stage1_name), exist_ok=True)
    with open(os.path.join(INT_S1_DIR, stage1_name, "data.pkl"), "wb+") as file:
        pickle.dump(stage1_result, file, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
for stage1_name, stage1_result in stage1_results.items():
    patches_dir = os.path.join(INT_S1_DIR, stage1_name, "patches")
    for index, data in stage1_result.items():
        image_dir = os.path.join(patches_dir, str(index))
        os.makedirs(image_dir, exist_ok=True)
        for i, leaf_mask in enumerate(data['masks']):
            cv2.imwrite(os.path.join(image_dir, f"patch_{i}.png"), leaf_mask['patch'])