In [0]:
import torch
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import timm  # Install with: pip install timm
import torch.nn as nn
import torch.optim as optim000
import matplotlib.pyplot as plt
import numpy as np

# Check available classes
image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
for label, class_dir in enumerate(os.listdir(image_dir)):
    print(label, class_dir)

class generate_image_dataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = []
        self.labels = []
        self.class_name = {}
        self.transform = transform

        for label, class_dir in enumerate(os.listdir(image_dir)):
            self.class_name[label] = class_dir
            class_path = os.path.join(image_dir, class_dir)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Updated transforms for Swin Transformer (224x224 is standard)
# Using ImageNet normalization values
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Swin Transformer typically uses 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

train_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
test_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/test'

training_image_dataset = generate_image_dataset(image_dir=train_image_dir, transform=transform)
test_image_dataset = generate_image_dataset(image_dir=test_image_dir, transform=transform)

train_image_loader = DataLoader(dataset=training_image_dataset, batch_size=32, shuffle=True)
test_image_loader = DataLoader(dataset=test_image_dataset, batch_size=32, shuffle=True)

# Check data shape
for images, labels in train_image_loader:
    print(f"Image batch shape: {images.shape}, Label batch shape: {labels.shape}")
    break

# Visualize a sample (adjust for ImageNet normalization)
for images, labels in train_image_loader:
    img = images[0].numpy()
    label = labels[0].item()
    
    print(f"Class: {training_image_dataset.class_name[label]}")
    
    # Denormalize for visualization
    img = np.transpose(img, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    plt.imshow(img)
    plt.title(f"Class: {training_image_dataset.class_name[label]}")
    plt.axis('off')
    plt.show()
    break

class SwinTransformerClassifier(nn.Module):
    def __init__(self, num_classes, model_name='swin_tiny_patch4_window7_224', pretrained=True):
        super(SwinTransformerClassifier, self).__init__()
        
        # Load pre-trained Swin Transformer
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=0,  # Remove the original classifier head
            global_pool=''  # Remove global pooling, we'll add our own
        )
        
        # Get the feature dimension from the backbone
        self.feature_dim = self.backbone.num_features
        
        # Add global average pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Extract features using Swin Transformer backbone
        features = self.backbone(x)  # Shape: [batch_size, seq_len, feature_dim]
        
        # Global average pooling across sequence dimension
        pooled = self.global_pool(features.transpose(1, 2)).squeeze(-1)  # Shape: [batch_size, feature_dim]
        
        # Classification
        output = self.classifier(pooled)
        return output

# Alternative simpler approach using timm's built-in classifier
class SwinTransformerSimple(nn.Module):
    def __init__(self, num_classes, model_name='swin_tiny_patch4_window7_224', pretrained=True):
        super(SwinTransformerSimple, self).__init__()
        
        # Load pre-trained Swin Transformer with custom number of classes
        self.model = timm.create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=num_classes
        )
    
    def forward(self, x):
        return self.model(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model - you can choose between the two approaches above
# Option 1: Custom classifier head
# model = SwinTransformerClassifier(num_classes=3, model_name='swin_tiny_patch4_window7_224', pretrained=True).to(device)

# Option 2: Simple approach (recommended)
model = SwinTransformerSimple(num_classes=3, model_name='swin_tiny_patch4_window7_224', pretrained=True).to(device)

# Print model info
print(f"Model: {model}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)  # AdamW is better for transformers
epochs = 2

# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(train_image_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(train_image_loader)}, "
                  f"Loss: {loss.item():.4f}")
    
    epoch_loss = running_loss / len(train_image_loader)
    epoch_accuracy = 100 * correct_predictions / total_samples
    
    print(f"Epoch {epoch+1}/{epochs} - Average Loss: {epoch_loss:.4f}, "
          f"Training Accuracy: {epoch_accuracy:.2f}%")

print("Training completed!")

# Evaluation function
def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# Evaluate the model
test_accuracy = evaluate_model(model, test_image_loader, device)