In [14]:
NUM_EPOCHS = 100
BATCH_SIZE = 16
MOMENTUM = 0.9
LR_DECAY = 0.0001
LR_INIT = 0.005
IMAGE_DIM = 227
NUM_CLASSES = 5
OUTPUT_PATH = "/content/drive/MyDrive/Projects/Cricket07/model.pth"

In [15]:
import numpy as np
FILE_NAME = "/content/drive/MyDrive/Projects/Cricket07/training_data.npz"

data = np.load(FILE_NAME)
X = data['X']
X = X[:, np.newaxis, :, :]
Y = data['Y']

In [16]:
import torch
from torch.utils.data import TensorDataset, DataLoader

tensor_x = torch.Tensor(X)
tensor_y = torch.Tensor(Y)

my_dataset = TensorDataset(tensor_x, tensor_y)
trainloader = DataLoader(my_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [17]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=NUM_CLASSES):
        super(AlexNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=11, stride=4),  # (8 x 96 x 55 x 55)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # (8 x 96 x 27 x 27)
            nn.Conv2d(96, 256, 5, padding=2),  # (8 x 256 x 27 x 27)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # (8 x 256 x 13 x 13)
            nn.Conv2d(256, 384, 3, padding=1),  # (8 x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, padding=1),  # (8 x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, padding=1),  # (8 x 256 x 13 x 13)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),  # (8 x 256 x 6 x 6)
        )
        self.linear = nn.Sequential(
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=(256 * 6 * 6), out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=num_classes),
        )

    def forward(self, x):
        x = self.net(x)
        x = x.view(-1, 256 * 6 * 6)
        return self.linear(x)

# Example usage:
# model = AlexNet()
# x = torch.randn(1, 1, 227, 227)
# print(model(x).shape)

In [18]:
for img, cls in trainloader:
    print(img.shape)
    break

torch.Size([16, 1, 227, 227])


In [None]:
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.autograd.set_detect_anomaly(True)

alexnet = AlexNet(num_classes=NUM_CLASSES).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(alexnet.parameters(), lr=0.001, momentum=0.9)

print('Starting training...')

for epoch in tqdm(range(NUM_EPOCHS)):

    correct = 0
    total_samples = 0
    for imgs, classes in trainloader:
        imgs, classes = imgs.to(device), classes.to(device)

        optimizer.zero_grad()

        op = alexnet(imgs)
        output = op.clone()
        loss = F.cross_entropy(output, classes)

        loss.backward()
        optimizer.step()

        correct += (output.argmax(dim=1) == classes.argmax(dim=1)).sum().item()
        total_samples += classes.size(0)

    accuracy = (correct / total_samples) * 100
    print(f'Epoch {epoch + 1}: Loss: {loss.item():.4f}, Accuracy: {accuracy:.2f}%')

print("Finished training")
torch.save(alexnet.state_dict(), OUTPUT_PATH)

Starting training...


  1%|          | 1/100 [00:03<05:59,  3.63s/it]

Epoch 1: Loss: 1.6036, Accuracy: 16.85%


  2%|▏         | 2/100 [00:06<05:38,  3.46s/it]

Epoch 2: Loss: 1.5884, Accuracy: 19.49%


  3%|▎         | 3/100 [00:10<05:23,  3.34s/it]

Epoch 3: Loss: 1.6072, Accuracy: 18.90%


  4%|▍         | 4/100 [00:13<05:14,  3.27s/it]

Epoch 4: Loss: 1.6099, Accuracy: 21.06%


  5%|▌         | 5/100 [00:17<05:24,  3.42s/it]

Epoch 5: Loss: 1.5857, Accuracy: 21.45%


  6%|▌         | 6/100 [00:20<05:14,  3.35s/it]

Epoch 6: Loss: 1.6127, Accuracy: 20.47%


  7%|▋         | 7/100 [00:23<05:05,  3.28s/it]

Epoch 7: Loss: 1.6110, Accuracy: 20.96%


  8%|▊         | 8/100 [00:26<04:59,  3.25s/it]

Epoch 8: Loss: 1.5960, Accuracy: 21.16%


  9%|▉         | 9/100 [00:30<05:08,  3.40s/it]

Epoch 9: Loss: 1.5622, Accuracy: 24.00%


 10%|█         | 10/100 [00:33<04:59,  3.33s/it]

Epoch 10: Loss: 1.5278, Accuracy: 22.92%


 11%|█         | 11/100 [00:36<04:51,  3.27s/it]

