In [None]:
import torch
import torch.nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
train_data_path = "../../data/cat_and_dog/train"
test_data_path = "../../data/cat_and_dog/validation"

LEARNING_RATE = 5e-4
EPOCHS = 30
BATCH_SIZE = 100

In [None]:
# RGB (3 channel) - (256 x 256)
alexnet = models.alexnet(pretrained=True)

In [None]:
alexnet

In [None]:
alexnet.classifier

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_loader(path, transform, batch_size):
    dataset = torchvision.datasets.ImageFolder(
        root = path,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True
    )
    return loader

In [None]:
train_loader = get_loader(train_data_path, transform, BATCH_SIZE)

In [None]:
x, y = next(iter(train_loader))

display(y)
display(y.shape)

In [None]:
output = alexnet(x)
display(output)

In [None]:
output.shape

In [None]:
# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
output.max(1)

In [None]:
plt.imshow(x[1].permute(1, 2, 0))

In [None]:
newClassifier = torch.nn.Sequential(
    torch.nn.Linear(9216, 4096),
    torch.nn.ReLU(),
    torch.nn.Linear(4096, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 2)
)
alexnet.classfier = newClassifier

display(alexnet)

In [None]:
criteria = F.cross_entropy
alexnet = alexnet.cuda()

optimizer = torch.optim.Adam(alexnet.parameters(), lr=LEARNING_RATE)

loss_sum = torch.tensor(0, dtype=torch.float32)
for epoch in range(EPOCHS):
    print("{} epoch".format(epoch))
    for x, y in train_loader:
        output = alexnet(x.cuda())
        loss = criteria(output, y.cuda())

        loss_sum += loss.detach().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("\t{}".format(loss_sum.mean()))
    loss_sum = torch.tensor(0, dtype=torch.float32)

In [None]:
test_loader = get_loader(test_data_path, transform, BATCH_SIZE)

total = correct = 0
for x, y in test_loader:
    output = alexnet(x.cuda())
    val, pred_idx = output.max(1)
    correct += (pred_idx == y.cuda()).sum().item()
    total += val.shape[0]

print ("correct : {} // total : {}".format(correct, total))
print ("accuracy: {}".format(correct / total))

<pre>
0 epoch
	48.37724685668945
1 epoch
	11.3847074508667
2 epoch
	8.798026084899902
3 epoch
	8.464077949523926
4 epoch
	5.918179988861084
5 epoch
	4.535010814666748
6 epoch
	3.8624236583709717
7 epoch
	2.679931879043579
8 epoch
	2.0708656311035156
9 epoch
	1.6488384008407593
10 epoch
	1.4238102436065674
11 epoch
	2.2943739891052246
12 epoch
	1.8169094324111938
13 epoch
	0.7539398670196533
14 epoch
	0.5866746306419373
15 epoch
	1.1908200979232788
16 epoch
	1.81795072555542
17 epoch
	0.9758766293525696
18 epoch
	0.577001690864563
19 epoch
	0.4210895895957947
20 epoch
	0.3758643567562103
21 epoch
	0.3736495077610016
22 epoch
	0.715544581413269
23 epoch
	1.090682864189148
24 epoch
	1.150400996208191
25 epoch
	0.29149603843688965
26 epoch
	0.3663492202758789
27 epoch
	0.19721843302249908
28 epoch
	0.1039135679602623
29 epoch
	0.35430073738098145


correct : 720 // total : 804
accuracy: 0.8955223880597015
</pre>