In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2

In [3]:
import torch
import torch.nn as nn
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset , DataLoader
from torchvision.models import resnet18
from torchvision import transforms

In [4]:
DATASET_NAME = 'cats_vs_dogs'
datasets = load_dataset(DATASET_NAME)
datasets

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/8.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/330M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/391M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23410 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 23410
    })
})

In [5]:
TEST_SIZE = 0.2
datasets = datasets ['train'].train_test_split(test_size = TEST_SIZE)


In [6]:
IMG_SIZE = 64
img_transforms = transforms.Compose ([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels =3),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485 , 0.456 , 0.406] ,
        [0.229 , 0.224 , 0.225]
    )
])

In [7]:
class CatDogDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        images = self.data[idx]['image']
        labels = self.data[idx]['labels']

        if self.transform:
            images = self.transform(images)

        labels = torch.tensor(labels, dtype=torch.long)

        return images, labels


In [8]:
TRAIN_BATCH_SIZE = 512
VAL_BATCH_SIZE = 256

train_dataset = CatDogDataset(datasets['train'], transform = img_transforms)
test_dataset = CatDogDataset (datasets['test'], transform = img_transforms )

train_loader = DataLoader ( train_dataset , batch_size = TRAIN_BATCH_SIZE ,
shuffle = True )
test_loader = DataLoader ( test_dataset , batch_size = VAL_BATCH_SIZE , shuffle =
False )

In [10]:
class CatDogModel(nn.Module):
    def __init__(self, n_classes):
        super(CatDogModel, self).__init__()

        resnet_model = resnet18(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(resnet_model.children())[:-1])

        for param in self.backbone.parameters():
            param.requires_grad = False

        in_features = resnet_model.fc.in_features
        self.fc = nn.Linear(in_features, n_classes)
    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


In [11]:
device = 'cuda' if torch.cuda.is_available () else 'cpu'
N_CLASSES = 2

model = CatDogModel(N_CLASSES).to(device)
test_input = torch.rand(1 , 3 , 224 , 224).to( device )
with torch.no_grad():
  output = model ( test_input )
  print(output.shape ) # (1 , 2)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 174MB/s]


torch.Size([1, 2])


In [12]:
import torch
import torch.nn as nn
import torch.optim as optim

# Hyperparameters
EPOCHS = 10
LR = 1e-3
WEIGHT_DECAY = 1e-5

# Assuming model, train_loader, test_loader, and device are already defined
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    train_losses = []
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    train_loss = sum(train_losses) / len(train_losses)

    val_losses = []
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_losses.append(loss.item())

    val_loss = sum(val_losses) / len(val_losses)

    print(f'EPOCH {epoch + 1}:\tTrain loss: {train_loss:.3f}\tVal loss: {val_loss:.3f}')


EPOCH 1:	Train loss: 0.642	Val loss: 0.599
EPOCH 2:	Train loss: 0.543	Val loss: 0.543
EPOCH 3:	Train loss: 0.526	Val loss: 0.533
EPOCH 4:	Train loss: 0.518	Val loss: 0.531
EPOCH 5:	Train loss: 0.514	Val loss: 0.526
EPOCH 6:	Train loss: 0.513	Val loss: 0.524
EPOCH 7:	Train loss: 0.505	Val loss: 0.529
EPOCH 8:	Train loss: 0.507	Val loss: 0.520
EPOCH 9:	Train loss: 0.504	Val loss: 0.521
EPOCH 10:	Train loss: 0.502	Val loss: 0.521


In [13]:
SAVE_PATH = 'catdog_weights.pt'
torch.save(model.state_dict(), SAVE_PATH)