In [27]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [28]:
class UnderwaterDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.img_dir = os.path.join(root_dir, split)
        self.label_dir = os.path.join(root_dir, f"{split}_labels")
        self.img_files = [f for f in os.listdir(self.img_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        label_path = os.path.join(self.label_dir, self.img_files[idx].replace('.jpg', '.txt'))

        image = Image.open(img_path).convert("RGB")

        boxes = []
        labels = []
        with open(label_path, 'r') as file:
            for line in file.readlines():
                data = line.strip().split()
                label = int(data[0])
                x_center, y_center, width, height = map(float, data[1:])
                xmin = x_center - width / 2
                ymin = y_center - height / 2
                xmax = x_center + width / 2
                ymax = y_center + height / 2
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(label)

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)

        if self.transform:
            image = self.transform(image)

        return image, {'boxes': boxes, 'labels': labels}

In [29]:
class LightViT(nn.Module):
    def __init__(self, num_classes=7, image_size=256, patch_size=16, dim=128, depth=4, heads=4, mlp_dim=256):
        super(LightViT, self).__init__()
        assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."

        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        # Patch Embedding
        self.patch_embedding = nn.Linear(patch_dim, dim)

        # Positional Encoding
        self.positional_encoding = nn.Parameter(torch.randn(1, num_patches, dim))

        # Transformer Encoder Layers
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim),
            num_layers=depth
        )

        # MLP Head for Bounding Box Prediction
        self.to_bbox = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, 4)
        )

        # MLP Head for Classification
        self.to_class = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        batch_size, _, _, _ = x.shape

        # Split image into patches
        x = x.view(batch_size, 3, -1, 16, 16).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, -1, 16 * 16 * 3)

        # Patch Embedding
        x = self.patch_embedding(x)

        # Add positional encoding
        x += self.positional_encoding

        # Transformer encoding
        x = self.transformer(x)

        # Bounding box prediction and classification
        bbox_pred = self.to_bbox(x.mean(dim=1))
        class_pred = self.to_class(x.mean(dim=1))

        return bbox_pred, class_pred

# Example Usage
model_vit = LightViT(num_classes=7)
sample_image = torch.rand((2, 3, 256, 256))  # Batch of 2 images
bbox_pred, class_pred = model_vit(sample_image)
print(f"BBox Prediction: {bbox_pred.shape}, Class Prediction: {class_pred.shape}")


BBox Prediction: torch.Size([2, 4]), Class Prediction: torch.Size([2, 7])


