# Connect to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Navigate to Classification folder

In [None]:
%%bash
cd /content/drive/MyDrive/Classification
pwd

/content/drive/MyDrive/Classification


# Install dependencies

In [None]:
pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.5.2-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.5.2-py3-none-any.whl (891 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m891.4/891.4 kB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.8 torchmetrics-1.5.2


# Train Dataset

In [None]:
"""
ThinhDV  11 Nov 2024
How to train/fine-tune a pre-trained model on a custom dataset (i.e., transfer learning)
"""
import torch
from torch import nn, save, load
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchmetrics.functional import accuracy
from torchvision.transforms import ToTensor, Resize
import numpy as np

from torch import nn

# Image Classifier Neural Network
class Alexnet(nn.Module):
    def __init__(self,num_classes=100):
        super(Alexnet, self).__init__()
        self.model = nn.Sequential(
            #### Convolutional Layers ####
            #input: 224*224*3
            # Layer 1:
            nn.Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4)),  # Change 1 to 3 for RGB images: output = 32
            #output: W = (224-11+2*0)/4 + 1=54, H = (224-11+2*0)/4 + 1=54
            nn.ReLU(),  # output:96, W=H=54 do co cung chieu dai moi canh sau conv
            nn.BatchNorm2d(96),
            nn.MaxPool2d(2, 2),
            # after pooling: ((W=H=54-kernel=2)/stride=2) + 1= 27; nn.Flatten(), nn.Linear(96*(27)*(27), 53),

            # Layer 2
            nn.Conv2d(96, 256, kernel_size=(5, 5), padding=(2, 2)),  # input: 96, output: 256
            #output: W=(27 - 5 + 2 * 2) / 1 + 1 = 26, H=(27 - 5 + 2 * 2) / 1 + 1 = 26
            nn.ReLU(),  # output:256, W=H=26 do co cung chieu dai moi canh sau conv
            nn.BatchNorm2d(256),
            nn.MaxPool2d(3, 2),
            # after pooling: ((W=H=26-kernel=3)/stride=2)+1 = 12; nn.Flatten(), nn.Linear(256*(12)*(12), 53),

            # Layer 3
            nn.Conv2d(256, 384, kernel_size=(3, 3), padding=(1, 1)),  # input: 256, output: 384
            # output: W=(12 - 3 + 2 * 1) / 1 + 1 = 12, W=(12 - 3 + 2 * 1) / 1 + 1 = 12
            nn.ReLU(),  # output:384, W=H=12 do co cung chieu dai moi canh sau conv
            nn.BatchNorm2d(384),

            # Layer 4
            nn.Conv2d(384, 384, kernel_size=(3, 3), padding=(1, 1)),  # input: 384, output: 384
            # output: W=(12 - 3 + 2 * 1) / 1 + 1 = 12, W=(12 - 3 + 2 * 1) / 1 + 1 = 12
            nn.ReLU(),  # output:384, W=H=12 do co cung chieu dai moi canh sau conv
            nn.BatchNorm2d(384),

            # Layer 5
            nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),  # input: 384, output: 256
            # output: W=(12 - 3 + 2 * 1) / 1 + 1 = 12, W=(12 - 3 + 2 * 1) / 1 + 1 = 12
            nn.ReLU(),  # output:256, W=H=12 do co cung chieu dai moi canh sau conv
            nn.BatchNorm2d(256),
            nn.MaxPool2d(3, 2),
            # after pooling: ((W=H=12-kernel=3)/stride=2)+1 = 6; nn.Flatten(), nn.Linear(256*(6)*(6), 53),

            #### Fully-Connected Layer ####
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(256*6 *6, num_classes), #tinh sao ra bang 6??????
        )
        #Truyen tham so vao
        self.num_classes = num_classes

    def forward(self, x):
        return self.model(x)


