In [1]:
import torch
import numpy as np
import os
import cv2
from tqdm import tqdm

In [2]:
REBUILD_DATA = False

In [3]:
class CatsVSDogs():
    CATS = "D:\git_projects\DataSets\PetImages\Cat"
    DOGS = "D:\git_projects\DataSets\PetImages\Dog"
    LABEL = { CATS:0, DOGS:1}
    training_data = []
    IMAGE_SIZE = 50
    cat_count = 0
    dog_count = 0
    
    def make_training_data(self):
        for label in self.LABEL:
            print(label)
            for f in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, f)
                    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img, (self.IMAGE_SIZE, self.IMAGE_SIZE))
                    self.training_data.append([np.array(img), np.eye(2)[self.LABEL[label]]])

                    if label == self.CATS:
                        self.cat_count += 1
                    elif label == self.DOGS:
                        self.dog_count += 1
                except Exception as e:
                    pass
        np.random.shuffle(self.training_data)
        np.save("D:\\git_projects\\DataSets\\PetImages\\training_data3.npy", self.training_data)
    
        print("CATS=> ", self.cat_count)
        print("DOGS=> ", self.dog_count)


if REBUILD_DATA:
    catsvdogs = CatsVSDogs()
    catsvdogs.make_training_data()

In [4]:
training_data = np.load("D:\\git_projects\\DataSets\\PetImages\\training_data3.npy", allow_pickle = True)

In [5]:
training_data.shape

(24946, 2)

In [6]:
import matplotlib.pyplot as plt
plt.imshow(training_data[0][0], cmap="gray")
plt.show()

<Figure size 640x480 with 1 Axes>

In [7]:
training_data[0][1]

array([0., 1.])

In [8]:
import torch.nn as nn
import torch.nn.functional as F

In [9]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        
        self.fc1 = nn.Linear(2*2*128, 256)
        self.fc2 = nn.Linear(256, 2)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.max_pool2d(F.relu(self.conv3(x)), 2)
        
        x = x.view(-1, 128*2*2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return F.softmax(x, dim=1)

net = Net()        

In [10]:
from torch import optim
optimizer = optim.Adam(net.parameters(), lr = 0.001)

In [11]:
X = torch.Tensor([i[0] for i in training_data]).view(-1, 50*50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1
val_size = int(len(X))*0.1
print(val_size)

2494.6000000000004


In [12]:
val_size = int(val_size)
train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

print(len(train_X))
print(len(test_X))

22452
2494


In [13]:
loss_function = nn.MSELoss()

In [16]:
EPOCHS = 3
BATCH_SIZE = 100

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
        batch_y = train_y[i:i+BATCH_SIZE]
        
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(loss)

100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:06<00:00,  1.77it/s]


tensor(0.1789, grad_fn=<MseLossBackward>)


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:07<00:00,  1.77it/s]


tensor(0.1635, grad_fn=<MseLossBackward>)


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:08<00:00,  1.75it/s]


tensor(0.1523, grad_fn=<MseLossBackward>)


In [17]:
correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        predicted_class = torch.argmax(net(test_X[i].view(-1, 1, 50, 50)))
        if real_class == predicted_class:
            correct += 1
        total += 1

print("Accuracy=> ", correct/total)

100%|█████████████████████████████████████████████████████████████████████████████| 2494/2494 [00:13<00:00, 184.20it/s]


Accuracy=>  0.7113071371291099
