<a href="https://colab.research.google.com/github/skj092/Image-Classification-with-CIFAR-100/blob/main/CIFAR_10_PT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
# importing the necessary libraries
import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt 
from tqdm import tqdm 
from sklearn.metrics import accuracy_score

In [50]:
# config
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 32
epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [51]:
# dataset and dataloader
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=tfms, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

valid_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=tfms, download=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [52]:

# visualizing the dataset 
def visualize_dataset(train_ds):
    plt.figure(figsize=(10, 10))
    plt.imshow(train_ds[0][0].permute(1, 2, 0).numpy())
    plt.title(str(train_ds[0][1]))
    plt.show()

In [53]:
# Define the network:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [54]:
# Defninig Loss and Optimizer
model = Net()
model.to(device)

# Testing on one batch
xb, yb = next(iter(train_loader))
xb = xb.to(device)
yb = yb.to(device)
print(model(xb[0:1]).shape)

torch.Size([1, 10])


In [55]:
# Validation
def Validation(dl):
    losses, accuracies = [], []
    with torch.no_grad():
        loop = tqdm(dl, desc='Validation')
        for data in loop:
            xb, yb = data[0].to(device), data[1].to(device)
            output = model(xb)
            loss = criterion(output, yb)
            predictions = torch.argmax(output, dim=1)
            acc = accuracy_score(yb.cpu().numpy(), predictions.cpu().numpy())
            losses.append(loss.item())
            accuracies.append(acc)
            loop.set_postfix(Loss=sum(losses)/len(losses), Accuracy=sum(accuracies)/len(accuracies))


# Training the Model
def train(model, train_dl, valid_dl, optimizer, criterion, epochs):
    # To store train accuracy and loss 
    train_acc, train_loss = [], []
    for epoch in range(epochs):
        running_loss = 0.0
        loop = tqdm(train_dl)
        for i, data in enumerate(loop):
            xb = data[0].to(device)
            yb = data[1].to(device)
            optimizer.zero_grad()
            output = model(xb)
            loss = criterion(output, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            loop.set_description(f"Epoch [{epoch}/{epochs}]")
            loop.set_postfix(loss=running_loss/len(train_dl))
        Validation(valid_dl)

In [56]:
# defining loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training the Model
train(model,train_loader,valid_loader, optimizer, criterion, 3)

# saving the model 
# torch.save(model.state_dict(), 'model.pkl')

Epoch [0/3]: 100%|██████████| 1563/1563 [00:25<00:00, 60.56it/s, loss=2.2]
Validation: 100%|██████████| 313/313 [00:03<00:00, 86.30it/s, Accuracy=0.299, Loss=1.96]
Epoch [1/3]: 100%|██████████| 1563/1563 [00:26<00:00, 59.69it/s, loss=1.79]
Validation: 100%|██████████| 313/313 [00:03<00:00, 84.97it/s, Accuracy=0.404, Loss=1.63]
Epoch [2/3]: 100%|██████████| 1563/1563 [00:25<00:00, 60.71it/s, loss=1.57]
Validation: 100%|██████████| 313/313 [00:03<00:00, 87.58it/s, Accuracy=0.453, Loss=1.49]