# Setup CUDA
def setup_cuda():
    # Setting seeds for reproducibility
    seed = 50
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def train_model():
    """
    Train the model over a single epoch
    :return: training loss and training accuracy
    """
    train_loss = 0.0
    train_acc = 0.0
    model.train()

    for (img, label) in tqdm(train_loader, ncols=80, desc='Training'):
        # Get a batch
        img, label = img.to(device, dtype=torch.float), label.to(device, dtype=torch.long)

        # Set the gradients to zero before starting backpropagation
        optimizer.zero_grad()

        # Perform a feed-forward pass
        logits = model(img)

        # Compute the batch loss
        loss = loss_fn(logits, label)

        # Compute gradient of the loss fn w.r.t the trainable weights
        loss.backward()

        # Update the trainable weights
        optimizer.step()

        # Accumulate the batch loss
        train_loss += loss.item()

        # Get the predictions to calculate the accuracy for every iteration. Remember to accumulate the accuracy
        prediction = logits.argmax(axis=1)
        train_acc += accuracy(prediction, label, task='multiclass', average='macro', num_classes=len(class_names)).item()

    return train_loss / len(train_loader), train_acc / len(train_loader)


def validate_model():
    """
    Validate the model over a single epoch
    :return: validation loss and validation accuracy
    """
    model.eval()
    valid_loss = 0.0
    val_acc = 0.0

    with torch.no_grad():
        for (img, label) in tqdm(val_loader, ncols=80, desc='Valid'):
            # Get a batch
            img, label = img.to(device, dtype=torch.float), label.to(device, dtype=torch.long)

            # Perform a feed-forward pass
            logits = model(img)

            # Compute the batch loss
            loss = loss_fn(logits, label)

            # Accumulate the batch loss
            valid_loss += loss.item()

            # Get the predictions to calculate the accuracy for every iteration. Remember to accumulate the accuracy
            prediction = logits.argmax(axis=1)
            val_acc += accuracy(prediction, label, task='multiclass', average='macro', num_classes=len(class_names)).item()

    return valid_loss / len(val_loader), val_acc / len(val_loader)


if __name__ == "__main__":
    device = setup_cuda()

    # 1. Load the dataset
    transform = transforms.Compose([Resize((224, 224)), ToTensor()])
    train_dataset = ImageFolder(root='/content/drive/MyDrive/dataset/playcards/train', transform=transform)
    val_dataset = ImageFolder(root='/content/drive/MyDrive/dataset/playcards/valid', transform=transform)
    # Get class names
    class_names = train_dataset.classes

    # 2. Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

    # 3. Create a new deep model with pre-trained weights
    model = Alexnet(
        num_classes=len(class_names),
    ).to(device)

    # 4. Specify loss function and optimizer
    optimizer = Adam(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()

    # 5. Train the model with 100 epochs
    max_acc = 0
    for epoch in range(100):

        # 5.1. Train the model over a single epoch
        train_loss, train_acc = train_model()

        # 5.2. Validate the model after training
        val_loss, val_acc = validate_model()

        print(f'Epoch {epoch}: Validation loss = {val_loss}, Validation accuracy: {val_acc}')

        # 4.3. Save the model if the validation accuracy is increasing
        if val_acc > max_acc:
            print(f'Validation accuracy increased ({max_acc} --> {val_acc}). Model saved')
            max_acc = val_acc
            torch.save(model.state_dict(), 'epoch_' + str(epoch) + '_acc_{0:.4f}'.format(max_acc) + '.pt')



Training: 100%|█████████████████████████████████| 19/19 [05:53<00:00, 18.62s/it]
Valid: 100%|██████████████████████████████████████| 1/1 [00:14<00:00, 14.22s/it]


Epoch 0: Validation loss = 1.364609956741333, Validation accuracy: 0.3500000238418579
Validation accuracy increased (0 --> 0.3500000238418579). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:02<00:00,  6.34it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.85it/s]


Epoch 1: Validation loss = 1.1601483821868896, Validation accuracy: 0.45000001788139343
Validation accuracy increased (0.3500000238418579 --> 0.45000001788139343). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.89it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.81it/s]


Epoch 2: Validation loss = 0.9250661730766296, Validation accuracy: 0.6499999761581421
Validation accuracy increased (0.45000001788139343 --> 0.6499999761581421). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.48it/s]


Epoch 3: Validation loss = 0.8232851028442383, Validation accuracy: 0.6500000357627869
Validation accuracy increased (0.6499999761581421 --> 0.6500000357627869). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.15it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.44it/s]


