# Train SAM on Google Cloud

## Set up Modelling Environment

In [1]:
!pip install rasterio transformers tqdm monai gcsfs datasets segment_anything git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-g7af9qoy
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-g7af9qoy
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting rasterio
  Obtaining dependency information for rasterio from https://files.pythonhosted.org/packages/5e/19/4617aaaf3166b06c520db50de38108bf069e63512712a7edda6710f4687b/rasterio-1.3.8.post2-cp310-cp310-manylinux2014_x86_64.whl.metadata
  Using cached rasterio-1.3.8.post2-cp310-cp310-manylinux2014_x86_64.whl.metadata (14 kB)
Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/1a/d1/3bba59606141ae808017f6fde91453882f931957f125009417b87a

In [2]:
# import os
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb=30'  # Set it to an appropriate value
# import torch

In [3]:
import sys
import os
import io
import argparse
import numpy as np
import rasterio
from datasets import Dataset
from PIL import Image
from torch.utils.data import Dataset as TorchDataset, DataLoader, random_split
from transformers import SamProcessor
from transformers import SamModel
import torch
from segment_anything import sam_model_registry
from tqdm import tqdm
from statistics import mean
import time
import monai
import gcsfs

## Define Various Functions

In [4]:
class args:
    # input_dir = 'satellite_images/train/stack/'
    # output_dir = 'models/'
    input_dir = 'gs://meter-sam/train/stack/'
    output_dir = 'gs://meter-sam/model/'
    num_epochs = 10
    batch_size = 4
    learning_rate = 1e-5
    weight_decay = 0

In [5]:
# def parse_arguments():
#     """Parse command-line arguments."""
#     parser = argparse.ArgumentParser(description="Train the SAM on Vertex AI.")

#     # Directories and File Paths
#     parser.add_argument('--input-dir', type=str, default='gs://meter-sam/stack/', help='Input data directory path.')
#     parser.add_argument('--output-dir', type=str, default='gs://meter-sam/model/', help='Output data directory path.')

#     # Hyperparameters
#     parser.add_argument('--num-epochs', type=int, default=5, help='Number of training epochs.')
#     parser.add_argument('--batch-size', type=int, default=2, help='Batch size for training.')
#     parser.add_argument('--learning-rate', type=float, default=1e-5, help='Learning rate for optimizer.')
#     parser.add_argument('--weight-decay', type=float, default=0, help='Weight decay for optimizer.')
    
#     return parser.parse_args()

In [6]:
def read_tiff(file_path):
    """Read TIFF file and return its content."""
    try:
        with rasterio.open(file_path) as src:
            return src.read()
    except Exception as e:
        print(f"Error reading file: {file_path}. Error: {e}")
        return None

In [7]:
def get_bounding_box(ground_truth_map):
    """Compute bounding box for the given ground truth map."""
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    return [x_min, y_min, x_max, y_max]

In [8]:
class SAMDataset(TorchDataset):
    """
    This class is used to create a dataset that serves input images and masks.
    It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
    """
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]
        ground_truth_mask = np.array(item["label"])
        prompt = get_bounding_box(ground_truth_mask)
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["ground_truth_mask"] = ground_truth_mask
        return inputs

In [9]:
def prepare_datasets(args):
    """Prepare datasets for training."""
    images = read_tiff(args.input_dir + 'images.tif')
    masks = read_tiff(args.input_dir + 'masks.tif')
    # Convert the NumPy arrays to Pillow images and store them in a dictionary
    dataset_dict = {
        "image": [Image.fromarray(img) for img in images],
        "label": [Image.fromarray(mask) for mask in masks],
    }
    # Create the dataset using the datasets.Dataset class
    dataset = Dataset.from_dict(dataset_dict)
    # Initialize the processor
    processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
    # Create an instance of the SAMDataset
    train_dataset = SAMDataset(dataset=dataset, processor=processor)
    return train_dataset, processor

In [10]:
def load_model():
    """Load SAM model."""
    model = SamModel.from_pretrained("facebook/sam-vit-base")
    for name, param in model.named_parameters():
        if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
            param.requires_grad_(False)
    return model

