In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm
import numpy as np
import pandas as pd 
import os
from PIL import Image, ImageOps, ImageFilter 
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt 

In [24]:
torch.set_num_threads(os.cpu_count())

In [33]:
class DogBreedDataset(Dataset):
    def __init__(self, label_df, img_dir, labels, transform=None): 
        self.labels_df = label_df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

        # Dict mapping 'breed' to index, model will operate on integers and we will convert back to breeds 
        self.classes = labels

    def __len__(self): 
        return len(self.labels_df)
    
    def __getitem__(self, index): 
        row = self.labels_df.iloc[index]
        path = os.path.join(self.img_dir, row["id"]+ ".jpg")
        img = Image.open(path)
        img = img.convert("RGB")

        if self.transform: 
            img = self.transform(img)

        label = self.classes[row['breed']]

        return img, label

In [34]:
labels = pd.read_csv("data/labels.csv")
breeds = labels['breed'].unique() 

# Dict mapping 'breed' to index, model will operate on integers and we will convert back to breeds 
classes = {b: i for i, b in enumerate(breeds)}

In [35]:
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(), 
    transforms.RandomResizedCrop(224, scale = (0.8, 1.0)),
    transforms.RandomRotation(15), 
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), 
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [36]:
dataset = DogBreedDataset(labels, "data/train", classes, train_transforms)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [37]:
num_classes = len(dataset.classes)

In [38]:
train_df, val_df = train_test_split(dataset.labels_df, 
                                    test_size=0.2, 
                                    stratify=dataset.labels_df['breed'])

train_df.to_csv("train_split.csv", index=False)
val_df.to_csv("val_split.csv", index=False)

In [39]:
train_dataset = DogBreedDataset(
    label_df=train_df,
    img_dir ='data/train',
    labels = classes,
    transform=train_transforms
)

val_dataset = DogBreedDataset(
    label_df = val_df,
    img_dir ='data/train',
    labels = classes, 
    transform=val_transforms
)

In [40]:
print(train_dataset.classes)
print(val_dataset.classes)

{'boston_bull': 0, 'dingo': 1, 'pekinese': 2, 'bluetick': 3, 'golden_retriever': 4, 'bedlington_terrier': 5, 'borzoi': 6, 'basenji': 7, 'scottish_deerhound': 8, 'shetland_sheepdog': 9, 'walker_hound': 10, 'maltese_dog': 11, 'norfolk_terrier': 12, 'african_hunting_dog': 13, 'wire-haired_fox_terrier': 14, 'redbone': 15, 'lakeland_terrier': 16, 'boxer': 17, 'doberman': 18, 'otterhound': 19, 'standard_schnauzer': 20, 'irish_water_spaniel': 21, 'black-and-tan_coonhound': 22, 'cairn': 23, 'affenpinscher': 24, 'labrador_retriever': 25, 'ibizan_hound': 26, 'english_setter': 27, 'weimaraner': 28, 'giant_schnauzer': 29, 'groenendael': 30, 'dhole': 31, 'toy_poodle': 32, 'border_terrier': 33, 'tibetan_terrier': 34, 'norwegian_elkhound': 35, 'shih-tzu': 36, 'irish_terrier': 37, 'kuvasz': 38, 'german_shepherd': 39, 'greater_swiss_mountain_dog': 40, 'basset': 41, 'australian_terrier': 42, 'schipperke': 43, 'rhodesian_ridgeback': 44, 'irish_setter': 45, 'appenzeller': 46, 'bloodhound': 47, 'samoyed': 

In [41]:
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,  
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)


In [42]:
# Test if the dataset itself is working
print("Checking first few dataset items...")
for i in range(3):
    try:
        x, y = train_dataset[i]
        print(f"Item {i}: Image shape = {x.shape}, Label = {y}")
    except Exception as e:
        print(f"Error at index {i}: {e}")


Checking first few dataset items...
Item 0: Image shape = torch.Size([3, 224, 224]), Label = 65
Item 1: Image shape = torch.Size([3, 224, 224]), Label = 113
Item 2: Image shape = torch.Size([3, 224, 224]), Label = 25


In [43]:
device  = "cpu"
cnn = models.resnet18(weights="IMAGENET1K_V1")

for param in cnn.parameters(): 
    param.requires_grad = False

cnn.fc = nn.Sequential(
    nn.Linear(cnn.fc.in_features, 256), 
    nn.ReLU(), 
    nn.Dropout(0.5), 
    nn.Linear(256, num_classes)
)

cnn=cnn.to(device)
print(cnn)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [44]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.fc.parameters(), lr=1e-4, weight_decay=1e-4)

In [47]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs): 
    for epoch in range(num_epochs): 
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for inputs, labels in tqdm(train_loader): 
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs,1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += (preds == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {epoch_loss} | Accuracy: {epoch_acc}")


        model.eval()
        val_loss = 0.0
        val_corrects = 0
        with torch.no_grad(): 
            for inputs, labels in val_loader: 
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                val_loss += loss.item() * inputs.size(0)
                val_corrects += (preds == labels).sum().item()
        
        val_loss /= len(val_loader.dataset)
        val_acc = val_corrects / len(val_loader.dataset)
        print(f"Validation Loss: {val_loss} | Accuracy: {val_acc}")


In [48]:
train(cnn, train_loader, val_loader, criterion, optimizer, num_epochs=5)


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]


Epoch 1/5 | Train Loss: 4.445388658247616 | Accuracy: 0.08499449675920265
Validation Loss: 4.108311125235336 | Accuracy: 0.24058679706601466


100%|██████████| 128/128 [02:41<00:00,  1.26s/it]


Epoch 2/5 | Train Loss: 4.061136771986184 | Accuracy: 0.15115568056744527
Validation Loss: 3.597309883532722 | Accuracy: 0.34621026894865525


100%|██████████| 128/128 [03:03<00:00,  1.44s/it]


Epoch 3/5 | Train Loss: 3.643144564478733 | Accuracy: 0.21010150421915127
Validation Loss: 3.1117800629809316 | Accuracy: 0.4444987775061125


100%|██████████| 128/128 [02:49<00:00,  1.33s/it]


Epoch 4/5 | Train Loss: 3.28300648063906 | Accuracy: 0.26232114467408585
Validation Loss: 2.720641226873421 | Accuracy: 0.5080684596577018


100%|██████████| 128/128 [02:49<00:00,  1.33s/it]


Epoch 5/5 | Train Loss: 2.978823646467502 | Accuracy: 0.2979087684970038
Validation Loss: 2.3912030819284302 | Accuracy: 0.5643031784841076


In [49]:
model_state = cnn.state_dict()

torch.save(model_state, "model_1.0")