In [70]:
import os
import cv2
import tqdm
import random
import pandas as pd

In [57]:
EVAL_AMOUNT = 10000
DATASET_DIR = "_data/plant_pathology"
INT_S1_DIR = "_intermediate/stage1"

In [58]:
os.makedirs(INT_S1_DIR, exist_ok=True)

In [59]:
train_data = pd.read_csv(os.path.join(DATASET_DIR, "train.csv"))

In [62]:
if EVAL_AMOUNT > len(train_data.index):
    indices = list(train_data.index)
else:
    indices = random.sample(list(train_data.index), k=EVAL_AMOUNT)

In [63]:
def get_patches(masks, image, apply_mask=False, padding=0):
    result = []
    
    for mask in masks:
        if apply_mask:
            image_tmp = image * (mask["segmentation"][:, :, np.newaxis])
        else:
            image_tmp = image
        
        bbox = mask["bbox"]
        x0 = bbox[1]-padding
        if x0 < 0:
            x0 = 0
        x1 = bbox[1]+bbox[3]+padding
        if x1 >= image.shape[0]:
            x1 = image.shape[0] - 1
        y0 = bbox[0]-padding
        if y0 < 0:
            y0 = 0
        y1 = bbox[0]+bbox[2]+padding
        if y1 >= image.shape[1]:
            y1 = image.shape[1] - 1
        
        patch = image_tmp[x0:x1, y0:y1]
        #mask['patch'] = patch
        
        if 0 in patch.shape:
            continue
        result.append(patch)
    
    return result

In [64]:
device = "cuda"

In [79]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

def sam_generate_mask(image):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    return masks

In [80]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50, ResNet50_Weights
class BinaryResnetClassifier(nn.Module):
    def __init__(self, num_classes=1):
        super(BinaryResnetClassifier, self).__init__()
        # Load a pre-trained ResNet model
        self.resnet = resnet50(ResNet50_Weights.IMAGENET1K_V1) 
        # Modify the last fully connected layer
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        nn.init.xavier_normal_(self.resnet.fc.weight)

    def forward(self, x):
        # Pass the input through the ResNet
        x = self.resnet(x)
        return x
    
import torchvision.transforms.v2 as transforms

tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])


resnet = torch.load("../leaf_segmentation/out/leaf_classifier/resnet/resnet_latest.pth")
resnet = resnet.to(device)

def s1_sam_resnet(image):
    masks = sam_generate_mask(image)
    patches = get_patches(masks, image)
    from PIL import Image
    results = []
    with torch.no_grad():
        for i, patch in enumerate(patches):
            input = tf(Image.fromarray(patch)).unsqueeze(0).to(device)
            result = torch.sigmoid(resnet(input)).cpu().item()

            results.append(result)
            masks[i]["patch"] = patch
            masks[i]["leaf_probability"] = result
    PROBABILITY_THRESHOLD = .5
    masks_filtered = [mask for mask, result in zip(masks,results) if result > PROBABILITY_THRESHOLD]
    return masks_filtered



In [81]:
from ultralytics import YOLO, checks
model = YOLO("../leaf_segmentation/out/yolo_urban_street/train/weights/best.pt")

def s1_sam_yolo(image):
    masks = sam_generate_mask(image)
    patches = get_patches(masks, image)
    results_yolo = []
    for i, patch in enumerate(patches):
        result = model.predict(patch, verbose=False)
        # retrieve leaf (class 1) porbability
        prob = result[0].boxes.conf
        if len(prob) == 1:
            prob = prob.item()
        else:
            prob = 0
        results_yolo.append(prob)
        masks[i]["patch"] = patch
        # TODO: update probability assignment
        masks[i]["leaf_probability"] = prob
    masks_filtered = [mask for mask, result in zip(masks,results_yolo) if result > .8]
    return masks_filtered

In [82]:
stage1_dict = {
    "SAM + YOLOv8": s1_sam_yolo,
    "SAM + ResNet": s1_sam_resnet,
#    "Mask R-CNN": s1_mask_rcnn
}

In [78]:
stage1_results = {}
for stage1_name, stage1_model in stage1_dict.items():
    print(f"Running model {stage1_name}")
    stage1_results[stage1_name] = {}
    for index in tqdm.tqdm(indices, desc=stage1_name):
        gt_healthy = bool(train_data.loc[index]["healthy"])
        stage1_results[stage1_name][index] = {
            'healthy': gt_healthy,
            'masks': []
        }
        img = cv2.imread(os.path.join(DATASET_DIR, "images", train_data.loc[index]["image_id"] + ".jpg"))
        with torch.no_grad():
            leaf_masks = stage1_model(img)
            stage1_results[stage1_name][index]['masks'] = leaf_masks
        torch.cuda.empty_cache()

Running model SAM + YOLOv8


SAM + YOLOv8:   6%|â–Œ         | 105/1821 [13:54<3:47:19,  7.95s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 39.50 GiB of which 18.06 MiB is free. Including non-PyTorch memory, this process has 6.88 GiB memory in use. Process 2607513 has 32.58 GiB memory in use. Of the allocated memory 6.20 GiB is allocated by PyTorch, and 170.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [55]:
import pickle
for stage1_name, stage1_result in stage1_results.items():
    os.makedirs(os.path.join(INT_S1_DIR, stage1_name), exist_ok=True)
    with open(os.path.join(INT_S1_DIR, stage1_name, "data.pkl"), "wb+") as file:
        pickle.dump(stage1_result, file, protocol=pickle.HIGHEST_PROTOCOL)

with open(os.path.join(INT_S1_DIR, "total_data.pkl"), "wb+") as file:
    pickle.dump(stage1_results, file, protocol=pickle.HIGHEST_PROTOCOL)

In [56]:
for stage1_name, stage1_result in stage1_results.items():
    patches_dir = os.path.join(INT_S1_DIR, stage1_name, "patches")
    os.makedirs(patches_dir, exist_ok=True)
    for index, data in stage1_result.items():
        for i, leaf_mask in enumerate(data['masks']):
            cv2.imwrite(os.path.join(patches_dir, f"patch_{index}_{i}.png"), leaf_mask['patch'])