#### imports

In [0]:
pip install torch torchvision matplotlib pillow


In [0]:
dbutils.library.restartPython()

In [0]:
import torch
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.models.segmentation import deeplabv3_resnet50
import matplotlib.pyplot as plt
import numpy as np

# Keep your existing dataset class
class generate_image_dataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = []
        self.labels = []
        self.class_name = {}
        self.transform = transform

        for label, class_dir in enumerate(os.listdir(image_dir)):
            self.class_name[label] = class_dir
            class_path = os.path.join(image_dir, class_dir)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Updated transform with ImageNet normalization (better for pre-trained models)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # DeepLab typically uses larger input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

train_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
test_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/test'

training_image_dataset = generate_image_dataset(image_dir=train_image_dir, transform=transform)
test_image_dataset = generate_image_dataset(image_dir=test_image_dir, transform=transform)

train_image_loader = DataLoader(dataset=training_image_dataset, batch_size=16, shuffle=True)  # Reduced batch size
test_image_loader = DataLoader(dataset=test_image_dataset, batch_size=16, shuffle=False)

class DeepLabV3Classifier(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV3Classifier, self).__init__()
        # Load pre-trained DeepLab v3
        self.deeplab = deeplabv3_resnet50(pretrained=True)
        
        # Extract the backbone (ResNet-50 encoder)
        self.backbone = self.deeplab.backbone
        
        # Add global average pooling and classification head
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),  # ResNet-50 outputs 2048 features
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
        
        # Freeze backbone layers (optional - remove these lines to fine-tune)
        for param in self.backbone.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        # Extract features using DeepLab backbone
        features = self.backbone(x)['out']  # Get the output features
        
        # Global average pooling
        pooled = self.global_avg_pool(features)
        pooled = pooled.view(pooled.size(0), -1)  # Flatten
        
        # Classification
        output = self.classifier(pooled)
        return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = DeepLabV3Classifier(num_classes=3).to(device)

criterion = nn.CrossEntropyLoss()
# Only train classifier if backbone is frozen, otherwise train all
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)  # Only classifier params
# optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Use this for full fine-tuning

epochs = 10

print("Starting training with DeepLab v3 feature extractor...")
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_image_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_image_loader)
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Average Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

print("Training completed!")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_image_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')

# Visualize a few predictions
model.eval()
with torch.no_grad():
    for images, labels in test_image_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        # Show first image
        img = images[0].cpu().numpy()
        img = np.transpose(img, (1, 2, 0))
        # Denormalize for display
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        plt.figure(figsize=(8, 6))
        plt.imshow(img)
        plt.title(f'True: {training_image_dataset.class_name[labels[0].item()]}, '
                  f'Predicted: {training_image_dataset.class_name[predicted[0].item()]}')
        plt.axis('off')
        plt.show()
        break