### In this notebook we will create an image classifier to detect playing cards

we will tackle this problem in three parts

1. pytorch Dataset
2. pytorch Model
3. Pytorch Training Loop

Almost every pytorch model training pipeline meets this paradigm

In [22]:
pip install timm

Note: you may need to restart the kernel to use updated packages.


In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

Setting up the dataset

would you like to bake a cake without first having the ingredients

In [24]:
class PlayingCard(Dataset):
    def __init__(self, data_dir, transform = None): #tells the class what todo when created
        super().__init__()
        self.data = ImageFolder(data_dir, transform = transform)
        
    def __len__(self): #dataloader needs to know no. of eg
        return len(self.data)
    
    def __getitem__(self, idx): #takes idx loc and ret 1 itm
        return self.data[idx]
    
    @property
    def classes(self):
        return self.data.classes

In [21]:
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [26]:
dataset = PlayingCard(
    data_dir='playing_cards/train',
    transform= transform
)

In [30]:
image, label = dataset[100]
print(image.shape)

torch.Size([3, 128, 128])


### Dataloader
  Batching our Dataset

In [31]:
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

In [32]:
for image, labels in dataloader:
    break

In [35]:
image.shape, labels.shape, labels

(torch.Size([32, 3, 128, 128]),
 torch.Size([32]),
 tensor([14, 49, 36, 51, 16, 42, 27, 14, 39, 21,  8, 32, 46, 46, 30, 46, 17, 26,
         32,  5, 39, 52, 22, 15, 18, 43, 30, 38, 22,  5, 28, 37]))

# Making a Pytorch a Model

In [None]:
class SimpleCardClassifier(nn.Module):
    def __init__(self, num_class= 53):
        super().__init__()
        self.base_model = timm.create_model('efficientnet_b0', pretrained = True)
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        enet_out_size =1280
        self.classifier = nn.Linear(enet_out_size, num_class)
    
    def forward(self,x):
        x = self.features(x)
        output = self.classifier(x)
        return output

In [None]:
model = SimpleCardClassifier()
print(model)

In [None]:
print('hehehe')