In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,random_split
import torch.optim as optim
from torchvision.transforms import transforms
from PIL import Image
import scipy
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from torch.utils.data import Subset

from utils import OxfordDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
transform_v = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

transform_t = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),        # Flip 50% of images
    transforms.RandomRotation(degrees=15),         # Rotate Â±15 degrees
    transforms.ColorJitter(                        # Color augmentation
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

In [4]:
train_dataset = OxfordDataset("flower_data", transform=transform_t)
validation_dataset = OxfordDataset("flower_data", transform=transform_v)
test_dataset = OxfordDataset("flower_data", transform=transform_v)


In [5]:
train_split = int(0.70 * len(train_dataset))
validation_split = int(0.15 * len(train_dataset))
test_split = len(train_dataset) - train_split - validation_split





In [6]:
indices = torch.randperm(len(train_dataset),generator=torch.Generator().manual_seed(42)).tolist()

train_indices = indices[:train_split]
validation_indices = indices[train_split:train_split+validation_split]
test_indices = indices[train_split+validation_split:]


In [7]:
train_dataset = Subset(train_dataset,train_indices)
val_dataset = Subset(validation_dataset,validation_indices)
test_dataset = Subset(test_dataset,test_indices)

In [8]:
test_indices.__len__()

1229

In [9]:
len(indices)

8189

In [8]:
train_loader = DataLoader(train_dataset,batch_size=32, shuffle=True)
validation_loader = DataLoader(val_dataset,batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [9]:
for image,label in train_loader:
    print(label)
    print(label[0].item())
    print(image.shape)
    break



tensor([77, 79, 30, 83, 58, 56, 80, 50, 50, 75, 22, 72, 39, 47, 36, 40, 95, 24,
        57, 33, 82, 63, 12, 69, 75, 86, 46, 87, 90,  2, 87, 78],
       dtype=torch.uint8)
77
torch.Size([32, 3, 224, 224])


In [38]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32,kernel_size=3,padding=1,stride=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2) # 224 -> 112

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64,kernel_size=3,stride=2,padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 112 -> 56 -> 28

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128,kernel_size=3,stride=2,padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 28 -> 14 -> 7

        self.conv4 = nn.Conv2d(in_channels=128, out_channels=512,kernel_size=3,stride=2,padding=1)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 7 -> 4 -> 2

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512,1024)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(1024,102)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)  
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        x = self.conv4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x



In [39]:
model = SimpleCNN().to(device)

optimizer = optim.Adam(model.parameters(),lr=0.0005, weight_decay=0.0005)
loss_function = nn.CrossEntropyLoss()


In [40]:
def train_epoch(model, train_loader, optimizer, loss_function):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if (batch_idx + 1) % 10 == 0:  # Changed condition
            avg_loss = running_loss / 10  # Now correct
            acc = 100. * correct / total
            print(f' [{(batch_idx + 1) * len(data)}/{len(train_loader.dataset)}]'
                  f' Loss: {avg_loss:.3f} | Accuracy: {acc:.1f} %')
            running_loss = 0.0  # Only reset loss, keep accuracy cumulative

In [41]:
def validation(model,validation_loader,device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for idx, (data, targets) in enumerate(validation_loader):
            data, targets = data.to(device), targets.to(device)
            output = model(data)
            _,predicted = output.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return 100. * correct/total


In [None]:
num_epochs = 10

for idx in range(num_epochs):
    print(f"Epoch : {idx}")
    train_epoch(model,train_loader,optimizer,loss_function)
    validation_acc = validation(model, validation_loader, device)
    print(f"Validation Accuracy : {validation_acc}")



Epoch : 0
 [320/5732] Loss: 2.659 | Accuracy: 30.9 %


In [17]:
test_acc = validation(model,test_loader,device)
print(test_acc)

43.77542717656631
