In [8]:
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.functional import F
from PIL import UnidentifiedImageError, Image
from torchvision.datasets import ImageFolder
import os
import warnings 

os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
warnings.filterwarnings("ignore", category=UserWarning)


In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available and enabled.")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and enabled.")
else:
    device = torch.device("cpu")
    print("CUDA and MPS are not available. Using CPU.")


MPS is available and enabled.


In [10]:
dataset_path = "hw02_test_train"
batch_size = 16
num_workers = 4

transform = transforms.Compose([transforms.Resize((24, 24)), transforms.ToTensor()])

train_dataset = datasets.ImageFolder(root=dataset_path + "/train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_dataset = datasets.ImageFolder(root=dataset_path + "/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

def is_valid_image(path):
    try:
        im = Image.open(path)
        im.verify()
        return True
    except (IOError, SyntaxError, UnidentifiedImageError):
        return False

# Filter out invalid images from the datasets
train_dataset.samples = [(path, class_idx) for path, class_idx in train_dataset.samples if is_valid_image(path)]
test_dataset.samples = [(path, class_idx) for path, class_idx in test_dataset.samples if is_valid_image(path)]

# Update the data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [11]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 5)  # 5 categories

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
num_epochs = 10
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")



Epoch 1/10, Loss: 1.6117841078310597




Epoch 2/10, Loss: 1.547875426253494




Epoch 3/10, Loss: 1.39657859169707




Epoch 4/10, Loss: 1.261272623830912




Epoch 5/10, Loss: 1.117411759434914




Epoch 6/10, Loss: 1.0692467154288778




Epoch 7/10, Loss: 0.9934233560854074




Epoch 8/10, Loss: 0.936789254752957




Epoch 9/10, Loss: 0.8331286232082211




Epoch 10/10, Loss: 0.7588192181927818


In [13]:
correct = 0
total = 0

model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient calculation
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy on the test dataset: {accuracy:.2f}%')



Accuracy on the test dataset: 60.85%
