In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import numpy as np

class CustomFashionMNIST(Dataset):
    def __init__(self, root='./data', train=True, transform=None):
        self.original_dataset = datasets.FashionMNIST(
            root=root,
            train=train,
            download=True
        )

        self.data = self.original_dataset.data.numpy()
        self.targets = self.original_dataset.targets.numpy()
        self.data = self.data.astype(np.float32) / 255.0
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.targets[idx]
        image = torch.FloatTensor(image).unsqueeze(0)

        if self.transform:
            image = self.transform(image)

        return image, label

class FashionNet(nn.Module):
    def __init__(self):
        super(FashionNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def save_best_model(model, optimizer, epoch, accuracy):
    """Save only the best model weights based on accuracy."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': accuracy
    }
    best_filename = 'model.pt'
    torch.save(checkpoint, best_filename)
    print(f"Best model saved to {best_filename} with accuracy: {accuracy:.2f}%")

if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create datasets and loaders
    train_dataset = CustomFashionMNIST(train=True)
    test_dataset = CustomFashionMNIST(train=False)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Create model
    model = FashionNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # Training configuration
    num_epochs = 20
    best_accuracy = 0.0

    print("Starting training...")

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Calculate average loss for the epoch
        avg_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

        # Evaluation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.2f}%')

        # Save only if this is the best accuracy so far
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            save_best_model(model, optimizer, epoch, accuracy)

    print(f"\ncompleted!")
    print(f"Best accuracy achieved: {best_accuracy:.2f}%")


Using device: cuda
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 13.3MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 210kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 3.93MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 6.19MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Starting training...
Epoch [1/20], Average Loss: 0.3838
Epoch [1/20], Accuracy: 88.58%
Best model saved to model.pt with accuracy: 88.58%
Epoch [2/20], Average Loss: 0.2661
Epoch [2/20], Accuracy: 90.47%
Best model saved to model.pt with accuracy: 90.47%
Epoch [3/20], Average Loss: 0.2264
Epoch [3/20], Accuracy: 90.11%
Epoch [4/20], Average Loss: 0.1946
Epoch [4/20], Accuracy: 92.17%
Best model saved to model.pt with accuracy: 92.17%
Epoch [5/20], Average Loss: 0.1737
Epoch [5/20], Accuracy: 91.80%
Epoch [6/20], Average Loss: 0.1521
Epoch [6/20], Accuracy: 92.17%
Epoch [7/20], Average Loss: 0.1360
Epoch [7/20], Accuracy: 92.44%
Best model saved to model.pt with accuracy: 92.44%
Epoch [8/20], Average Loss: 0.1184
Epoch [8/20], Accuracy: 92.33%
Epoch [9/20], Average Loss: 0.1058
Epoch [9/20], Accuracy: 92.74%
Best model saved to model.pt with accuracy: 92.74%
Epoch [10/20], Average Loss: 0.0905
Epoch