Epoch 4: Validation loss = 0.9313212633132935, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.60it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.99it/s]


Epoch 5: Validation loss = 1.0416505336761475, Validation accuracy: 0.6000000238418579


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.83it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.84it/s]


Epoch 6: Validation loss = 0.9515305757522583, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.65it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.05it/s]


Epoch 7: Validation loss = 1.0055161714553833, Validation accuracy: 0.7000000476837158
Validation accuracy increased (0.6500000357627869 --> 0.7000000476837158). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.41it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 12.04it/s]


Epoch 8: Validation loss = 0.9059051275253296, Validation accuracy: 0.75
Validation accuracy increased (0.7000000476837158 --> 0.75). Model saved


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.24it/s]


Epoch 9: Validation loss = 1.1284377574920654, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.78it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.69it/s]


Epoch 10: Validation loss = 0.9008943438529968, Validation accuracy: 0.75


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.54it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.33it/s]


Epoch 11: Validation loss = 1.0997097492218018, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.53it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.90it/s]


Epoch 12: Validation loss = 0.8681515455245972, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.91it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.13it/s]


Epoch 13: Validation loss = 1.0704783201217651, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.85it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.29it/s]


Epoch 14: Validation loss = 1.1236168146133423, Validation accuracy: 0.75


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.58it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.10it/s]


Epoch 15: Validation loss = 1.1931809186935425, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.48it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.41it/s]


Epoch 16: Validation loss = 0.8250978589057922, Validation accuracy: 0.75


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.70it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.97it/s]


Epoch 17: Validation loss = 1.0993773937225342, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.69it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.45it/s]


Epoch 18: Validation loss = 1.0886213779449463, Validation accuracy: 0.75


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.40it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.70it/s]


Epoch 19: Validation loss = 1.037945032119751, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.58it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.01it/s]


Epoch 20: Validation loss = 1.0776081085205078, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.84it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.27it/s]


Epoch 21: Validation loss = 1.1368563175201416, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.89it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.66it/s]


Epoch 22: Validation loss = 1.1160423755645752, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.59it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.77it/s]


Epoch 23: Validation loss = 1.1460307836532593, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.61it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.35it/s]


Epoch 24: Validation loss = 1.0885813236236572, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.86it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.41it/s]


Epoch 25: Validation loss = 1.086698055267334, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.79it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.51it/s]


Epoch 26: Validation loss = 1.0848580598831177, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.65it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.49it/s]


Epoch 27: Validation loss = 1.0872857570648193, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.58it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.24it/s]


Epoch 28: Validation loss = 1.0994725227355957, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.83it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.59it/s]


Epoch 29: Validation loss = 1.1318150758743286, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.00it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.03it/s]


Epoch 30: Validation loss = 1.0926387310028076, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.56it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.92it/s]


Epoch 31: Validation loss = 1.1495927572250366, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.70it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.86it/s]


Epoch 32: Validation loss = 1.1949942111968994, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.88it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.17it/s]


Epoch 33: Validation loss = 1.2106438875198364, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.96it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.10it/s]


Epoch 34: Validation loss = 1.2050237655639648, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.56it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.42it/s]


Epoch 35: Validation loss = 1.1565439701080322, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.60it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.64it/s]


Epoch 36: Validation loss = 1.1754566431045532, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.81it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.53it/s]


Epoch 37: Validation loss = 1.221627950668335, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.98it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.42it/s]


Epoch 38: Validation loss = 1.225530743598938, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.52it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.14it/s]


Epoch 39: Validation loss = 1.1770763397216797, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.69it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.71it/s]


Epoch 40: Validation loss = 1.2493103742599487, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.82it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.75it/s]


Epoch 41: Validation loss = 1.2576210498809814, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.83it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.59it/s]


Epoch 42: Validation loss = 1.2375162839889526, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.43it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  7.93it/s]


Epoch 43: Validation loss = 1.2367222309112549, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.68it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.41it/s]


Epoch 44: Validation loss = 1.2552897930145264, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.88it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.48it/s]


Epoch 45: Validation loss = 1.2639870643615723, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.73it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.06it/s]


Epoch 46: Validation loss = 1.2551109790802002, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.47it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.78it/s]


