In [6]:
import os, argparse
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import sys
sys.path.append("..")

import segmentation_models_pytorch as smp
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from dataset import PhenotypeDataset, train_augmentation, validation_augmentation, get_preprocessing

In [7]:
def load_sam(sam_checkpoint, model_type):
	#for now
	device = "cpu"

	sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
	sam.to(device=device)

	return SamAutomaticMaskGenerator(sam)

In [8]:
model_path = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"
root_dir = "/Users/amankumar/Library/CloudStorage/OneDrive-UniversityofWaterloo/uWaterloo Courses/Winter '23/CS 679/Project/leaf_counting/Plant"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

#Loading dataset

preprocess_input = smp.encoders.get_preprocessing_fn('resnet18', pretrained='imagenet')
mask_generator = load_sam(model_path, model_type)

In [9]:
#!rm "/Users/amankumar/Library/CloudStorage/OneDrive-UniversityofWaterloo/uWaterloo Courses/Winter '23/CS 679/Project/leaf_counting/Plant/train/.DS_Store"
phenotype_dataset_train = PhenotypeDataset(root_dir=os.path.join(root_dir, 'train'), transform=train_augmentation())#, preprocessing=get_preprocessing(preprocessing_fn))
phenotype_dataset_test = PhenotypeDataset(root_dir=os.path.join(root_dir, 'test'), transform=validation_augmentation())#, preprocessing=get_preprocessing(preprocessing_fn))

batch_size = 4
train_dataloader = DataLoader(phenotype_dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(phenotype_dataset_test, batch_size=32, shuffle=False)

In [None]:
save_path = "Processed_Plant/train"
j = 0
for i, batch in tqdm(enumerate(train_dataloader, 0)):
    for img in batch["image"]:
        img = img.numpy()
        masks = mask_generator.generate(img)
        sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=False)
        
        total_sum = sum([i["area"] for i in sorted_masks])
        if sorted_masks[-1]["area"]/total_sum > 0.05:
            del_1 = sorted_masks[-1]["area"]
            sorted_masks.pop(-1)
        if sorted_masks[-1]["area"]/total_sum > 0.05:
            del_2 = sorted_masks[-1]["area"]
            sorted_masks.pop(-1)
        print("{}: Deleted sizes: {}, {} | Last element now = {}".format(i, del_1, del_2, sorted_masks[-1]["area"]))
        superimposed = np.zeros((img.shape[0], img.shape[1], 3))          
        for mask in sorted_masks:
            m = mask['segmentation']
            m = np.repeat(m[:, :, np.newaxis], 3, axis=2)*1.0
            #print(mask["area"])
            color_mask = np.random.random((1, 3)).tolist()[0]
            for i in range(3):
                m[:,:,i] = m[:,:,i]*color_mask[i]
            superimposed = superimposed + m
        
        im = Image.fromarray((superimposed * 255).astype(np.uint8))
        im.save(os.path.join(save_path,"plant_{}.png".format(j)))
        j = j + 1



0it [00:00, ?it/s]

0: Deleted sizes: 78418, 23773 | Last element now = 23220
2: Deleted sizes: 83850, 9021 | Last element now = 659
2: Deleted sizes: 71057, 12371 | Last element now = 174


1it [02:55, 175.50s/it]

2: Deleted sizes: 98449, 12371 | Last element now = 3676
1: Deleted sizes: 86811, 14605 | Last element now = 944
2: Deleted sizes: 101981, 14605 | Last element now = 2367
2: Deleted sizes: 101303, 14605 | Last element now = 2828


2it [05:51, 176.08s/it]

2: Deleted sizes: 13103, 8856 | Last element now = 7761
