In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from model_v2 import CardClassifier

train_path = "/content/drive/MyDrive/playing_cards/train" #/content/drive/MyDrive/playing_cards/train, ../../playing_cards/train
valid_path = "/content/drive/MyDrive/playing_cards/valid" #/content/drive/MyDrive/playing_cards/valid, ../../playing_cards/valid

In [3]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
#set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [7]:
train_dataset = datasets.ImageFolder(train_path, transform=transform)
valid_dataset = datasets.ImageFolder(valid_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)



In [8]:
model = CardClassifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
num_epochs = 100  # Set the number of training epochs
early_stopper = EarlyStopper(patience=3, min_delta=10)  # Set early stopping configuration

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        # Move inputs and labels to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Calculate average training loss for the epoch
    average_train_loss = running_loss / len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in valid_loader:
            # Move inputs and labels to GPU
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # Calculate average validation loss and accuracy for the epoch
        average_val_loss = val_loss / len(valid_loader)
        accuracy = correct / total

        print(f"Epoch {epoch+1}/{num_epochs} => "
              f"Train Loss: {average_train_loss:.4f}, "
              f"Validation Loss: {average_val_loss:.4f}, "
              f"Validation Accuracy: {accuracy * 100:.2f}%")

        # Check for early stopping
        if early_stopper.early_stop(average_val_loss):
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break


Epoch 1/100 => Train Loss: 3.0056, Validation Loss: 1.7711, Validation Accuracy: 42.64%
Epoch 2/100 => Train Loss: 1.6750, Validation Loss: 1.0617, Validation Accuracy: 68.30%
Epoch 3/100 => Train Loss: 0.8561, Validation Loss: 1.0411, Validation Accuracy: 75.09%
Epoch 4/100 => Train Loss: 0.4019, Validation Loss: 1.1850, Validation Accuracy: 79.25%
Epoch 5/100 => Train Loss: 0.1916, Validation Loss: 1.1290, Validation Accuracy: 80.75%
Epoch 6/100 => Train Loss: 0.1266, Validation Loss: 1.3260, Validation Accuracy: 81.13%
Epoch 7/100 => Train Loss: 0.0909, Validation Loss: 1.4696, Validation Accuracy: 80.00%
Epoch 8/100 => Train Loss: 0.0661, Validation Loss: 1.2836, Validation Accuracy: 82.26%
Epoch 9/100 => Train Loss: 0.0399, Validation Loss: 1.4530, Validation Accuracy: 80.75%
Epoch 10/100 => Train Loss: 0.0502, Validation Loss: 1.5176, Validation Accuracy: 83.77%
Epoch 11/100 => Train Loss: 0.0593, Validation Loss: 1.2361, Validation Accuracy: 81.51%
Epoch 12/100 => Train Loss: 0.

In [10]:
torch.save(model.state_dict(), "card_classifier.pth")