In [1]:
import torch 
import torch.nn as nn

import numpy as np

from PIL import Image 
from datasets import load_dataset 
from torch.utils.data import Dataset, DataLoader 
from torchvision.models import resnet18 
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DATASET_NAME = 'cats_vs_dogs'
datasets = load_dataset(DATASET_NAME, split="train[:1000]")
datasets

Dataset({
    features: ['image', 'labels'],
    num_rows: 1000
})

In [4]:
TEST_SIZE = 0.2 
datasets = datasets.train_test_split(test_size=TEST_SIZE)
datasets

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 800
    })
    test: Dataset({
        features: ['image', 'labels'],
        num_rows: 200
    })
})

In [5]:
IMG_SIZE = 64 
img_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

In [6]:
class CatDogDataset(Dataset):
    def __init__(self, data, transforms=None):
        self.data = data 
        self.transforms = transforms 

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        images = self.data[idx]['image']
        labels = self.data[idx]['labels']

        if self.transforms:
            images = self.transforms(images)

        labels = torch.tensor(labels, dtype=torch.long)

        return images, labels

In [7]:
TRAIN_BATCH_SIZE = 512
VAL_BATCH_SIZE = 256

train_dataset = CatDogDataset(
    datasets['train'], transforms=img_transforms
)
test_dataset = CatDogDataset(
    datasets['test'], transforms=img_transforms
)

train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=True
)

In [8]:
class CatDogModel(nn.Module):
    def __init__(self, n_classes):
        super(CatDogModel, self).__init__()

        resnet_model = resnet18(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(resnet_model.children())[:-1])

        for param in self.backbone.parameters():
            param.requires_grad = False 

        in_features = resnet_model.fc.in_features 
        self.fc = nn.Linear(in_features, n_classes)

    def forward(self, X):
        x = self.backbone(X)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [10]:
N_CLASSES = 2
model = CatDogModel(N_CLASSES).to(device)
test_input = torch.rand(1, 3, 224, 224).to(device)
with torch.no_grad():
    output = model(test_input)
    print(output.shape)

torch.Size([1, 2])


In [11]:
EPOCHS = 10 
LR = 1e-3 
WEIGHT_DECAY = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = torch.nn.CrossEntropyLoss()

In [17]:
for epoch in range(EPOCHS):
    train_losses = []
    train_correct = 0
    total_train = 0
    model.train()

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)

        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

       
        _, preds = torch.max(outputs, 1)  
        train_correct += (preds == labels).sum().item()
        total_train += labels.size(0)

        train_losses.append(loss.item())

    train_loss = sum(train_losses) / len(train_losses)
    train_acc = train_correct / total_train

    val_losses = []
    val_correct = 0
    total_val = 0
    model.eval()

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_losses.append(loss.item())

            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            total_val += labels.size(0)

    val_loss = sum(val_losses) / len(val_losses)
    val_acc = val_correct / total_val

    print(f'EPOCH {epoch+1}:\tTrain loss: {train_loss:.3f}, Train Acc: {train_acc:.3f}\tVal Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}')

EPOCH 1:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 2:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 3:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 4:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 5:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 6:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 7:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 8:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 9:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000
EPOCH 10:	Train loss: 0.001, Train Acc: 1.000	Val Loss: 0.001, Val Acc: 1.000


In [18]:
SAVE_PATH = 'models/weights/catdog_weights.pt'
torch.save(model.state_dict(), SAVE_PATH)