In [1]:
%cd ..

/home/atanu/Documents/swatah/mobilenetv2


In [2]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

train_data = datasets.ImageFolder('custom_dataset/train', transform=transform)
val_data = datasets.ImageFolder('custom_dataset/test', transform=transform)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
val_loader = DataLoader(val_data, batch_size=4)

In [3]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V2)

num_classes = len(train_data.classes)
model.classifier[1] = nn.Linear(model.last_channel, num_classes)

In [4]:
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
epochs = 2

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.3f}, Accuracy: {100.*correct/total:.2f}%")

Epoch 1/2, Loss: 1357.555, Accuracy: 88.64%
Epoch 2/2, Loss: 917.403, Accuracy: 92.40%


In [6]:
model.eval()
val_correct = 0
val_total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        val_total += labels.size(0)
        val_correct += predicted.eq(labels).sum().item()

print(f"Validation Accuracy: {100.*val_correct/val_total:.2f}%")

Validation Accuracy: 93.82%


In [None]:
torch.save(model.state_dict(), 'models/mobilenetv2_custom.pth')