In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  Referenced from: <FD9BEDA3-7FDE-3298-84BC-7D1F1F8E037D> /usr/local/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <9ACECC86-1DF2-3366-9859-844A5C7C6E0E> /usr/local/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
def data_loader(data_dir,
                batch_size,
                random_seed=42,
                valid_size=0.1,
                shuffle=True,
                test=False):
  
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
    ])

    if test:
        dataset = datasets.CIFAR100(
          root=data_dir, train=False,
          download=True, transform=transform,
        )

        data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle
        )

        return data_loader

    # load the dataset
    train_dataset = datasets.CIFAR100(
        root=data_dir, train=True,
        download=True, transform=transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler)
 
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return (train_loader, valid_loader)


# CIFAR100 dataset 
train_loader, valid_loader = data_loader(data_dir='./data',
                                         batch_size=64)

test_loader = data_loader(data_dir='./data',
                              batch_size=64,
                              test=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class VGG(nn.Module):

    def __init__(self):

        super(VGG,self).__init__()

        self.sequential = nn.Sequential(
                                        nn.Conv2d(3,64,kernel_size = 3,stride = 1, padding=1),
                                        nn.BatchNorm2d(64),
                                        nn.ReLU(inplace=True),
                                        nn.MaxPool2d(2,stride=2),

                                        nn.Conv2d(64,128,kernel_size = 3,stride = 1, padding=1),
                                        nn.BatchNorm2d(128),
                                        nn.MaxPool2d(2,stride=2),

                                        nn.Conv2d(128,256,kernel_size = 3,stride = 1, padding=1),
                                        nn.Conv2d(256,256,kernel_size = 3,stride = 1, padding=1),
                                        nn.BatchNorm2d(256),
                                        nn.MaxPool2d(2,stride=2),

                                        nn.Conv2d(256,512,kernel_size = 3,stride = 1, padding=1),
                                        nn.Conv2d(512,512,kernel_size = 3,stride = 1, padding=1),
                                        nn.BatchNorm2d(512),
                                        nn.MaxPool2d(2,stride=2),

                                        nn.Conv2d(512,512,kernel_size = 3,stride = 1, padding=1),
                                        nn.Conv2d(512,512,kernel_size = 3,stride = 1, padding=1),
                                        nn.BatchNorm2d(512),
                                        nn.MaxPool2d(2,stride=2),

                                        nn.Flatten(-3),
                                        nn.Linear(7*7*512,4096),
                                        nn.Linear(4096,4096),
                                        nn.Linear(4096,1000),
                                        nn.Softmax()
        )

    def forward(self,x):

        return self.sequential(x)

In [4]:
model = VGG()

In [9]:
lr = 1e-3
betas = (0.9,0.999)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr,betas)

In [10]:
total_step = len(train_loader)
num_epochs = 100

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        labells = torch.zeros((labels.size(0),1000))
        for i in range(labells.size(0)):
            for j in range(labells.size(1)):

                labells[i,j] = (j==labels[i])
        # Forward pass
        model.requires_grad_(True)
        outputs = model(images)
        
        loss = criterion((outputs).double(), labells.double())


        optim.zero_grad()
        loss.backward()
        optim.step()

    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
    
        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))                