# "Kaggle Cats&Dogs"
> "Classification of images of Kaggle Cats&Dogs dataset"

- toc: false
- branch: master
- badges: true
- comments: true
- categories: [jupyter, pytorch, pytorch-lightning]
- hide: false
- search_exclude: true

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from tqdm import tqdm
import os
import cv2
import numpy as np

In [2]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import glob

class CatsDogsDS(Dataset):
    def __init__(self, files, labels):
        super().__init__()
        self.files = files
        self.labels = labels
        
    def __getitem__(self, ix):
        try:
            file = self.files[ix]
            label = file.split('\\')[-2]
            label = self.labels[label]
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
            img = cv2.resize(img, (50, 50))[None]
            img = img/255.
            return img, label
        except:
            return np.zeros((1,50,50)), np.random.randint(2)
        
    def __len__(self):
        return len(self.files)
        
class CatsDogsDM(pl.LightningDataModule):
    def __init__(self, cats_dir, dogs_dir, labels, img_size):
        super().__init__()
        self.files = glob.glob(cats_dir+'/*.jpg') + glob.glob(dogs_dir+'/*.jpg')
        np.random.seed(10)
        np.random.shuffle(self.files)
        self.trn, self.val = train_test_split(self.files)
        self.trn_dataset = CatsDogsDS(self.trn, labels)
        self.val_dataset = CatsDogsDS(self.val, labels)
        
    def train_dataloader(self):
        return DataLoader(self.trn_dataset, 
                          batch_size=64, shuffle=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)
        

IMG_SIZE = 50
cats = "D:\Study\kagglecatsanddogs_3367a\PetImages\Cat"
dogs = "D:\Study\kagglecatsanddogs_3367a\PetImages\Dog"
labels = {'Cat':0, 'Dog':1}
    
dm = CatsDogsDM(cats, dogs, labels, IMG_SIZE)

In [3]:
class Net(pl.LightningModule):
    def conv_layer(self, ni, no):
        return nn.Sequential(
            nn.Conv2d(ni,no,5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(0.2)
        )
    
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            self.conv_layer(1,32),
            self.conv_layer(32,64),
            self.conv_layer(64,128) # (bs, 128, 2, 2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(512,512),
            nn.ReLU(inplace=True),
            nn.Linear(512,2)
        )
        
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, x):
        x = self.backbone(x)
        x = x.view(len(x), -1)
        x = self.classifier(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x.float())
        loss = self.loss_fn(y_hat, y)
        matches = [torch.argmax(i) == j for i,j in zip(y_hat,y)]
        acc = sum(matches)/len(matches)
        self.log('acc', acc, on_step=True, on_epoch=True,
                prog_bar=True, logger=True)
        return {'loss':loss, 'acc': acc}
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x.float())
        loss = self.loss_fn(y_hat, y)
        matches = [torch.argmax(i) == j for i,j in zip(y_hat,y)]
        val_acc = sum(matches)/len(matches)
        self.log('val_acc', val_acc, on_step=True, on_epoch=True,
                prog_bar=True, logger=True)
        
        return {'val_loss':loss, 'val_acc': val_acc}
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)
        
    def predict(self, test):
        if test.shape[0] == 1:
            test = test[None]
            pred = self(test)
            return torch.argmax(pred).item()
        else:
            pred = self(test.float())
            return torch.argmax(pred, dim=1)
    
    def evaluate(self, testx, labels):
        preds = self.predict(testx)
        if isinstance(preds, int):
            return preds==labels
        else:
            matches = (preds==labels)
            acc = sum(matches)/len(matches)
            return acc
        
    def get_progress_bar_dict(self):
        tqdm_dict = super().get_progress_bar_dict()
        if 'v_num' in tqdm_dict:
            del tqdm_dict['v_num']
        return tqdm_dict

In [4]:
if __name__ == '__main__':
    net = Net()
    trainer = pl.Trainer(max_epochs=5, gpus=[0])
    trainer.fit(net, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | backbone   | Sequential       | 257 K 
1 | classifier | Sequential       | 263 K 
2 | loss_fn    | CrossEntropyLoss | 0     
------------------------------------------------
520 K     Trainable params
0         Non-trainable params
520 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [6]:
testset = dm.val_dataset
rnd = np.random.randint(0, len(testset))
testx = testset[rnd][0]
testy = testset[rnd][1]

batchset = iter(dm.val_dataloader())
batchx, batchy = next(batchset)

pred = net.predict(batchx)
acc = net.evaluate(batchx, batchy)
print("Accuracy", acc.item())

Accuracy 0.84375
