In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import os

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# download the COCO dataset
!wget http://images.cocodataset.org/zips/train2017.zip
!wget http://images.cocodataset.org/zips/val2017.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

# unzip the files
!unzip train2017.zip
!unzip val2017.zip
!unzip annotations_trainval2017.zip


In [None]:
# define the transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# define the dataset
trainset = torchvision.datasets.CocoDetection(root='./train2017', annFile='./annotations/instances_train2017.json', transform=transform)
valset = torchvision.datasets.CocoDetection(root='./val2017', annFile='./annotations/instances_val2017.json', transform=transform)

# define the dataloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=4, shuffle=False, num_workers=2)


In [None]:
# plot 9 random images from the dataset
def plot_random_images(dataset, num_cols=3, num_rows=3):
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    for i in range(num_rows):
        for j in range(num_cols):
            index = np.random.randint(len(dataset))
            image = dataset[index][0].numpy().transpose(1, 2, 0)
            image = (image + 1) / 2
            axes[i, j].imshow(image)
            axes[i, j].axis('off')
    plt.show()

plot_random_images(trainset)


In [None]:
# what is the size of the dataset?
print('Size of the training dataset:', len(trainset))
print('Size of the validation dataset:', len(valset))

# what is the number of classes in the dataset? 
num_of_classes = len(trainset.coco.getCatIds())
print('Number of classes in the dataset:', num_of_classes)


In [None]:
# use timm to load ResNet18 model
import timm
model = timm.create_model('resnet18', pretrained=True)
model.fc = nn.Linear(512, num_of_classes)
model = model.to(device)

# define the loss function and the optimizer
criterion = nn.CrossEntropyLoss()