Epoch 11: Loss: 1.6622, Accuracy: 28.31%


 12%|█▏        | 12/100 [00:39<04:45,  3.25s/it]

Epoch 12: Loss: 1.5706, Accuracy: 32.13%


 13%|█▎        | 13/100 [00:43<04:55,  3.40s/it]

Epoch 13: Loss: 1.5245, Accuracy: 31.15%


 14%|█▍        | 14/100 [00:46<04:45,  3.32s/it]

Epoch 14: Loss: 1.3427, Accuracy: 30.46%


 15%|█▌        | 15/100 [00:49<04:38,  3.27s/it]

Epoch 15: Loss: 1.4766, Accuracy: 32.71%


 16%|█▌        | 16/100 [00:53<04:32,  3.24s/it]

Epoch 16: Loss: 1.5389, Accuracy: 33.20%


 17%|█▋        | 17/100 [00:56<04:41,  3.40s/it]

Epoch 17: Loss: 1.5855, Accuracy: 30.85%


 18%|█▊        | 18/100 [00:59<04:32,  3.33s/it]

Epoch 18: Loss: 1.5764, Accuracy: 31.15%


 19%|█▉        | 19/100 [01:03<04:25,  3.28s/it]

Epoch 19: Loss: 1.4365, Accuracy: 33.20%


 20%|██        | 20/100 [01:06<04:22,  3.28s/it]

Epoch 20: Loss: 1.3791, Accuracy: 33.79%


 21%|██        | 21/100 [01:10<04:27,  3.39s/it]

Epoch 21: Loss: 1.3975, Accuracy: 33.40%


 22%|██▏       | 22/100 [01:13<04:19,  3.32s/it]

Epoch 22: Loss: 1.2828, Accuracy: 34.67%


 23%|██▎       | 23/100 [01:16<04:12,  3.28s/it]

Epoch 23: Loss: 1.4009, Accuracy: 34.97%


 24%|██▍       | 24/100 [01:19<04:09,  3.29s/it]

Epoch 24: Loss: 1.1798, Accuracy: 35.26%


 25%|██▌       | 25/100 [01:23<04:11,  3.36s/it]

Epoch 25: Loss: 1.7174, Accuracy: 34.77%


 26%|██▌       | 26/100 [01:26<04:04,  3.30s/it]

Epoch 26: Loss: 1.1663, Accuracy: 36.14%


 27%|██▋       | 27/100 [01:29<03:57,  3.26s/it]

Epoch 27: Loss: 1.5312, Accuracy: 34.57%


 28%|██▊       | 28/100 [01:32<03:56,  3.28s/it]

Epoch 28: Loss: 1.4615, Accuracy: 36.34%


 29%|██▉       | 29/100 [01:36<03:56,  3.34s/it]

Epoch 29: Loss: 1.3435, Accuracy: 38.49%


 30%|███       | 30/100 [01:39<03:50,  3.29s/it]

Epoch 30: Loss: 1.2828, Accuracy: 38.20%


 31%|███       | 31/100 [01:42<03:44,  3.25s/it]

Epoch 31: Loss: 1.5549, Accuracy: 36.43%


 32%|███▏      | 32/100 [01:46<03:45,  3.32s/it]

Epoch 32: Loss: 1.4121, Accuracy: 39.67%


 33%|███▎      | 33/100 [01:49<03:43,  3.34s/it]

Epoch 33: Loss: 1.1571, Accuracy: 41.63%


 34%|███▍      | 34/100 [01:52<03:36,  3.28s/it]

Epoch 34: Loss: 1.1014, Accuracy: 43.58%


 35%|███▌      | 35/100 [01:55<03:30,  3.24s/it]

Epoch 35: Loss: 1.2312, Accuracy: 44.07%


 36%|███▌      | 36/100 [01:59<03:33,  3.33s/it]

Epoch 36: Loss: 1.4632, Accuracy: 46.52%


 37%|███▋      | 37/100 [02:02<03:30,  3.34s/it]

Epoch 37: Loss: 1.7127, Accuracy: 45.64%


 38%|███▊      | 38/100 [02:05<03:23,  3.28s/it]

Epoch 38: Loss: 1.4802, Accuracy: 43.68%


 39%|███▉      | 39/100 [02:09<03:18,  3.25s/it]

Epoch 39: Loss: 1.2663, Accuracy: 46.13%


 40%|████      | 40/100 [02:12<03:21,  3.35s/it]

