In [1]:
import os
import random
import json
import numpy as np
from PIL import Image
from pycocotools import mask as coco_mask
from transformers import SamModel, SamProcessor
from raffm import RaFFM
import torch
import torch.nn.functional as F

# Initialize the original SAM model and processor
original_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# RaFFM configuration and submodel initialization
elastic_config = {
    "atten_out_space": [1280],
    "inter_hidden_space": [2048],
    "residual_hidden_space": [2048],
}
raffm_model = RaFFM(original_model, elastic_config=elastic_config)
submodel, params, config = raffm_model.random_resource_aware_model()
submodel = submodel.to("cuda")  # Move submodel to GPU

models = {"submodel": submodel, "original_model": original_model}

def get_image_info(dataset_directory, num_images=1):
    image_mask_pairs = []
    for filename in os.listdir(dataset_directory):
        if filename.endswith(".jpg"):
            image_path = os.path.join(dataset_directory, filename)
            mask_filename = filename.replace(".jpg", ".json")
            mask_path = os.path.join(dataset_directory, mask_filename)
            if os.path.exists(mask_path):
                image_mask_pairs.append((image_path, mask_path))
    selected_pairs = random.sample(image_mask_pairs, min(num_images, len(image_mask_pairs)))
    return selected_pairs

def get_ground_truth_masks(mask_path):
    binary_masks = []
    with open(mask_path, 'r') as json_file:
        mask_data = json.load(json_file)
    for annotation in mask_data['annotations']:
        rle_mask = annotation['segmentation']
        binary_mask = coco_mask.decode(rle_mask)
        binary_masks.append(binary_mask)
    return binary_masks

