# **Lab-05**

**Aim:** To implement AlexNet architecture


In [None]:
!pip install torchmetrics

In [None]:
!unzip /content/Face-Images

In [3]:
!find . -name "*.DS_Store" -type f -delete

In [4]:
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor, transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import os

In [5]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))

    def __len__(self):
        return sum(len(files) for _, _, files in os.walk(self.root_dir))

    def __getitem__(self, idx):
        class_idx = 0
        while idx >= len(os.listdir(os.path.join(self.root_dir, self.classes[class_idx]))):
            idx -= len(os.listdir(os.path.join(self.root_dir, self.classes[class_idx])))
            class_idx += 1
        class_dir = os.path.join(self.root_dir, self.classes[class_idx])
        file_name = os.listdir(class_dir)[idx]
        image = Image.open(os.path.join(class_dir, file_name))
        if self.transform:
            image = self.transform(image)
        return image, class_idx


transform = transforms.Compose([
    transforms.Resize((227,227)),
    transforms.ToTensor()
])

traindataset = CustomDataset(root_dir='/content/Face Images/Final Training Images', transform=transform)
testdataset = CustomDataset(root_dir='/content/Face Images/Final Testing Images', transform=transform)
train = DataLoader(traindataset, batch_size=32, shuffle=True)
test = DataLoader(testdataset, batch_size=32, shuffle=True)

In [6]:
class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [11]:
import torchmetrics
from tqdm.auto import tqdm
import torch.optim as optim
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix

In [18]:
AlexNetModel=AlexNet(16)
lossfn=nn.CrossEntropyLoss()
optimizer=optim.SGD(params=AlexNetModel.parameters(),
                    lr=0.001,
                    momentum=0.9)
accuracy=torchmetrics.classification.Accuracy(task='multiclass',num_classes=16)

epochs=25


for i, data in enumerate(train, 0):
      inputs, labels = data
      optimizer.zero_grad()
      AlexNetModel.train()
      outputs = AlexNetModel(inputs)
      loss = lossfn(outputs, labels)
      loss.backward()
      optimizer.step()
      accuracy.update(outputs, labels)


print("finished training")


finished training


In [19]:
AlexNetModel.eval()
for data in test:
  inputs, labels = data
  outputs = AlexNetModel(inputs)
  loss = lossfn(outputs, labels)
  accuracy.update(outputs, labels)

print(f'Accuracy on test set: {accuracy.compute()*100}')

Accuracy on test set: 5.844155788421631


In [16]:
torch.save(AlexNetModel.state_dict(), '/content/model_weights.pth')