In [4]:
import torch

In [2]:
import torch.nn as nn

import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

Set device to GPU if available

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using:", device)
print(torch.version.cuda)

Using: cuda
12.6


In [9]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])
train_data=datasets.ImageFolder('data/captured_images/',transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)



In [10]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [B, 32, 112, 112]
        x = self.pool(F.relu(self.conv2(x)))  # [B, 64, 56, 56]
        x = x.view(-1, 64 * 56 * 56)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleCNN(num_classes=len(train_data.classes)).to(device)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

Epoch 1, Loss: 6.9276
Epoch 2, Loss: 4.1575
Epoch 3, Loss: 4.1518
Epoch 4, Loss: 4.0878
Epoch 5, Loss: 3.9891
Epoch 6, Loss: 3.9271
Epoch 7, Loss: 3.8194
Epoch 8, Loss: 3.7327
Epoch 9, Loss: 3.5972
Epoch 10, Loss: 3.4839