In [30]:
class LightYOLO(nn.Module):
    def __init__(self, num_classes=7):
        super(LightYOLO, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # Downsample
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        # Prediction heads
        self.bbox_head = nn.Linear(128 * 16 * 16, 4)
        self.class_head = nn.Linear(128 * 16 * 16, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten

        # Bounding box prediction
        bbox_pred = self.bbox_head(x)
        class_pred = self.class_head(x)

        return bbox_pred, class_pred

# Example Usage
model_yolo = LightYOLO(num_classes=7)
bbox_pred, class_pred = model_yolo(sample_image)
print(f"BBox Prediction: {bbox_pred.shape}, Class Prediction: {class_pred.shape}")


BBox Prediction: torch.Size([2, 4]), Class Prediction: torch.Size([2, 7])


In [31]:
class LightViT_YOLO(nn.Module):
    def __init__(self, num_classes=7, max_objects=10):
        super(LightViT_YOLO, self).__init__()
        self.max_objects = max_objects
        self.vit = LightViT(num_classes)
        self.yolo = LightYOLO(num_classes)
        self.final_bbox_head = nn.Linear(8, 4 * max_objects)
        self.final_class_head = nn.Linear(14, num_classes * max_objects)

    def forward(self, x):
        vit_bbox, vit_class = self.vit(x)
        yolo_bbox, yolo_class = self.yolo(x)
        combined_bbox = torch.cat((vit_bbox, yolo_bbox), dim=1)
        combined_class = torch.cat((vit_class, yolo_class), dim=1)
        final_bbox = self.final_bbox_head(combined_bbox).view(-1, self.max_objects, 4)
        final_class = self.final_class_head(combined_class).view(-1, self.max_objects, 7)
        return final_bbox, final_class
    
    # Example Usage
model_combined = LightViT_YOLO(num_classes=7)
bbox_pred, class_pred = model_combined(sample_image)
print(f"BBox Prediction: {bbox_pred.shape}, Class Prediction: {class_pred.shape}")

BBox Prediction: torch.Size([2, 10, 4]), Class Prediction: torch.Size([2, 10, 7])


In [41]:
def collate_fn(batch):
    images = []
    boxes = []
    labels = []
    
    for item in batch:
        images.append(item[0])
        boxes.append(item[1]['boxes'])
        labels.append(item[1]['labels'])

    images = torch.stack(images)
    
    max_objects = max(len(b) for b in boxes)
    padded_boxes = []
    padded_labels = []

    for b, l in zip(boxes, labels):
        if len(b) < max_objects:
            pad_size = max_objects - len(b)
            padded_boxes.append(torch.cat([b, torch.zeros(pad_size, 4)]))
            padded_labels.append(torch.cat([l, torch.zeros(pad_size, dtype=torch.long)]))
        else:
            padded_boxes.append(b[:max_objects])
            padded_labels.append(l[:max_objects])

    return images, {
        'boxes': torch.stack(padded_boxes),
        'labels': torch.stack(padded_labels)
    }

In [45]:
def train_model(model, train_loader, num_epochs=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, targets in train_loader:
            images = images.to(device)
            bbox_targets = targets['boxes'].to(device)
            class_targets = targets['labels'].to(device)

            optimizer.zero_grad()
            bbox_pred, class_pred = model(images)

            # Ensure bbox_pred and bbox_targets have the same shape
            max_objects = bbox_targets.size(1)
            bbox_pred = bbox_pred[:, :max_objects, :]

            # Calculate losses
            loss_bbox = nn.functional.mse_loss(bbox_pred, bbox_targets)
            loss_class = criterion(class_pred[:, :max_objects, :].reshape(-1, 7), class_targets.reshape(-1))
            loss = loss_bbox + loss_class
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")

In [46]:
def evaluate_model(model, test_loader):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    total = 0
    correct = 0

    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)
            bbox_pred, class_pred = model(images)
            
            for i in range(len(images)):
                # Get the number of actual objects in this image
                num_objects = (targets['labels'][i] != 0).sum().item()
                
                # Only compare predictions for actual objects
                pred_classes = class_pred[i, :num_objects].argmax(dim=1)
                true_classes = targets['labels'][i, :num_objects].to(device)
                
                total += num_objects
                correct += (pred_classes == true_classes).sum().item()

    print(f"Accuracy: {correct / total}")

In [1]:
def plot_predictions(image, bbox_pred, class_pred, class_names):
    fig, ax = plt.subplots(1)
    ax.imshow(image.permute(1, 2, 0).cpu().numpy())

    for bbox, cls_scores in zip(bbox_pred, class_pred):
        if bbox.sum() > 0:  # Only plot non-zero boxes
            xmin, ymin, xmax, ymax = bbox.cpu().numpy()
            rect = patches.Rectangle((xmin*256, ymin*256), (xmax-xmin)*256, (ymax-ymin)*256, 
                                     linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            cls_name = class_names[cls_scores.argmax()]
            plt.text(xmin*256, ymin*256, cls_name, fontsize=8, color='r')

    plt.axis('off')
    plt.show()

In [None]:
# Main execution
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    train_dataset = UnderwaterDataset(root_dir='data/USIS10K', split='train', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

    model = LightViT_YOLO(num_classes=7, max_objects=10)
    train_model(model, train_loader)

    # Save the model
    torch.save(model.state_dict(), 'multi_object_model.pth')

    # Evaluation
    test_dataset = UnderwaterDataset(root_dir='data/USIS10K', split='test', transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
    
    evaluate_model(model, test_loader)

    # Visualization
    class_names = ["wrecks/ruins", "fish", "reefs", "aquatic plants", "human divers", "robots", "sea-floor"]
    model.eval()
    
    for images, targets in test_loader:
        images = images.to(device)
        bbox_pred, class_pred = model(images)
        
        for i in range(10):
            num_objects = (targets['labels'][i] != 0).sum().item()
            
            
            plot_predictions(images[i], bbox_pred[i, :num_objects].detach(), class_pred[i, :num_objects].detach(), class_names)
            input("Press Enter to continue...")