In [1]:
import os
import cv2
import tqdm
import glob
import random
import pickle
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from datetime import datetime

In [20]:
now = datetime.now()

USE_BEFORE = True
USE_GENERATED = True

EVAL_AMOUNT = 512
DATASET_DIR = "_data/plant_pathology"
#DATASET_DIR = "_data/plantdoc_csv"
INT_S1_DIR = f"_intermediate/stage1_pathology_masked/{now.strftime('%Y_%m_%d_%H_%M_%S')}"
if USE_GENERATED:
    INT_S1_DIR = sorted(glob.glob("_intermediate/stage1_pathology_masked/*"))[-1]
PATCHES_DIR = os.path.join(INT_S1_DIR, "patches")

In [3]:
os.makedirs(INT_S1_DIR, exist_ok=True)
os.makedirs(PATCHES_DIR, exist_ok=True)

In [4]:
train_data = pd.read_csv(os.path.join(DATASET_DIR, "data.csv"))

In [5]:
if EVAL_AMOUNT > len(train_data.index):
    indices = list(train_data.index)
else:
    indices = random.sample(list(train_data.index), k=EVAL_AMOUNT)

In [6]:
def get_patches(masks, image, apply_mask=False, padding=0):
    result = []
    
    for mask in masks:
        if apply_mask:
            image_tmp = image * (mask["segmentation"][:, :, np.newaxis])
        else:
            image_tmp = image
        
        bbox = mask["bbox"]
        x0 = bbox[1]-padding
        if x0 < 0:
            x0 = 0
        x1 = bbox[1]+bbox[3]+padding
        if x1 >= image.shape[0]:
            x1 = image.shape[0] - 1
        y0 = bbox[0]-padding
        if y0 < 0:
            y0 = 0
        y1 = bbox[0]+bbox[2]+padding
        if y1 >= image.shape[1]:
            y1 = image.shape[1] - 1
   
        x0 = int(x0)
        x1 = int(x1)
        y0 = int(y0)
        y1 = int(y1)

        try:
            patch = image_tmp[x0:x1, y0:y1]
        except:
            print(x0, x1, y0, y1, type(x0), type(x1), type(y0), type(y1)) 

        #mask['patch'] = patch
        
        if 0 in patch.shape:
            continue
        result.append(patch)
    
    return result

In [7]:
def get_patches_file(image_id):
    patches = []
    for file in glob.glob(os.path.join(PATCHES_DIR, image_id, "*.png")):
        patches.append(cv2.imread(file))
    return patches

In [8]:
def get_masks_file(image_id):
    with open(os.path.join(PATCHES_DIR, image_id, "data.pkl"), 'rb') as file:
        data = pickle.load(file)
    return data

In [9]:
device = "cuda"

In [10]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

def sam_generate_mask(image):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    return masks

  state_dict = torch.load(f)


In [11]:
class BinaryResnetClassifier(nn.Module):
    def __init__(self, num_classes=1):
        super(BinaryResnetClassifier, self).__init__()
        # Load a pre-trained ResNet model
        self.resnet = resnet50(ResNet50_Weights.IMAGENET1K_V1)  # You can choose any ResNet variant
        # Modify the last fully connected layer
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        nn.init.xavier_normal_(self.resnet.fc.weight)
        
    def forward(self, x):
        # Pass the input through the ResNet
        x = self.resnet(x)
        # Apply the sigmoid activation function
#        x = torch.sigmoid(x)  # Output will be between 0 and 1
        return x

In [12]:
class BinaryInceptionClassifier(nn.Module):
    def __init__(self, num_classes=1):
        super(BinaryInceptionClassifier, self).__init__()
        # Load a pre-trained ResNet model
        self.inception = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True) # You can choose any ResNet variant
        # Modify the last fully connected layer
        self.inception.fc = nn.Linear(self.inception.fc.in_features, num_classes)
        nn.init.xavier_normal_(self.inception.fc.weight)
        
    def forward(self, x):
        # Pass the input through the ResNet
        x = self.inception(x)
        # Apply the sigmoid activation function
#        x = torch.sigmoid(x)  # Output will be between 0 and 1
        return x

