In [22]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import cv2
import supervision as sv
import numpy as np

# Import the SAM Module

In [23]:
from segment_anything import SamPredictor, sam_model_registry

# Load in the current model

In [24]:
sam = sam_model_registry['vit_h'](checkpoint='../weights/sam_vit_h_4b8939.pth')
predictor = SamPredictor(sam)

In [7]:
os.listdir(os.getcwd())

['SAM_retraining_test.ipynb']

In [10]:
os.path.split(os.getcwd())

('/Users/ryantenbarge/code/sstollunderwood/solar_potential_map', 'notebooks')

# Define dataset with masks

In [35]:
class GisTrainingDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mask_dir, transform=None):
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

            img_name = os.path.join(self.root_dir, self.images[idx])
            mask_name = os.path.join(self.mask_dir, self.images[idx])

            image = Image.open(img_name)
            mask = Image.open(mask_name)

            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)

            return image, mask

# Define Transformers (like an image preprocessing pipeline)

In [32]:
transform = transforms.Compose([
    transforms.Resize((286, 286)), # this may be unnecessary, but resizes images to 572x572 pixels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Need to replace with mean and STD of all our org photos
])

# Load custom dataset

In [33]:
import IPython
from pathlib import Path
import sys
NBK_DIR = IPython.extract_module_locals()[1]["_dh"][0]

In [36]:
dataset = GisTrainingDataset(root_dir=NBK_DIR.parent/'data_for_ml'/'original',
                             mask_dir=NBK_DIR.parent/'data_for_ml'/'masked')

In [28]:
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [29]:
# Finetuning the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor.sam.to(device)
predictor.sam.train()

AttributeError: 'SamPredictor' object has no attribute 'sam'