Epoch 40: Loss: 1.2914, Accuracy: 45.05%


 41%|████      | 41/100 [02:15<03:16,  3.32s/it]

Epoch 41: Loss: 1.3734, Accuracy: 46.23%


 42%|████▏     | 42/100 [02:19<03:09,  3.27s/it]

Epoch 42: Loss: 0.9423, Accuracy: 49.85%


 43%|████▎     | 43/100 [02:22<03:03,  3.23s/it]

Epoch 43: Loss: 1.4569, Accuracy: 45.54%


 44%|████▍     | 44/100 [02:25<03:08,  3.37s/it]

Epoch 44: Loss: 1.3157, Accuracy: 49.56%


 45%|████▌     | 45/100 [02:29<03:02,  3.32s/it]

Epoch 45: Loss: 0.9280, Accuracy: 51.81%


 46%|████▌     | 46/100 [02:32<02:56,  3.26s/it]

Epoch 46: Loss: 1.2612, Accuracy: 49.95%


 47%|████▋     | 47/100 [02:35<02:51,  3.23s/it]

Epoch 47: Loss: 1.5662, Accuracy: 54.16%


 48%|████▊     | 48/100 [02:39<02:56,  3.39s/it]

Epoch 48: Loss: 0.9702, Accuracy: 54.06%


 49%|████▉     | 49/100 [02:42<02:49,  3.32s/it]

Epoch 49: Loss: 1.2483, Accuracy: 52.60%


 50%|█████     | 50/100 [02:45<02:42,  3.26s/it]

Epoch 50: Loss: 0.7693, Accuracy: 57.49%


 51%|█████     | 51/100 [02:48<02:37,  3.22s/it]

Epoch 51: Loss: 1.3193, Accuracy: 51.32%


 52%|█████▏    | 52/100 [02:52<02:41,  3.36s/it]

Epoch 52: Loss: 1.3295, Accuracy: 53.67%


 53%|█████▎    | 53/100 [02:55<02:34,  3.29s/it]

Epoch 53: Loss: 0.8151, Accuracy: 56.12%


 54%|█████▍    | 54/100 [02:58<02:29,  3.26s/it]

Epoch 54: Loss: 1.7388, Accuracy: 55.04%


 55%|█████▌    | 55/100 [03:01<02:25,  3.23s/it]

Epoch 55: Loss: 0.8714, Accuracy: 58.77%


 56%|█████▌    | 56/100 [03:05<02:28,  3.38s/it]

Epoch 56: Loss: 1.3552, Accuracy: 57.69%


 57%|█████▋    | 57/100 [03:08<02:22,  3.31s/it]

Epoch 57: Loss: 0.7598, Accuracy: 61.41%


 58%|█████▊    | 58/100 [03:11<02:17,  3.26s/it]

Epoch 58: Loss: 1.3236, Accuracy: 57.49%


 59%|█████▉    | 59/100 [03:14<02:12,  3.24s/it]

Epoch 59: Loss: 0.7192, Accuracy: 58.96%


 60%|██████    | 60/100 [03:18<02:14,  3.37s/it]

Epoch 60: Loss: 1.2130, Accuracy: 57.98%


 61%|██████    | 61/100 [03:21<02:08,  3.30s/it]

Epoch 61: Loss: 0.8125, Accuracy: 59.84%


 62%|██████▏   | 62/100 [03:24<02:03,  3.26s/it]

Epoch 62: Loss: 0.9790, Accuracy: 62.78%


 63%|██████▎   | 63/100 [03:28<02:00,  3.26s/it]

Epoch 63: Loss: 0.5161, Accuracy: 60.53%


 64%|██████▍   | 64/100 [03:31<02:01,  3.36s/it]

Epoch 64: Loss: 0.9629, Accuracy: 63.76%


 65%|██████▌   | 65/100 [03:34<01:55,  3.30s/it]

Epoch 65: Loss: 1.0187, Accuracy: 64.74%


 66%|██████▌   | 66/100 [03:37<01:50,  3.25s/it]

Epoch 66: Loss: 0.6707, Accuracy: 61.70%


 67%|██████▋   | 67/100 [03:41<01:48,  3.28s/it]

Epoch 67: Loss: 1.1698, Accuracy: 61.41%


 68%|██████▊   | 68/100 [03:44<01:47,  3.36s/it]

Epoch 68: Loss: 0.7866, Accuracy: 63.27%


 69%|██████▉   | 69/100 [03:48<01:41,  3.29s/it]

Epoch 69: Loss: 1.0572, Accuracy: 63.76%


 70%|███████   | 70/100 [03:51<01:37,  3.24s/it]

