In [3]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from cnn import CardDataset, CNN, rank_converter, suit_converter, decode_rank, decode_suit
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tqdm import trange
import os
import shutil

In [4]:
dataset = CardDataset()
train_size = int(0.95 * len(dataset))
validation_size = len(dataset) - train_size
train_subset, validation_subset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)

In [10]:
model = CNN()

loss_function = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

best_accuracy = 0
for epoch in trange(15):
    running_loss = 0
    model.train()
    for batch in train_loader:
        X_batch, y_batch = batch

        optimizer.zero_grad()

        y_pred = model.forward(X_batch)

        loss = loss_function(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    accuracy = 0
    val_ds, inds = validation_subset.dataset, validation_subset.indices
    model.eval()
    for ind in inds:
        X, y_true = dataset[ind]
        y_pred = np.argmax(model(X).detach().numpy())
        accuracy += (y_true == y_pred)
    if best_accuracy < accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), "model.pt")
    
    torch.save(model.state_dict(), f"model_{epoch}_new.pt")
    
    learning_rate *= 0.8

    print(f"Epoch {epoch} Accuracy:", accuracy / len(inds))

  7%|▋         | 1/15 [00:43<10:02, 43.02s/it]

Epoch 0 Accuracy: 0.5473087818696883


 13%|█▎        | 2/15 [01:36<10:37, 49.05s/it]

Epoch 1 Accuracy: 0.6311614730878187


 20%|██        | 3/15 [02:23<09:40, 48.34s/it]

Epoch 2 Accuracy: 0.8339943342776204


 27%|██▋       | 4/15 [03:13<08:58, 48.96s/it]

Epoch 3 Accuracy: 0.9722379603399434


 33%|███▎      | 5/15 [04:02<08:10, 49.05s/it]

Epoch 4 Accuracy: 0.9903682719546743


 40%|████      | 6/15 [04:53<07:25, 49.52s/it]

Epoch 5 Accuracy: 0.9915014164305949


 47%|████▋     | 7/15 [05:45<06:42, 50.32s/it]

Epoch 6 Accuracy: 0.9971671388101983


 53%|█████▎    | 8/15 [06:37<05:56, 50.94s/it]

Epoch 7 Accuracy: 0.9909348441926346


 60%|██████    | 9/15 [07:28<05:05, 50.95s/it]

Epoch 8 Accuracy: 0.9609065155807366


 67%|██████▋   | 10/15 [08:17<04:11, 50.39s/it]

Epoch 9 Accuracy: 0.9994334277620397


 73%|███████▎  | 11/15 [09:06<03:19, 49.92s/it]

Epoch 10 Accuracy: 0.996600566572238


 80%|████████  | 12/15 [09:56<02:30, 50.03s/it]

Epoch 11 Accuracy: 0.996600566572238


 87%|████████▋ | 13/15 [10:43<01:37, 48.99s/it]

Epoch 12 Accuracy: 0.996600566572238


 93%|█████████▎| 14/15 [11:29<00:47, 47.99s/it]

Epoch 13 Accuracy: 0.998300283286119


100%|██████████| 15/15 [12:14<00:00, 48.95s/it]

Epoch 14 Accuracy: 0.9943342776203966



