In [None]:
import torch
from transformers import SamProcessor, SamModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
import os

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torch
import os
import numpy as np
from segment_anything import SamProcessor

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, processor):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.processor = processor
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image = Image.open(os.path.join(self.image_dir, image_file)).convert("RGB")
        mask_path = os.path.join(self.mask_dir, image_file)
        if not os.path.exists(mask_path):
            print(f"Warning: Mask file for {image_file} not found, skipping this image.")
            return self.__getitem__((idx + 1) % len(self))
        mask = Image.open(mask_path).convert("L") 
        inputs = self.processor(images=image, return_tensors="pt")
        mask = torch.tensor(np.array(mask), dtype=torch.long)

        return inputs, mask

In [None]:
train_mask_dir = "E:/sam-vit-data-all/Exclusive-multicolor"
train_image_dir = "E:/sam-vit-data-all/converted_images"
test_mask_dir = "E:/sam-vit-data-all/Exclusive-multicolor_test"
test_image_dir = "E:/sam-vit-data-all/converted_images_test"

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, processor)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, processor)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [None]:
def dice_loss(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    dice_coefficient = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice_coefficient

In [None]:

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")  
model = SamModel.from_pretrained("facebook/sam-vit-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
import torch
import torch.nn as nn
criterion = nn.CrossEntropyLoss()

In [None]:
def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch_idx, (inputs, masks) in enumerate(train_loader):
        inputs = {key: val.squeeze().to(device) for key, val in inputs.items()}
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            outputs = model(**inputs)
        predicted_masks = outputs.masks 
        preds = predicted_masks.argmax(dim=1)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Training Loss: {avg_loss:.4f}")
    return avg_loss

In [None]:
def evaluate(model, test_loader, device):
    model.eval()
    total_correct = 0
    total_pixels = 0
    with torch.no_grad():
        for inputs, masks in test_loader:
            inputs = {key: val.squeeze().to(device) for key, val in inputs.items()}
            masks = masks.to(device)
            outputs = model(**inputs)
            predicted_masks = outputs.masks  
            preds = predicted_masks.argmax(dim=1)
            total_correct += (preds == masks).sum().item()
            total_pixels += masks.numel()
    accuracy = total_correct / total_pixels
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_loss = train(model, train_loader, optimizer, device)
    test_accuracy = evaluate(model, test_loader, device)
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")

In [None]:

model.eval()
image_path = "E:/image.jpg" 
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits
predicted_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
import numpy as np
def colorize_mask(mask):
    colors = np.array([
        [0, 0, 0],        
        [255, 0, 0],      
        [0, 255, 0],      
        [0, 0, 255],      
        [255, 255, 0],    
        [255, 0, 255],    
        [0, 255, 255],    
        # Light colors
        [255, 182, 193],  
        [255, 228, 196],  
        [255, 240, 245],  
        [255, 222, 173],  
        [255, 255, 224],  
        [216, 191, 216],  
        [255, 240, 245],  
        [255, 239, 196],   
        [219, 112, 147],  
        [176, 224, 230],  
        [255, 250, 205], 
        [152, 251, 152],  
        [144, 238, 144],  
        [173, 216, 230],  
        [255, 182, 193],  
        [240, 128, 128],  
        [240, 230, 140],  
        [250, 250, 210],  
        [253, 253, 150],  
        [236, 240, 252],  
        [255, 228, 225],  
        [255, 218, 185],  
        [245, 245, 220],  
        [255, 218, 185],  
        [255, 248, 220],  
        [255, 228, 181],  
        [240, 230, 140],  
        [255, 241, 199],  
        [237, 249, 255],  
        [216, 191, 216],  
    ])
    
    return colors
    color_mask = colors[mask]
    return color_mask

colored_mask = colorize_mask(predicted_mask)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(colored_mask)
plt.title("Predicted Mask")
plt.axis("off")

plt.show()