Epoch 70: Loss: 1.0296, Accuracy: 65.43%


 71%|███████   | 71/100 [03:54<01:35,  3.29s/it]

Epoch 71: Loss: 0.9039, Accuracy: 68.46%


 72%|███████▏  | 72/100 [03:57<01:33,  3.34s/it]

Epoch 72: Loss: 1.0626, Accuracy: 68.76%


 73%|███████▎  | 73/100 [04:01<01:28,  3.27s/it]

Epoch 73: Loss: 1.9611, Accuracy: 65.82%


 74%|███████▍  | 74/100 [04:04<01:24,  3.24s/it]

Epoch 74: Loss: 0.9274, Accuracy: 67.78%


 75%|███████▌  | 75/100 [04:07<01:22,  3.30s/it]

Epoch 75: Loss: 1.6156, Accuracy: 66.70%


 76%|███████▌  | 76/100 [04:11<01:19,  3.33s/it]

Epoch 76: Loss: 0.8404, Accuracy: 68.07%


 77%|███████▋  | 77/100 [04:14<01:15,  3.27s/it]

Epoch 77: Loss: 1.1328, Accuracy: 71.30%


 78%|███████▊  | 78/100 [04:17<01:11,  3.24s/it]

Epoch 78: Loss: 0.7497, Accuracy: 69.05%


 79%|███████▉  | 79/100 [04:20<01:09,  3.31s/it]

Epoch 79: Loss: 0.4237, Accuracy: 69.93%


 80%|████████  | 80/100 [04:24<01:06,  3.32s/it]

Epoch 80: Loss: 0.8043, Accuracy: 68.76%


 81%|████████  | 81/100 [04:27<01:02,  3.27s/it]

Epoch 81: Loss: 1.0126, Accuracy: 68.85%


 82%|████████▏ | 82/100 [04:30<00:58,  3.22s/it]

Epoch 82: Loss: 1.0117, Accuracy: 70.42%


 83%|████████▎ | 83/100 [04:34<00:56,  3.32s/it]

Epoch 83: Loss: 1.1280, Accuracy: 72.28%


 84%|████████▍ | 84/100 [04:37<00:52,  3.31s/it]

Epoch 84: Loss: 0.6691, Accuracy: 70.03%


 85%|████████▌ | 85/100 [04:40<00:48,  3.27s/it]

Epoch 85: Loss: 0.8851, Accuracy: 66.80%


 86%|████████▌ | 86/100 [04:43<00:45,  3.23s/it]

Epoch 86: Loss: 0.6963, Accuracy: 71.11%


 87%|████████▋ | 87/100 [04:47<00:43,  3.36s/it]

Epoch 87: Loss: 0.4870, Accuracy: 71.20%


 88%|████████▊ | 88/100 [04:50<00:39,  3.30s/it]

Epoch 88: Loss: 0.6239, Accuracy: 72.97%


 89%|████████▉ | 89/100 [04:53<00:35,  3.24s/it]

Epoch 89: Loss: 0.4347, Accuracy: 74.34%


 90%|█████████ | 90/100 [04:56<00:32,  3.21s/it]

Epoch 90: Loss: 0.4562, Accuracy: 70.52%


 91%|█████████ | 91/100 [05:00<00:30,  3.38s/it]

Epoch 91: Loss: 1.1838, Accuracy: 70.03%


 92%|█████████▏| 92/100 [05:03<00:26,  3.30s/it]

Epoch 92: Loss: 0.4926, Accuracy: 75.42%


 93%|█████████▎| 93/100 [05:06<00:22,  3.27s/it]

Epoch 93: Loss: 0.4342, Accuracy: 73.36%


 94%|█████████▍| 94/100 [05:09<00:19,  3.24s/it]

Epoch 94: Loss: 0.6402, Accuracy: 77.38%


 95%|█████████▌| 95/100 [05:13<00:16,  3.38s/it]

Epoch 95: Loss: 0.6758, Accuracy: 76.79%


 96%|█████████▌| 96/100 [05:16<00:13,  3.30s/it]

Epoch 96: Loss: 0.7174, Accuracy: 76.00%


 97%|█████████▋| 97/100 [05:19<00:09,  3.25s/it]

Epoch 97: Loss: 0.2586, Accuracy: 75.51%


 98%|█████████▊| 98/100 [05:23<00:06,  3.21s/it]

Epoch 98: Loss: 0.6360, Accuracy: 75.51%