Epoch 47: Validation loss = 1.2268118858337402, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.79it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.71it/s]


Epoch 48: Validation loss = 1.2287920713424683, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.81it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.20it/s]


Epoch 49: Validation loss = 1.2339894771575928, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.84it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.57it/s]


Epoch 50: Validation loss = 1.2787630558013916, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.42it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.09it/s]


Epoch 51: Validation loss = 1.309301733970642, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.83it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.16it/s]


Epoch 52: Validation loss = 1.2952810525894165, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.89it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.37it/s]


Epoch 53: Validation loss = 1.3169523477554321, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.04it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.01it/s]


Epoch 54: Validation loss = 1.301023244857788, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.30it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  5.57it/s]


Epoch 55: Validation loss = 1.2934041023254395, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.88it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.58it/s]


Epoch 56: Validation loss = 1.253033995628357, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.96it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.30it/s]


Epoch 57: Validation loss = 1.342456579208374, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.46it/s]


Epoch 58: Validation loss = 1.3323360681533813, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.32it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.91it/s]


Epoch 59: Validation loss = 1.3297107219696045, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.98it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.59it/s]


Epoch 60: Validation loss = 1.314546823501587, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.92it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.24it/s]


Epoch 61: Validation loss = 1.293379545211792, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.89it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.95it/s]


Epoch 62: Validation loss = 1.3622698783874512, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.31it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.03it/s]


Epoch 63: Validation loss = 1.3530024290084839, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.84it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.70it/s]


Epoch 64: Validation loss = 1.3253676891326904, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.91it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.57it/s]


Epoch 65: Validation loss = 1.3180063962936401, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.88it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.12it/s]


Epoch 66: Validation loss = 1.329479455947876, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.35it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.61it/s]


Epoch 67: Validation loss = 1.3122518062591553, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.99it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.98it/s]


Epoch 68: Validation loss = 1.3130521774291992, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.01it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.59it/s]


Epoch 69: Validation loss = 1.3040680885314941, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.97it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.17it/s]


Epoch 70: Validation loss = 1.2967294454574585, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.34it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.35it/s]


Epoch 71: Validation loss = 1.3132197856903076, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.12it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.23it/s]


Epoch 72: Validation loss = 1.3029236793518066, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.92it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.60it/s]


Epoch 73: Validation loss = 1.3579301834106445, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.96it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.77it/s]


Epoch 74: Validation loss = 1.291334867477417, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.14it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.98it/s]


Epoch 75: Validation loss = 1.2808332443237305, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.01it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.40it/s]


Epoch 76: Validation loss = 1.3351281881332397, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.02it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.53it/s]


Epoch 77: Validation loss = 1.2950793504714966, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  6.09it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.63it/s]


Epoch 78: Validation loss = 1.3396203517913818, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.34it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  7.87it/s]


Epoch 79: Validation loss = 1.325082778930664, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.90it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.57it/s]


Epoch 80: Validation loss = 1.373847484588623, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.97it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.04it/s]


Epoch 81: Validation loss = 1.3540942668914795, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.97it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.17it/s]


Epoch 82: Validation loss = 1.3665187358856201, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.39it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.07it/s]


Epoch 83: Validation loss = 1.3289134502410889, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.94it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.52it/s]


Epoch 84: Validation loss = 1.3697757720947266, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.98it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.48it/s]


Epoch 85: Validation loss = 1.3712691068649292, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.92it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.13it/s]


Epoch 86: Validation loss = 1.355165719985962, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.18it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.13it/s]


Epoch 87: Validation loss = 1.3710581064224243, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.90it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.57it/s]


Epoch 88: Validation loss = 1.3797996044158936, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.91it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.26it/s]


Epoch 89: Validation loss = 1.3998881578445435, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.54it/s]


Epoch 90: Validation loss = 1.3680737018585205, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.38it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  9.35it/s]


Epoch 91: Validation loss = 1.3722141981124878, Validation accuracy: 0.7000000476837158


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.38it/s]


Epoch 92: Validation loss = 1.3501875400543213, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.89it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.88it/s]


Epoch 93: Validation loss = 1.3510358333587646, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.92it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.67it/s]


Epoch 94: Validation loss = 1.3931069374084473, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.35it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.64it/s]