In [13]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=16, latent_dim=200, act_fn=nn.ReLU()):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2),  # 112x112
            act_fn,
            nn.Conv2d(out_channels, 2*out_channels, 3, padding=1, stride=2),  # 56x56
            act_fn,
            nn.Conv2d(2*out_channels, 4*out_channels, 3, padding=1, stride=2),  # 28x28
            act_fn,
            nn.Conv2d(4*out_channels, 8*out_channels, 3, padding=1, stride=2),  # 14x14
            act_fn,
            nn.Conv2d(8*out_channels, 16*out_channels, 3, padding=1, stride=2),  # 7x7
            act_fn,
            nn.Flatten(),
            nn.Linear(16*out_channels*7*7, latent_dim),
            act_fn
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=16, latent_dim=200, act_fn=nn.ReLU()):
        super().__init__()
        self.out_channels = out_channels
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 16*out_channels*7*7),
            act_fn
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(16*out_channels, 8*out_channels, 3, stride=2, padding=1, output_padding=1),  # 14x14
            act_fn,
            nn.ConvTranspose2d(8*out_channels, 4*out_channels, 3, stride=2, padding=1, output_padding=1),  # 28x28
            act_fn,
            nn.ConvTranspose2d(4*out_channels, 2*out_channels, 3, stride=2, padding=1, output_padding=1),  # 56x56
            act_fn,
            nn.ConvTranspose2d(2*out_channels, out_channels, 3, stride=2, padding=1, output_padding=1),  # 112x112
            act_fn,
            nn.ConvTranspose2d(out_channels, in_channels, 3, stride=2, padding=1, output_padding=1),  # 224x224
        )

    def forward(self, x):
        output = self.linear(x)
        output = output.view(-1, 16*self.out_channels, 7, 7)
        output = self.conv(output)
        return output

#  defining autoencoder
class BigAutoencoder(nn.Module):
    def __init__(self, encoder=Encoder(), decoder=Decoder()):
        super().__init__()
        self.encoder = encoder
#        self.encoder.to(device)

        self.decoder = decoder
#        self.decoder.to(device)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [14]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from ultralytics import YOLO
import torchvision.transforms.v2 as transforms

yolo = YOLO("../leaf_segmentation/out/yolo_urban_street/train/weights/best.pt")
yolo_syn = YOLO("../leaf_segmentation/out/yolo_synthetic/train4/weights/best.pt")

resnet = torch.load("../leaf_segmentation/out/leaf_classifier/resnet_masked_cos/resnet_best.pth")
resnet = resnet.to(device)
resnet.eval()

resnet_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.Resize((224, 224)),
    transforms.ToDtype(torch.float32, scale=True),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])

inception = torch.load("../leaf_segmentation/out/leaf_classifier/inception/inception_best.pth")
inception = inception.to(device)
inception.eval()