def calculate_metrics(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    iou = intersection / union if union != 0 else 0
    return iou

def valid_points_from_masks(gt_masks):
    points = []
    for mask in gt_masks:
        ys, xs = np.where(mask > 0)
        points += [(x, y) for x, y in zip(xs, ys)]
    return points

def focal_loss(inputs, targets, alpha=0.25, gamma=2.0, reduction='mean'):
    BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
    pt = torch.exp(-BCE_loss)  # Prevents nans when probability 0
    F_loss = alpha * (1-pt)**gamma * BCE_loss
    if reduction == 'mean':
        return F_loss.mean()
    elif reduction == 'sum':
        return F_loss.sum()
    else:
        return F_loss

def dice_loss(inputs, targets, smooth=1e-6):
    inputs = torch.sigmoid(inputs)
    inputs = inputs.reshape(-1)
    targets = targets.reshape(-1)
    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    return 1 - dice

dataset_directory = "SA1B"
selected_images = get_image_info(dataset_directory, num_images=5)

# Process each selected image and its corresponding mask
for image_path, mask_path in selected_images:
    original_image = Image.open(image_path).convert("RGB")
    ground_truth_masks = get_ground_truth_masks(mask_path)
    valid_points = valid_points_from_masks(ground_truth_masks)
    
    for _ in range(5):  # Process 5 random points from valid areas
        if not valid_points:
            continue
        input_point = random.choice(valid_points)
        raw_image = np.array(original_image)
        relevant_gt_mask = next((mask for mask in ground_truth_masks if mask[input_point[1], input_point[0]] > 0), None)
        if relevant_gt_mask is None:
            continue
        
        input_points = [[input_point]]
        inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
        
        with torch.no_grad():
            submodel_outputs = models["submodel"](**inputs)
            submodel_masks = processor.image_processor.post_process_masks(
                submodel_outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
            )
        
        submodel_predicted_mask = submodel_masks[0].squeeze(0).squeeze(0).cpu()[1] if submodel_masks else None
        
        # Calculate IoU, focal loss, and Dice loss
        if submodel_predicted_mask is not None and relevant_gt_mask is not None:
            iou_submodel = calculate_metrics(submodel_predicted_mask, relevant_gt_mask)
            submodel_predicted_mask = torch.sigmoid(torch.tensor(submodel_predicted_mask)).unsqueeze(0)  # Add batch dimension and apply sigmoid
            relevant_gt_mask_tensor = torch.tensor(relevant_gt_mask, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
            fl_loss = focal_loss(submodel_predicted_mask, relevant_gt_mask_tensor)
            dl_loss = dice_loss(submodel_predicted_mask, relevant_gt_mask_tensor)
            print(f"Point: {input_point}, IoU (Submodel): {iou_submodel:.4f}, Focal Loss: {fl_loss.item()}, Dice Loss: {dl_loss.item()}")
        else:
            print(f"No metrics calculated for the point {input_point}")


2024-02-19 11:42:42.904412: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-19 11:42:42.942654: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  submodel_predicted_mask = torch.sigmoid(torch.tensor(submodel_predicted_mask)).unsqueeze(0)  # Add batch dimension and apply sigmoid


Point: (173, 264), IoU (Submodel): 0.8678, Focal Loss: 0.07372962683439255, Dice Loss: 0.6223226189613342
Point: (655, 1036), IoU (Submodel): 0.2203, Focal Loss: 0.0745801329612732, Dice Loss: 0.629086971282959
Point: (717, 1233), IoU (Submodel): 0.7989, Focal Loss: 0.0741911455988884, Dice Loss: 0.6124395728111267
Point: (1155, 639), IoU (Submodel): 0.7939, Focal Loss: 0.09358198195695877, Dice Loss: 0.9802486896514893
Point: (1944, 1453), IoU (Submodel): 0.7668, Focal Loss: 0.0748417004942894, Dice Loss: 0.6123348474502563
Point: (1685, 430), IoU (Submodel): 0.8852, Focal Loss: 0.0689886137843132, Dice Loss: 0.5665760636329651
Point: (1584, 327), IoU (Submodel): 0.4256, Focal Loss: 0.06988122314214706, Dice Loss: 0.5789762735366821
Point: (735, 768), IoU (Submodel): 0.0063, Focal Loss: 0.07077154517173767, Dice Loss: 0.5904623866081238
Point: (2147, 1375), IoU (Submodel): 0.1109, Focal Loss: 0.07686154544353485, Dice Loss: 0.6679875254631042
Point: (1674, 433), IoU (Submodel): 0.8853

In [3]:
import os
import random
import json
import numpy as np
from PIL import Image
from pycocotools import mask as coco_mask
from transformers import SamModel, SamProcessor
from raffm import RaFFM
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'


# Initialize the original SAM model and processor
original_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# RaFFM configuration and submodel initialization
elastic_config = {
    "atten_out_space": [1280],
    "inter_hidden_space": [2048],
    "residual_hidden_space": [2048],
}
raffm_model = RaFFM(original_model, elastic_config=elastic_config)
submodel, params, config = raffm_model.random_resource_aware_model()
submodel = submodel.to("cuda")  # Move submodel to GPU

def get_image_info(dataset_directory, num_images=1):
    image_mask_pairs = []
    for filename in os.listdir(dataset_directory):
        if filename.endswith(".jpg"):
            image_path = os.path.join(dataset_directory, filename)
            mask_filename = filename.replace(".jpg", ".json")
            mask_path = os.path.join(dataset_directory, mask_filename)
            if os.path.exists(mask_path):
                image_mask_pairs.append((image_path, mask_path))
    selected_pairs = random.sample(image_mask_pairs, min(num_images, len(image_mask_pairs)))
    return selected_pairs

def get_ground_truth_masks(mask_path):
    binary_masks = []
    with open(mask_path, 'r') as json_file:
        mask_data = json.load(json_file)
    for annotation in mask_data['annotations']:
        rle_mask = annotation['segmentation']
        binary_mask = coco_mask.decode(rle_mask)
        binary_masks.append(binary_mask)
    return binary_masks

def calculate_metrics(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    iou = intersection / union if union != 0 else 0
    return iou

def valid_points_from_masks(gt_masks):
    points = []
    for mask in gt_masks:
        ys, xs = np.where(mask > 0)
        points += [(x, y) for x, y in zip(xs, ys)]
    return points

def focal_loss(inputs, targets, alpha=0.25, gamma=2.0, reduction='mean'):
    BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
    pt = torch.exp(-BCE_loss)  # Prevents nans when probability 0
    F_loss = alpha * (1-pt)**gamma * BCE_loss
    if reduction == 'mean':
        return F_loss.mean()
    elif reduction == 'sum':
        return F_loss.sum()
    else:
        return F_loss

def dice_loss(inputs, targets, smooth=1e-6):
    inputs = torch.sigmoid(inputs)
    inputs = inputs.reshape(-1)
    targets = targets.reshape(-1)
    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    return 1 - dice

# Define loss functions
def focal_loss(inputs, targets, alpha=0.25, gamma=2.0, reduction='mean'):
    BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
    pt = torch.exp(-BCE_loss)  # Prevents nans when probability 0
    F_loss = alpha * (1 - pt) ** gamma * BCE_loss
    return F_loss.mean() if reduction == 'mean' else F_loss.sum()

def dice_loss(inputs, targets, smooth=1e-6):
    inputs = torch.sigmoid(inputs)
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    return 1 - dice

# Define dataset and dataloader
class SA1BDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_directory):
        self.image_mask_pairs = get_image_info(dataset_directory, num_images=20)  # Adjust num_images as needed
        self.transform = transforms.Compose([
            transforms.Resize((1024, 1024)),  # Resize images to 1024x1024
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard normalization for ImageNet
        ])

    def __len__(self):
        return len(self.image_mask_pairs)

    def __getitem__(self, idx):
        image_path, mask_path = self.image_mask_pairs[idx]
        image = Image.open(image_path).convert("RGB")
        mask = get_ground_truth_masks(mask_path)[0]  # Assuming there is at least one mask per image

        image = self.transform(image)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        mask = transforms.functional.resize(mask, (1024, 1024), interpolation=transforms.InterpolationMode.NEAREST)  # Resize mask

        return image, mask

dataset = SA1BDataset("SA1B")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # Batch size set to 1 for simplicity

# Training loop
optimizer = torch.optim.Adam(submodel.parameters(), lr=0.001)
num_epochs = 10
accumulation_steps = 4  # Adjust based on your memory constraints

for epoch in range(num_epochs):
    submodel.train()
    running_loss = 0.0

    for images, masks in dataloader:
        images = images.to("cuda")
        masks = masks.to("cuda")

        optimizer.zero_grad()

        # Forward pass
        outputs = submodel(images)['pred_masks']
        pred_masks = outputs.squeeze(1)  # Remove the extra dimension

        # Select the appropriate channel and squeeze the channel dimension
        pred_masks = pred_masks[:, 0, :, :].unsqueeze(1)  # Assuming the first channel is the class of interest

        # Resize predicted masks to match target masks
        pred_masks_resized = torch.nn.functional.interpolate(pred_masks, size=(1024, 1024), mode='bilinear', align_corners=False)

        # Compute the loss
        loss_focal = focal_loss(pred_masks_resized, masks)
        loss_dice = dice_loss(pred_masks_resized, masks)
        loss = loss_focal + loss_dice

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

print('Training complete')

Epoch 1/10, Loss: 1.1567
Epoch 2/10, Loss: 1.0052
Epoch 3/10, Loss: 0.9932
Epoch 4/10, Loss: 0.9910
Epoch 5/10, Loss: 0.9888
Epoch 6/10, Loss: 0.9898
Epoch 7/10, Loss: 0.9897
Epoch 8/10, Loss: 0.9927
Epoch 9/10, Loss: 0.9899
Epoch 10/10, Loss: 0.9895
Training complete


In [4]:
# Save the trained model
model_path = "trained_submodel.pth"
torch.save(submodel.state_dict(), model_path)
print(f"Model saved to {model_path}")


Model saved to trained_submodel.pth
