In [1]:
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.models as models
import ssl


from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

ssl._create_default_https_context = ssl._create_unverified_context

In [2]:
##https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),        # 10 output classes
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 * 2)
        x = self.classifier(x)
        return x
    
model=AlexNet()

In [3]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs):

    train_losses = []
    valid_losses = []
    best_model_loss = 5
    
    for epoch in range(num_epochs):

        # Training
        for i, (data, labels) in enumerate(train_loader):
      
            prediction = model.forward(data)

            train_loss = criterion(prediction, labels)

            train_loss.backward()

            optimizer.step()

            optimizer.zero_grad()
        print(f'\rEpoch {epoch+1}, batch {i+1}/{len(train_loader)} - Loss: {train_loss}')

        train_losses.append(train_loss)
        writer.add_scalar("Loss/train ADAM", train_loss, epoch)

        # Validation
        for batch_nr, (data, labels) in enumerate(val_loader):
            prediction = model.forward(data)
            loss_val = criterion(prediction, labels)
            valid_losses.append(loss_val)
        print(f"loss validation: {loss_val}")
        #print(f"loss validation: {loss_val}","\n")

        if valid_losses[-1] < best_model_loss:
            print(f"\t > Found a better model, {best_model_loss} -> {valid_losses[-1]}")
            best_model = copy.deepcopy(model)
            best_model_loss = valid_losses[-1]

        writer.add_scalar("Loss/validation ADAM", loss_val, epoch)

    print(f"\nBest model loss: {best_model_loss}")
    return best_model, train_losses, valid_losses

In [4]:
def get_accuracy(network, loader):
    
    with torch.no_grad():
        correct = 0
        total = 0
        y_pred = []
        y_true = []

        for x, (data, labels) in enumerate(loader):

            prediction = network.forward(data)

            for i in range(len(data)):

                y_true.append(labels[i].item())
                y_pred.append(torch.argmax(prediction[i]).item())
                if y_true[i] == y_pred[i]:
                    correct += 1        
    
            total += float(len(data))
    
        score = correct/total

        accuracy = score

        return accuracy
    

In [5]:
LEARNING_RATE = 0.001
EPOCHS = 5
BATCH_SIZE = 1000
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)

validset, trainset = torch.utils.data.random_split(trainset, [10000, 40000])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True)
validloader = torch.utils.data.DataLoader(validset, batch_size=BATCH_SIZE,shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False)



# Loss function , Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# Training
trained_model, train_loss, valid_loss = train_model(model, criterion, optimizer, trainloader, validloader, EPOCHS)

# Testing
test_acc = get_accuracy(trained_model, testloader)
print(f"Model Accuracy (CIFAR10): {test_acc*100}%")
writer.flush()

Files already downloaded and verified
Files already downloaded and verified
Epoch 1, batch 40/40 - Loss: 1.855844259262085
loss validation: 1.964345932006836
	 > Found a better model, 5 -> 1.964345932006836
Epoch 2, batch 40/40 - Loss: 1.5980778932571411
loss validation: 1.5825409889221191
	 > Found a better model, 1.964345932006836 -> 1.5825409889221191
Epoch 3, batch 40/40 - Loss: 1.4516880512237549
loss validation: 1.4356021881103516
	 > Found a better model, 1.5825409889221191 -> 1.4356021881103516
Epoch 4, batch 40/40 - Loss: 1.2556418180465698
loss validation: 1.2733981609344482
	 > Found a better model, 1.4356021881103516 -> 1.2733981609344482
Epoch 5, batch 40/40 - Loss: 1.1666392087936401
loss validation: 1.174381971359253
	 > Found a better model, 1.2733981609344482 -> 1.174381971359253

Best model loss: 1.174381971359253
Model Accuracy (CIFAR10): 57.199999999999996%


In [6]:
model = nn.Sequential(
    model,
    nn.Linear(10, 10))
print(model)

Sequential(
  (0): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=False)
      (1): Linear(in_features=1024, out_features=4096, bias=True)

In [7]:
LEARNING_RATE = 0.01
EPOCHS = 5
num_classes=10
model_featext=trained_model


# Freeze all layers except the last few layers
for name, param in model_featext.named_parameters():
    if "Sequential.1" in name or "classifier.6" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# Train the model
trained_model_featext, train_loss, valid_loss = train_model(model_featext, criterion, optimizer, trainloader, validloader, EPOCHS)


test_acc = get_accuracy(trained_model_featext, testloader)
print(f"Model Accuracy (AlexNet PreTrained): {test_acc*100}%")
writer.flush()

Epoch 1, batch 40/40 - Loss: 1.0603957176208496
loss validation: 1.1493136882781982
	 > Found a better model, 5 -> 1.1493136882781982
Epoch 2, batch 40/40 - Loss: 1.0956660509109497
loss validation: 1.1575943231582642
Epoch 3, batch 40/40 - Loss: 1.084247350692749
loss validation: 1.100934624671936
	 > Found a better model, 1.1493136882781982 -> 1.100934624671936
Epoch 4, batch 40/40 - Loss: 1.052489161491394
loss validation: 1.2534207105636597
Epoch 5, batch 40/40 - Loss: 1.1280475854873657
loss validation: 1.2305376529693604

Best model loss: 1.100934624671936
Model Accuracy (AlexNet PreTrained): 57.099999999999994%