inception_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.Resize((320, 320)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

In [22]:
def predict_sam(img, pred, image_id, threshold):
    masks = get_masks_file(image_id)
    result = []
    mask_result = []
    
    for i, mask in enumerate(masks):
        patch = mask['patch']
        if 0 in patch.shape:
            continue
        include, prob = pred(patch)
        if include:
            result.append(mask["segmentation"] > 0)
        
        mask['leaf_probability'] = prob
        if prob < threshold:
            del masks[i]
    return result, mask_result
    

def pred_resnet(x):
    x = resnet_transform(x).to(device).unsqueeze(0)
    with torch.no_grad():
        out = resnet(x)
        sig = torch.sigmoid(out).item()
        return sig < 0.025, 1 - sig    

def s1_sam_resnet(img, image_id):
    return predict_sam(img, pred_resnet, image_id, .9)


def pred_inception(x):
    x = inception_transform(x).to(device).unsqueeze(0)
    with torch.no_grad():
        out = inception(x)
        sig = torch.sigmoid(out).item()
        return sig < 0.01, 1 - sig

def s1_sam_inception(img, image_id):
    return predict_sam(img, pred_inception, image_id, .95)

In [16]:
from ultralytics import YOLO, checks
model = YOLO("../leaf_segmentation/out/yolo_urban_street/train/weights/best.pt")

def s1_sam_yolo(image, image_id):
#    masks = sam_generate_mask(image)
#    patches = get_patches(masks, image)
#    patches = get_patches_file(image_id)
    results_yolo = []
    masks = get_masks_file(image_id)
    for i, mask in enumerate(masks):
        result = model.predict(mask['patch'], verbose=False)
        # retrieve leaf (class 1) porbability
        prob = result[0].boxes.conf
        if len(prob) == 1:
            prob = prob.item()
        else:
            prob = 0
        results_yolo.append(prob)
        mask['leaf_probability'] = prob
    masks_filtered = [mask for mask in masks if mask['leaf_probability'] > .8]
    return masks

In [17]:
model_seg = YOLO("../leaf_segmentation/out/yolo_synthetic/train4/weights/best.pt")

def s1_yolo(image, image_id):
    masks_result = []
    result = model.predict(image, verbose=False, retina_masks=True)[0]
    if result.masks is None:
        return []
    masks = result.masks.data
    boxes = result.boxes.data
    names = list(result.names.values())
    
    classes = boxes[:, 5]
    
    for i, name in enumerate(names):
        obj_indices = torch.where(classes == i)
        obj_masks = masks[obj_indices]
        obj_masks = torch.nn.functional.interpolate(obj_masks.unsqueeze(0), size=image.shape[:2], mode='bilinear', align_corners=False).squeeze(0)
        prob = result[0].boxes.conf
        
        segmentations = [seg.cpu().numpy() for seg in torch.unbind(obj_masks)]
        
        for i, seg in enumerate(segmentations):
            patch = image * seg[:, :, np.newaxis].astype(np.uint8)
            coords = cv2.findNonZero(seg)  # Returns all non-zero points
            x, y, w, h = cv2.boundingRect(coords)  # Get bounding box
            
            patch = patch[y:y+h, x:x+w]
            masks_result.append({
                "segmentation": seg,
                "leaf_probability": result.boxes.conf.cpu().numpy()[i],
                "patch": patch
            })
    return masks_result

In [26]:
for index in tqdm.tqdm(indices, desc="Generating patches"):
    img_id = train_data.loc[index]["image_id"]
    try:
        os.makedirs(os.path.join(PATCHES_DIR, img_id))
    except:
        continue
    img = cv2.imread(os.path.join(DATASET_DIR, "images", img_id + ".jpg"))
    img = cv2.resize(img, (640, 640))
    masks = sam_generate_mask(img)
    patches = get_patches(masks, img, apply_mask=True)
    
    for d, item in zip(masks, patches):
        d['patch'] = item
    with open(os.path.join(PATCHES_DIR, img_id, "data.pkl"), 'wb+') as file:
        pickle.dump(masks, file, protocol=pickle.HIGHEST_PROTOCOL)
    
    #for i, patch in enumerate(patches):
    #    cv2.imwrite(os.path.join(PATCHES_DIR, img_id, f"patch{i}.png"), patch)

Generating patches: 100%|██████████| 512/512 [11:53<00:00,  1.39s/it]


In [27]:
stage1_dict = {
    #"YOLOv8": s1_yolo,
    "SAM + ResNet": s1_sam_resnet,
    "SAM + YOLOv8": s1_sam_yolo,
    "SAM + Inception": s1_sam_inception
#    "Mask R-CNN": s1_mask_rcnn
}

In [None]:
stage1_results = {}

for stage1_name, stage1_model in stage1_dict.items():
    print(f"Running model {stage1_name}")
    stage1_results[stage1_name] = {}
    for index in tqdm.tqdm(indices, desc=stage1_name):
        gt_healthy = bool(train_data.loc[index]["healthy"])
        stage1_results[stage1_name][index] = {
            'healthy': gt_healthy,
            'masks': []
        }
        img = cv2.imread(os.path.join(DATASET_DIR, "images", train_data.loc[index]["image_id"] + ".jpg"))
        with torch.no_grad():
            leaf_masks = stage1_model(img, train_data.loc[index]["image_id"])
            stage1_results[stage1_name][index]['masks'] = leaf_masks
        torch.cuda.empty_cache()

Running model SAM + ResNet


SAM + ResNet: 100%|██████████| 512/512 [02:07<00:00,  4.02it/s]


Running model SAM + YOLOv8


SAM + YOLOv8:  62%|██████▏   | 319/512 [02:31<01:13,  2.62it/s]

In [17]:
import pickle

with open(os.path.join(INT_S1_DIR, "total_data.pkl"), "wb+") as file:
    pickle.dump(stage1_results, file, protocol=pickle.HIGHEST_PROTOCOL)

In [18]:
import pickle

for stage1_name, stage1_result in stage1_results.items():
    os.makedirs(os.path.join(INT_S1_DIR, stage1_name), exist_ok=True)
    with open(os.path.join(INT_S1_DIR, stage1_name, "data.pkl"), "wb+") as file:
        pickle.dump(stage1_result, file, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
for stage1_name, stage1_result in stage1_results.items():
    patches_dir = os.path.join(INT_S1_DIR, stage1_name, "patches")
    for index, data in stage1_result.items():
        image_dir = os.path.join(patches_dir, str(index))
        os.makedirs(image_dir, exist_ok=True)
        for i, leaf_mask in enumerate(data['masks']):
            cv2.imwrite(os.path.join(image_dir, f"patch_{i}.png"), leaf_mask['patch'])