<a href="https://colab.research.google.com/github/vijaygwu/classideas/blob/main/MixupCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import argparse
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration
class Config:
    train_data_path = '/Users/vraghavan/Desktop/ClassificationTest/dogs-vs-cats/train'
    test_data_path = '/Users/vraghavan/Desktop/ClassificationTest/dogs-vs-cats/test1'
    model_save_path = 'mixup_cnn_torch.pth'
    epochs = 10
    learning_rate = 0.001
    batch_size = 32
    positive_label = "dog"
    negative_label = "cat"
    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

# MixUp augmentation
def mixup_data(x, y, alpha=1.0):
    lam = random.betavariate(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_y = lam * y + (1 - lam) * y[index, :]
    return mixed_x, mixed_y

# CutMix augmentation
def cutmix_data(x, y, alpha=1.0):
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    lam = random.betavariate(alpha, alpha)
    cut_ratio = int(x.size(2) * lam)

    x[:, :, :cut_ratio, :] = x[index, :, -cut_ratio:, :]
    y = (y, y[index])
    return x, y

# CNN Architecture
class ComplexCNN(nn.Module):
    def __init__(self):
        super(ComplexCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.fc1 = nn.Linear(128 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, 2)  # 2 classes: dog and cat
        self.dropout = nn.Dropout(0.5)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.max_pool(x)
        x = self.leaky_relu(self.conv2(x))
        x = self.max_pool(x)
        x = self.leaky_relu(self.conv3(x))
        x = self.max_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Dataset Handler
class ImageDataset(Dataset):
    def __init__(self, data_root, transform=None):
        self.image_files = [f for f in os.listdir(data_root) if f.endswith('.jpg')]
        self.data_root = data_root
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.data_root, self.image_files[idx])
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        label = 1 if Config.positive_label in self.image_files[idx] else 0
        return image, label

# Visualizations
def visualize_loss_and_accuracy(losses, accuracies):
    epochs = range(1, len(losses) + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, '-o')
    plt.title('Loss over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracies, '-o')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.show()

# Training routine
def train_and_evaluate(optimizer_type, augmentation_type):
    # Data Loaders
    train_loader = DataLoader(
        ImageDataset(Config.train_data_path, transform=Config.data_transform),
        batch_size=Config.batch_size, shuffle=True
    )
    test_loader = DataLoader(
        ImageDataset(Config.test_data_path, transform=Config.test_transform),
        batch_size=Config.batch_size, shuffle=False
    )
    # Initialize model, loss, and optimizer
    model = ComplexCNN().to(device)
    criterion = nn.CrossEntropyLoss()

    if optimizer_type == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=Config.learning_rate)
    elif optimizer_type == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=Config.learning_rate, momentum=0.9)
    else:
        optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)

    # Learning rate scheduling
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.1, verbose=True)

    # Training loop
    train_losses, train_accuracies = [], []
    for epoch in range(Config.epochs):
        model.train()
        running_loss, correct_predictions = 0.0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            if augmentation_type == 'mixup':
                images, labels = mixup_data(images, labels)
            elif augmentation_type == 'cutmix':
                images, labels = cutmix_data(images, labels)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct_predictions += (predicted == labels).sum().item()

        average_loss = running_loss / len(train_loader)
        accuracy = correct_predictions / len(train_loader.dataset)
        train_losses.append(average_loss)
        train_accuracies.append(accuracy)
        print(f"Epoch [{epoch+1}/{Config.epochs}], Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}")

        scheduler.step(average_loss)  # Learning rate scheduling

    # Save the model
    torch.save(model.state_dict(), Config.model_save_path)
    # Displaying loss and accuracy graphs
    visualize_loss_and_accuracy(train_losses, train_accuracies)

# Test a single image
def test_single_image(image_path, model):
    data_transform = Config.test_transform

    model.eval()

    with torch.no_grad():
        pil_image = Image.open(image_path)
        input_image = data_transform(pil_image)
        input_image = input_image.unsqueeze(0)  # Add batch dimension
        outputs = model(input_image.to(device))
        _, predicted = torch.max(outputs, 1)

    class_labels = ['Dog', 'Cat']
    predicted_label = class_labels[predicted.item()]

    plt.imshow(pil_image)
    plt.title(f'Predicted: {predicted_label}')
    plt.axis('off')
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train the model and/or Predict a given image.')
    parser.add_argument('--image_path', type=str, help='Path to the input image. If provided, model prediction will be executed.')
    parser.add_argument('--optimizer', type=str, choices=['adam', 'rmsprop', 'sgd'], default='adam', help='Choose optimizer (adam, rmsprop, sgd)')
    parser.add_argument('--augmentation', type=str, choices=['mixup', 'cutmix'], default='mixup', help='Choose augmentation technique (mixup, cutmix)')
    args = parser.parse_args()

    if not args.image_path:
        train_and_evaluate(args.optimizer, args.augmentation)
    else:
        model = ComplexCNN().to(device)
        model.load_state_dict(torch.load(Config.model_save_path))
        model.eval()
        test_single_image(args.image_path, model)
