In [0]:
import torch
import os
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
for label, class_dir in enumerate(os.listdir(image_dir)):
    print(label,class_dir)

class generate_image_dataset(Dataset):
    def __init__(self,image_dir, transform=None):
        self.image_dir=image_dir
        self.image_paths=[]
        self.labels=[]
        self.class_name={}
        self.transform = transform

        for label, class_dir in enumerate(os.listdir(image_dir)):
            self.class_name[label] =class_dir
            class_path = os.path.join(image_dir,class_dir)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path,img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self,idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')     
        label = self.labels[idx]             

        if self.transform:
            image = self.transform(image)     

        return image, label

# Updated transform for ShuffleNet (ImageNet normalization)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ShuffleNet typically uses 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

train_image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
test_image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/test'

training_image_dataset = generate_image_dataset(image_dir=train_image_dir,transform=transform)
test_image_dataset = generate_image_dataset(image_dir=test_image_dir,transform=transform)

train_image_loader = DataLoader(dataset=training_image_dataset, batch_size=32,shuffle=True)
test_image_loader = DataLoader(dataset=test_image_dataset, batch_size=32,shuffle=True)

for images, labels in train_image_loader:
    print(images.shape, labels.shape)

# Visualization (adjust for ImageNet normalization)
for images, labels in train_image_loader:
    print(images.shape, labels.shape)
    img = images[0].numpy()
    label = labels[0].item()

    print(training_image_dataset.class_name[label])
    img = np.transpose(img, (1, 2, 0))
    
    # Denormalize for visualization
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = img * std + mean
    img = np.clip(img, 0, 1)
    
    print(img.shape)
    print(label)

    plt.imshow(img)
    break

# ShuffleNet Model Setup
def create_shufflenet_model(num_classes=3, pretrained=True):
    """
    Create ShuffleNet v2 model with custom number of classes
    """
    # Load pre-trained ShuffleNet v2 (1.0x variant)
    model = models.shufflenet_v2_x1_0(pretrained=pretrained)
    
    # Get the number of input features to the final classifier
    num_features = model.fc.in_features
    
    # Replace the final classifier layer for your number of classes
    model.fc = nn.Linear(num_features, num_classes)
    
    return model

# Alternative: Create ShuffleNet from scratch (no pre-training)
def create_shufflenet_no_pretrain(num_classes=3):
    """
    Create ShuffleNet v2 model without pre-training
    """
    model = models.shufflenet_v2_x1_0(pretrained=False, num_classes=num_classes)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model - choose one of the following:

# Option 1: Pre-trained ShuffleNet (recommended for better performance)
model = create_shufflenet_model(num_classes=3, pretrained=True).to(device)

# Option 2: ShuffleNet from scratch (uncomment to use)
# model = create_shufflenet_no_pretrain(num_classes=3).to(device)

print(f"Model created: {model.__class__.__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 2

# Training loop
for i in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_image_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f'Epoch {i+1}/{epochs}, Batch {batch_idx}/{len(train_image_loader)}, '
                  f'Loss: {loss.item():.4f}')
    
    epoch_accuracy = 100 * correct / total
    print(f"Epoch {i+1}/{epochs}, Average Loss: {running_loss/len(train_image_loader):.4f}, "
          f"Accuracy: {epoch_accuracy:.2f}%")

print("Training completed!")

# Evaluation on test set
model.eval()
test_correct = 0
test_total = 0
test_loss = 0.0

with torch.no_grad():
    for images, labels in test_image_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_accuracy = 100 * test_correct / test_total
print(f"Test Loss: {test_loss/len(test_image_loader):.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Model summary
print(f"\nModel Summary:")
print(f"Architecture: ShuffleNet v2 1.0x")
print(f"Input size: 224x224x3")
print(f"Number of classes: 3")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: ~2.3M parameters")
print(f"Expected inference speed: ~10ms on mobile CPU")