In [11]:
def train_sam(train_dataset, model, processor, args):
    """Train the SAM model."""
    # # Splitting the dataset into training and validation
    # train_size = int(0.9 * len(dataset))  # 90% for training
    # val_size = len(dataset) - train_size   # remaining 10% for validation
    # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    # Create a DataLoader instance for the training dataset
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=4, pin_memory=True)
    optimizer = torch.optim.Adam(model.mask_decoder.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
    # Training device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    # Start training
    total_start_time = time.time()
    model.train()
    for epoch in range(args.num_epochs):
        start_time = time.time()
        epoch_losses = []
        for batch in tqdm(train_dataloader):
            # forward passFar
            outputs = model(
                pixel_values=batch["pixel_values"].to(device),
                input_boxes=batch["input_boxes"].to(device),
                multimask_output=False)
            # compute loss
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
            # backward pass (compute gradients of parameters w.r.t. loss)
            optimizer.zero_grad()
            loss.backward()
            # optimize
            optimizer.step()
            epoch_losses.append(loss.item())
        elapsed_time = (time.time() - start_time) / 60
        print(f'EPOCH: {epoch}')
        print(f'Mean loss: {mean(epoch_losses)}')
        print(f'Time taken for the epoch: {elapsed_time:.2f} minutes\n')
    # Calculate total training time
    total_training_time = (time.time() - total_start_time) / 60
    print(f'Total training time: {total_training_time:.2f} minutes')   
    return model

In [12]:
def save_model(model, output_dir):
    """Save the trained model."""
    fs = gcsfs.GCSFileSystem(project='imposing-mind-398223')
    with fs.open(output_dir+"model_base.pth", 'wb') as f:
        torch.save(model.state_dict(), f)
    # torch.save(model.state_dict(), output_dir+"model_base.pth")

In [13]:
# def main():
#     """Main function to orchestrate model training."""
#     train_dataset, processor = prepare_datasets(args)
#     model = load_model()
#     model = train_sam(train_dataset, model, processor, args)
#     save_model(model, args.output_dir)

## Load, Train, and Save the Model

In [14]:
train_dataset, processor = prepare_datasets(args)

In [15]:
model = load_model()

In [16]:
torch.cuda.empty_cache()

In [17]:
model = train_sam(train_dataset, model, processor, args)

100%|██████████| 8750/8750 [1:12:33<00:00,  2.01it/s]


EPOCH: 0
Mean loss: 0.6678856571742466
Time taken for the epoch: 72.57 minutes



100%|██████████| 8750/8750 [1:11:30<00:00,  2.04it/s]


EPOCH: 1
Mean loss: 0.647871703345435
Time taken for the epoch: 71.50 minutes



100%|██████████| 8750/8750 [1:11:35<00:00,  2.04it/s]


EPOCH: 2
Mean loss: 0.6412016136407852
Time taken for the epoch: 71.60 minutes



100%|██████████| 8750/8750 [1:11:37<00:00,  2.04it/s]


EPOCH: 3
Mean loss: 0.6364234496559416
Time taken for the epoch: 71.63 minutes



100%|██████████| 8750/8750 [1:11:34<00:00,  2.04it/s]


EPOCH: 4
Mean loss: 0.6328254467146738
Time taken for the epoch: 71.58 minutes



100%|██████████| 8750/8750 [1:11:33<00:00,  2.04it/s]


EPOCH: 5
Mean loss: 0.6293139483554022
Time taken for the epoch: 71.55 minutes



100%|██████████| 8750/8750 [1:11:28<00:00,  2.04it/s]


EPOCH: 6
Mean loss: 0.6268264428002494
Time taken for the epoch: 71.47 minutes



100%|██████████| 8750/8750 [1:11:30<00:00,  2.04it/s]


EPOCH: 7
Mean loss: 0.6238799877234867
Time taken for the epoch: 71.51 minutes



100%|██████████| 8750/8750 [1:11:27<00:00,  2.04it/s]


EPOCH: 8
Mean loss: 0.621107563996315
Time taken for the epoch: 71.46 minutes



100%|██████████| 8750/8750 [1:11:22<00:00,  2.04it/s]

EPOCH: 9
Mean loss: 0.6186118943384715
Time taken for the epoch: 71.38 minutes

Total training time: 71.38 minutes





In [18]:
save_model(model, args.output_dir)

In [19]:
# if __name__ == "__main__":
#     main()