In [1]:
import torch
import torch.nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
import matplotlib.pyplot as plt

In [2]:
train_data_path = "../../data/cat_and_dog/train"
test_data_path = "../../data/cat_and_dog/validation"

LEARNING_RATE = 5e-4
EPOCHS = 30
BATCH_SIZE = 100

In [3]:
# RGB (3 channel) - (256 x 256)
alexnet = models.alexnet(pretrained=True)

In [None]:
alexnet

In [None]:
alexnet.classifier

In [4]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_loader(path, transform, batch_size):
    dataset = torchvision.datasets.ImageFolder(
        root = path,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True
    )
    return loader

In [5]:
train_loader = get_loader(train_data_path, transform, BATCH_SIZE)

In [None]:
x, y = next(iter(train_loader))

display(y)
display(y.shape)

In [None]:
output = alexnet(x)
display(output)

In [None]:
output.shape

In [None]:
# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
output.max(1)

In [None]:
plt.imshow(x[1].permute(1, 2, 0))

In [6]:
alexnet.classifier = torch.nn.Identity()

after_alex = torch.nn.Sequential(
    torch.nn.Linear(9216, 4096),
    torch.nn.ReLU(),
    torch.nn.Linear(4096, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 2)
)

display(alexnet)
display(after_alex)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Identity()
)

Sequential(
  (0): Linear(in_features=9216, out_features=4096, bias=True)
  (1): ReLU()
  (2): Linear(in_features=4096, out_features=1024, bias=True)
  (3): ReLU()
  (4): Linear(in_features=1024, out_features=2, bias=True)
)

In [7]:
criteria = F.cross_entropy
alexnet = alexnet.cuda()
after_alex = after_alex.cuda()

optimizer = torch.optim.Adam(after_alex.parameters(), lr=LEARNING_RATE)

loss_sum = torch.tensor(0, dtype=torch.float32)
for epoch in range(EPOCHS):
    print("{} epoch".format(epoch))
    for x, y in train_loader:
        output = alexnet(x.cuda())
        output = after_alex(output)
        
        loss = criteria(output, y.cuda())

        loss_sum += loss.detach().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("\t{}".format(loss_sum.mean()))
    loss_sum = torch.tensor(0, dtype=torch.float32)

0 epoch
	34.52383041381836
1 epoch
	2.2854971885681152
2 epoch
	1.2044789791107178
3 epoch
	0.5521777868270874
4 epoch
	0.2547612488269806
5 epoch
	0.09210491180419922
6 epoch
	0.03479543700814247
7 epoch
	0.01003146730363369
8 epoch
	0.005586165003478527
9 epoch
	0.004092842806130648
10 epoch
	0.00307043781504035
11 epoch
	0.0024808431044220924
12 epoch
	0.0020664262119680643
13 epoch
	0.0017626928165555
14 epoch
	0.0015110159292817116
15 epoch
	0.0013028169050812721
16 epoch
	0.0011368608102202415
17 epoch
	0.0009824037551879883
18 epoch
	0.0008366918191313744
19 epoch
	0.0006944608758203685
20 epoch
	0.0005627321661449969
21 epoch
	0.00044750457163900137
22 epoch
	0.000352485163602978
23 epoch
	0.00027527802740223706
24 epoch
	0.00021530153753701597
25 epoch
	0.00016607044381089509
26 epoch
	0.00013195992505643517
27 epoch
	0.00010523318633204326
28 epoch
	8.606433402746916e-05
29 epoch
	7.145165727706626e-05


In [8]:
test_loader = get_loader(test_data_path, transform, BATCH_SIZE)

total = correct = 0
for x, y in test_loader:
    output = alexnet(x.cuda())
    output = after_alex(output)
    
    val, pred_idx = output.max(1)
    correct += (pred_idx == y.cuda()).sum().item()
    total += val.shape[0]

print ("correct : {} // total : {}".format(correct, total))
print ("accuracy: {}".format(correct / total))

correct : 764 // total : 804
accuracy: 0.9502487562189055


<pre>
0 epoch
	48.37724685668945
1 epoch
	11.3847074508667
2 epoch
	8.798026084899902
3 epoch
	8.464077949523926
4 epoch
	5.918179988861084
5 epoch
	4.535010814666748
6 epoch
	3.8624236583709717
7 epoch
	2.679931879043579
8 epoch
	2.0708656311035156
9 epoch
	1.6488384008407593
10 epoch
	1.4238102436065674
11 epoch
	2.2943739891052246
12 epoch
	1.8169094324111938
13 epoch
	0.7539398670196533
14 epoch
	0.5866746306419373
15 epoch
	1.1908200979232788
16 epoch
	1.81795072555542
17 epoch
	0.9758766293525696
18 epoch
	0.577001690864563
19 epoch
	0.4210895895957947
20 epoch
	0.3758643567562103
21 epoch
	0.3736495077610016
22 epoch
	0.715544581413269
23 epoch
	1.090682864189148
24 epoch
	1.150400996208191
25 epoch
	0.29149603843688965
26 epoch
	0.3663492202758789
27 epoch
	0.19721843302249908
28 epoch
	0.1039135679602623
29 epoch
	0.35430073738098145


correct : 720 // total : 804
accuracy: 0.8955223880597015
</pre>

<pre>
0 epoch
	34.52383041381836
1 epoch
	2.2854971885681152
2 epoch
	1.2044789791107178
3 epoch
	0.5521777868270874
4 epoch
	0.2547612488269806
5 epoch
	0.09210491180419922
6 epoch
	0.03479543700814247
7 epoch
	0.01003146730363369
8 epoch
	0.005586165003478527
9 epoch
	0.004092842806130648
10 epoch
	0.00307043781504035
11 epoch
	0.0024808431044220924
12 epoch
	0.0020664262119680643
13 epoch
	0.0017626928165555
14 epoch
	0.0015110159292817116
15 epoch
	0.0013028169050812721
16 epoch
	0.0011368608102202415
17 epoch
	0.0009824037551879883
18 epoch
	0.0008366918191313744
19 epoch
	0.0006944608758203685
20 epoch
	0.0005627321661449969
21 epoch
	0.00044750457163900137
22 epoch
	0.000352485163602978
23 epoch
	0.00027527802740223706
24 epoch
	0.00021530153753701597
25 epoch
	0.00016607044381089509
26 epoch
	0.00013195992505643517
27 epoch
	0.00010523318633204326
28 epoch
	8.606433402746916e-05
29 epoch
	7.145165727706626e-05

    
correct : 764 // total : 804
accuracy: 0.9502487562189055
</pre>