<a href="https://colab.research.google.com/github/vijaygwu/classideas/blob/main/SimpleCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration class to hold hyperparameters and paths
class Config:
    train_data_path = 'path/to/train/data'  # Specify your training data path
    test_data_path = 'path/to/test/data'    # Specify your test data path
    model_save_path = 'complex_cnn_torch.pth'
    epochs = 10
    learning_rate = 0.001
    batch_size = 32
    positive_label = "dog"
    negative_label = "cat"
    data_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

# CNN Architecture class
class ComplexCNN(nn.Module):
    def __init__(self):
        super(ComplexCNN, self).__init__()
        # Define layers here

    def forward(self, x):
        # Define forward pass here

# Dataset Handler class
class ImageDataset(Dataset):
    def __init__(self, data_root, transform=None):
        # Initialize dataset here

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Implement __getitem__ method here

# Plotting function for metrics
def plot_metrics(losses, accuracies, metric_name):
    epochs = range(1, len(losses) + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, '-o')
    plt.title('Loss over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracies, '-o')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.show()

# Training and evaluation routine
def train_and_evaluate():
    # Data Loaders
    train_loader = DataLoader(
        ImageDataset(Config.train_data_path, transform=Config.data_transform),
        batch_size=Config.batch_size, shuffle=True
    )
    test_loader = DataLoader(
        ImageDataset(Config.test_data_path, transform=Config.data_transform),
        batch_size=Config.batch_size, shuffle=False
    )
    # Initialize model, loss, and optimizer
    model = ComplexCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)

    # Training loop
    train_losses, train_accuracies = [], []
    for epoch in range(Config.epochs):
        model.train()
        running_loss, correct_predictions = 0.0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct_predictions += (predicted == labels).sum().item()

        average_loss = running_loss / len(train_loader)
        accuracy = correct_predictions / len(train_loader.dataset)
        train_losses.append(average_loss)
        train_accuracies.append(accuracy)
        print(f"Epoch [{epoch+1}/{Config.epochs}], Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}")

    # Save the model
    torch.save(model.state_dict(), Config.model_save_path)
    # Displaying loss and accuracy graphs
    plot_metrics(train_losses, train_accuracies, 'Accuracy')

# Visualizations
def visualize_loss_and_accuracy(losses, accuracies):
    epochs = range(1, len(losses) + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, '-o')
    plt.title('Loss over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracies, '-o')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train the model and/or Predict a given image.')
    parser.add_argument('--image_path', type=str, help='Path to the input image. If provided, model prediction will be executed.')
    args = parser.parse_args()

    # Training routine
    if not args.image_path:
        train_and_evaluate()

    # Prediction routine
    else:
        # Load Model
        model = ComplexCNN().to(device)
        model.load_state_dict(torch.load(Config.model_save_path))
        model.eval()

        # Preprocess Image
        data_transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        pil_image = Image.open(args.image_path)
        input_image = data_transform(pil_image).unsqueeze(0).to(device)

        # Make Prediction
        with torch.no_grad():
            outputs = model(input_image)
            predicted = torch.argmax(outputs).item()

        # Map the prediction to class labels
        class_labels = ['Cat', 'Dog']
        predicted_label = class_labels[predicted]

        # Display the image and prediction
        plt.imshow(pil_image)
        plt.title(f'Predicted: {predicted_label}')
        plt.axis('off')
        plt.show()
