In [1]:
import torch
import datasets

# Load the dataset
train_dataset = datasets.load_dataset("zh-plus/tiny-imagenet", split="train", cache_dir='images_dir')
valid_dataset = datasets.load_dataset("zh-plus/tiny-imagenet", split="valid", cache_dir='images_dir')

Found cached dataset parquet (/media/dnth/Active-Projects/vl-datasets/notebooks/images_dir/zh-plus___parquet/Maysee--tiny-imagenet-2eb6c3acd8ebc62a/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
Found cached dataset parquet (/media/dnth/Active-Projects/vl-datasets/notebooks/images_dir/zh-plus___parquet/Maysee--tiny-imagenet-2eb6c3acd8ebc62a/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


In [2]:
train_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
 'label': 0}

In [3]:
import torchvision.transforms as transforms

train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

valid_transform = transforms.Compose(
    [
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_valid(example_batch):
    """Apply valid_transforms across a batch."""
    example_batch["pixel_values"] = [
        valid_transform(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

In [4]:
train_dataset.set_transform(preprocess_train)
valid_dataset.set_transform(preprocess_valid)

In [5]:
train_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
 'label': 0,
 'pixel_values': tensor([[[ 0.6734,  0.6563,  0.6392,  ..., -0.9363, -0.8507, -0.7822],
          [ 0.7077,  0.6906,  0.6734,  ..., -0.9020, -0.7993, -0.7137],
          [ 0.8104,  0.7762,  0.7419,  ..., -0.8507, -0.6623, -0.5424],
          ...,
          [-0.5596, -0.6281, -0.7137,  ..., -0.7822, -0.8164, -0.8335],
          [-0.5424, -0.6109, -0.7308,  ..., -0.9020, -0.8678, -0.8507],
          [-0.5424, -0.6109, -0.7308,  ..., -0.9534, -0.9020, -0.8678]],
 
         [[ 0.4853,  0.4853,  0.5028,  ...,  0.1877,  0.2577,  0.3102],
          [ 0.5028,  0.5028,  0.5028,  ...,  0.2052,  0.3102,  0.3803],
          [ 0.5203,  0.5203,  0.5203,  ...,  0.2402,  0.4153,  0.5203],
          ...,
          [-0.5126, -0.5826, -0.6702,  ..., -0.1450, -0.2150, -0.2675],
          [-0.5126, -0.6001, -0.7227,  ..., -0.2850, -0.2850, -0.2850],
          [-0.5126, -0.6001, -0.7402,  ..., -0.3550, -0.3200, -0.3025]],


In [6]:
valid_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
 'label': 0,
 'pixel_values': tensor([[[2.1119, 2.1290, 2.1633,  ..., 2.2318, 2.2318, 2.2318],
          [2.1804, 2.1633, 2.1633,  ..., 2.2318, 2.2318, 2.2318],
          [2.2318, 2.2318, 2.1975,  ..., 2.2147, 2.2147, 2.2147],
          ...,
          [1.2728, 1.9578, 2.2489,  ..., 1.6838, 1.6667, 1.5982],
          [2.0777, 2.1119, 1.8208,  ..., 2.1462, 2.0777, 1.9920],
          [2.0263, 1.9578, 2.0777,  ..., 2.2318, 2.1975, 2.1975]],
 
         [[2.3936, 2.4111, 2.4111,  ..., 2.4111, 2.4111, 2.4111],
          [2.4286, 2.4111, 2.4111,  ..., 2.4111, 2.4111, 2.4111],
          [2.4286, 2.4286, 2.4111,  ..., 2.3936, 2.3936, 2.3936],
          ...,
          [1.4657, 2.1660, 2.4286,  ..., 1.8508, 1.8333, 1.7633],
          [2.4286, 2.4286, 2.1835,  ..., 2.3235, 2.3060, 2.2185],
          [2.4286, 2.4286, 2.4286,  ..., 2.4111, 2.4286, 2.4286]],
 
         [[2.5703, 2.5877, 2.6051,  ..., 2.6226, 2.6226, 2.6226],
    

In [7]:
train_dataset[0]["pixel_values"].shape

torch.Size([3, 64, 64])

In [8]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [9]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True , collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.features["label"].names))

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
from tqdm.auto import tqdm

num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

for epoch in tqdm(range(num_epochs), desc="Epochs"):
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
        inputs, labels = data["pixel_values"], data["labels"]
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {running_loss/len(train_loader)}")


Using device: cuda


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1 - Loss: 3.4840661308649556


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2 - Loss: 2.9641834250496477


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3 - Loss: 2.7713216058433514


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4 - Loss: 2.6330488709842457


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5 - Loss: 2.5439858345119544


In [13]:
correct = 0
total = 0
with torch.no_grad():
    for data in tqdm(valid_loader, desc="Validation"):
        inputs, labels = data["pixel_values"], data["labels"]
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total}%")


Validation:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 43.43%