Epoch 95: Validation loss = 1.3631789684295654, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.84it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.76it/s]


Epoch 96: Validation loss = 1.4141521453857422, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 10.89it/s]


Epoch 97: Validation loss = 1.4054460525512695, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.96it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00,  8.96it/s]


Epoch 98: Validation loss = 1.4210631847381592, Validation accuracy: 0.6500000357627869


Training: 100%|█████████████████████████████████| 19/19 [00:03<00:00,  5.31it/s]
Valid: 100%|██████████████████████████████████████| 1/1 [00:00<00:00, 11.05it/s]

Epoch 99: Validation loss = 1.287003755569458, Validation accuracy: 0.6500000357627869





# Testing

In [None]:
import torch
from torch import nn, save, load
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchmetrics.functional import accuracy
from torchvision.transforms import ToTensor, Resize
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt

# Thiết lập biến cần thiết
train_dir = '/content/drive/MyDrive/dataset/playcards/train'
test_dir = '/content/drive/MyDrive/dataset/playcards/test'
valid_dir = '/content/drive/MyDrive/dataset/playcards/valid'
NUM_WORKERS = os.cpu_count()
BATCH_SIZE = 32
IMG_SIZE = 224
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])
patch_size = 16
CLASS = {
    0: 'ace_of_clubs',
    1: 'ace_of_diamonds',
    2: 'ace_of_hearts',
    3: 'ace_of_spades',
}


# Thiết lập thiết bị (GPU hoặc CPU)
def setup_cuda():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    return device


device = setup_cuda()

# Thiết lập các transform để xử lý ảnh
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])


# Hàm dự đoán kết quả cho một hình ảnh
def predict_image(image_path, model, transform, class_names, device):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img_tensor)
        _, predicted_class = torch.max(output, 1)
    predicted_label = class_names[predicted_class.item()]
    return img, predicted_label


# Hàm chính để dự đoán các hình ảnh trong tập test
def test_model():
    # 1. Tải dữ liệu và lớp từ tập train
    transform = transforms.Compose([Resize((224, 224)), ToTensor()])
    train_dataset = ImageFolder(root=train_dir, transform=transform)
    test_dataset = ImageFolder(root=test_dir, transform=transform)
    # Get class names
    class_names = train_dataset.classes

    # 2. Tạo mô hình ViT và tải trạng thái từ checkpoint
    model = Alexnet(
        num_classes=len(class_names),
    ).to(device)

    folder_checkpoint = '/content/'  # Define the folder name
    file_checkpoint = os.path.join(folder_checkpoint, 'epoch_8_acc_0.7500.pt')  # lay epoch cuoi cung
    model.load_state_dict(torch.load(file_checkpoint, device))
    print('Model loaded from checkpoint.')
    # Ensure the output directory exists
    output_dir = "result"
    os.makedirs(output_dir, exist_ok=True)

    # 3. Dự đoán kết quả cho mỗi hình ảnh trong tập test
    for image_path in tqdm(test_dataset.imgs, desc='Testing'):
        img, predicted_label = predict_image(image_path[0], model, manual_transforms, class_names, device)
        # plt.imshow(img)
        # plt.title(f'Predicted: {predicted_label}')
        # plt.show()

        # Convert the tensor image back to a PIL image if necessary
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)

        # Create a plot
        fig, ax = plt.subplots()

        # Set white background
        fig.patch.set_facecolor('white')
        ax.set_facecolor('white')

        # Remove axis
        ax.axis('off')

        # Display the image
        ax.imshow(img)

        # Add the predicted label as the title
        ax.set_title(f'Predicted: {predicted_label}', fontsize=12, pad=10)

        # Save the figure
        image_basename = os.path.basename(image_path[0])
        image_name, image_ext = os.path.splitext(image_basename)
        output_image_path = os.path.join(output_dir, f"{image_name}_pred_{predicted_label}.png")

        plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0.1)
        plt.close(fig)


if __name__ == '__main__':
    test_model()

  model.load_state_dict(torch.load(file_checkpoint, device))


Model loaded from checkpoint.


Testing: 100%|██████████| 20/20 [00:20<00:00,  1.02s/it]
