In [32]:
import torch 
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn as nn
from torchsummary import summary

from torchvision.models import resnet50, ResNet50_Weights

In [33]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ImageFolder(root="dataset 2", transform=transform)

# dataloader for batches   
dataloader = DataLoader(dataset, batch_size= 32, shuffle=True)

In [34]:
# load pre-train model: get most updated weights
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# change the input layer dimension 
input_dim = len(dataset.classes)
model.fc = nn.Linear(model.fc.in_features, input_dim )

# check the architecture
summary(model, input_size=(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [35]:
# freeze some layers
for para in model.parameters():
    para.requires_grad= False

# unfreeze last layers for fine-tuning 