In [1]:
import torch
import torchvision 
from torchvision import datasets, models, transforms

In [13]:
vgg19 = models.vgg19_bn(pretrained=True, progress=True)
print(vgg19)

Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /home/ec2-user/.cache/torch/checkpoints/vgg19_bn-c79401a0.pth


HBox(children=(FloatProgress(value=0.0, max=574769405.0), HTML(value='')))


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 25

In [16]:
for param in vgg19.features.parameters():
    param.requires_grad=False

In [15]:
import torch.nn as nn
classes = ['bacterial', 'normal', 'virus']
in_f = vgg19.classifier[6].in_features
vgg19.classifier[6] = nn.Linear(in_features=in_f, out_features=len(classes), bias=True)
print(vgg19.classifier[6])

Linear(in_features=4096, out_features=3, bias=True)


In [18]:
import torch.optim as optim 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg19.classifier.parameters(), lr=0.001)

In [20]:
train_dir = 'data/workdir'
test_dir = 'data/testdir'
image_transformer = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_data = datasets.ImageFolder(train_dir, transform=image_transformer)
test_data = datasets.ImageFolder(test_dir, transform=image_transformer)
print("Length of the Training Data in Total: ")
print(len(train_data))
print("Length of the Test Data in Total: ")
print(len(test_data))
print("Details about the Training Data: ")
print(train_data)
print("Details about the Test Data: ")
print(test_data)

batch_size = 20
num_workers=0

train_loader = torch.utils.data.DataLoader(train_data, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size,
                                          num_workers=num_workers, 
                                          shuffle=True)

Length of the Training Data in Total: 
5233
Length of the Test Data in Total: 
624
Details about the Training Data: 
Dataset ImageFolder
    Number of datapoints: 5233
    Root location: data/workdir
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )
Details about the Test Data: 
Dataset ImageFolder
    Number of datapoints: 624
    Root location: data/testdir
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )


In [None]:
epochs = 2
for epoch in range(1, epochs+1):
    train_loss = 0.0
    for batch, (data, label) in enumerate(train_loader):
        optimizer.zero_grad()
        output=vgg19(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        if batch % 20 == 19:
            print(epoch, batch+1, train_loss/20)
            train_loss = 0.0

1 20 0.9483803659677505
1 40 0.8648304760456085
1 60 0.8221193641424179
1 80 0.8131710946559906
1 100 0.7558881729841233
