In [1]:
# Importing modules
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from sklearn.metrics import classification_report

In [2]:
# Data dir
train_dir = "./train"
val_dir = "./val"

# Transformation
transform = transforms.Compose([transforms.Resize(64),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])])

# Datasets
train_dataset = datasets.ImageFolder(train_dir, transform)
val_dataset = datasets.ImageFolder(val_dir, transform)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

In [3]:
# SimpleCNNNet model
class SimpleCNN(nn.Module):
    
    def __init__(self, n_classes=len(train_dataset.classes)):
        super(SimpleCNN, self).__init__()
        
        # Feature learninig layer
        self.features = nn.Sequential(
            # First convolution layer
            nn.Conv2d(3, 16, kernel_size=6, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Second convolution layer
            nn.Conv2d(16, 64, kernel_size=2, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        # Average pooling layer
        self.avgpool = nn.AdaptiveAvgPool2d(6)
        
        # Classification layer
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(64*6*6, 64),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, n_classes))
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
# Model
model = SimpleCNN()
model

SimpleCNN(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(6, 6), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 64, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=6)
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=2304, out_features=64, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=64, out_features=16, bias=True)
    (5): ReLU()
    (6): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [5]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Device
device = torch.device("cpu")

# Loading model to device
model = model.to(device)

In [6]:
# Training
for epoch in range(1, 11):
    
    # Training loop
    model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()
        X, y = batch
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        
    # Validation loop
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            X, y = batch
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            loss = criterion(y_hat, y)
            val_loss += loss.item()
    
    # Logging
    val_loss = val_loss/len(val_dataloader)
    print(f"Epoch: {epoch} | Validation loss: {val_loss}")

Epoch: 1 | Validation loss: 25.51868023591883
Epoch: 2 | Validation loss: 1.2126766592264175
Epoch: 3 | Validation loss: 0.6985159341026755
Epoch: 4 | Validation loss: 0.6962896199787364
Epoch: 5 | Validation loss: 0.6953094531508053
Epoch: 6 | Validation loss: 0.6948356733602636
Epoch: 7 | Validation loss: 0.6945925565326915
Epoch: 8 | Validation loss: 0.6944624921854805
Epoch: 9 | Validation loss: 0.694390637033126
Epoch: 10 | Validation loss: 0.6943500252331004
