# Fine Tuning

Ask AI
> Using resnet18 from torchvision, how can I freeze the weight except for the last layer for fine tuning? I am using pytorch 2.7 and torchvision 0.22.

In [None]:
import torch
from torch import nn, optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader

| Model           | Number of Parameters |
|-----------------|---------------------|
| SqueezeNet1_1   | ~1.25M              |
| MobileNet V2    | ~3.5M               |
| ResNet18        | ~11.7M              |
| VGG16           | ~138M               |

In [None]:
# ----- Load pre-trained model (choose one by uncommenting) -----
# model_choice = "squeezenet1_1"
model_choice = "mobilenet_v2"
# model_choice = "resnet18"
# model_choice = "vgg16"

# ----- Choose trained mode (choose one by uncommenting) -----
is_fine_tuned = True
# is_fine_tuned = False

In [None]:
# Load the chosen pre-trained model from torchvision.models
if model_choice == "squeezenet1_1":
    model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
elif model_choice == "mobilenet_v2":
    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
elif model_choice == "resnet18":
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
elif model_choice == "vgg16":
    model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
else:
    raise Exception("Invalid model")  # Raise an error if input is invalid

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

data_dir = "dataset"
image_datasets = {
    x: datasets.ImageFolder(root=f"{data_dir}/{x}", transform=data_transforms[x])
    for x in ["train", "val"]
}
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=2)
    for x in ["train", "val"]
}
class_names = image_datasets["train"].classes
num_classes = len(class_names)
print(class_names)

In [None]:
from torchinfo import summary

input_size = (1, 3, 224, 224)  # (batch_size, channels, height, width)
summary(model, input_size=input_size)

In [None]:
for name, param in model.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Shape: {param.shape}")
    print(f"Requires grad: {param.requires_grad}")  # True if parameter will be updated
    print("-" * 50)

In [None]:
# If fine-tuning is enabled, freeze all feature extractor layers
if is_fine_tuned:
    for param in model.features.parameters():
        param.requires_grad = (
            False  # Freeze feature extraction layers (no updates during training)
        )

    # Replace the last classification layer based on the model type,
    # so only the new classification layer will be trainable
    if model_choice == "squeezenet1_1":
        # Replace final Conv2d layer for SqueezeNet1_1 to match the number of classes
        model.classifier[1] = nn.Conv2d(
            512, num_classes, kernel_size=(1, 1), stride=(1, 1)
        )
    elif model_choice == "mobilenet_v2":
        # Replace the last linear layer for MobileNet v2
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif model_choice == "resnet18":
        # Replace fully connected layer for ResNet18
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_choice == "vgg16":
        # Replace the last linear layer in the classifier for VGG16
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

# Set up the optimizer to update only parameters that require gradients (typically just the last layer)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

# Use cross-entropy loss for multi-class classification
criterion = nn.CrossEntropyLoss()

In [None]:
from torchinfo import summary

input_size = (1, 3, 224, 224)  # (batch_size, channels, height, width)
summary(model, input_size=input_size)

In [None]:
for name, param in model.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Shape: {param.shape}")
    print(f"Requires grad: {param.requires_grad}")  # True if parameter will be updated
    print("-" * 50)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloaders["train"]:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(image_datasets["train"])
    print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {epoch_loss:.4f}")

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in dataloaders["val"]:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
print(f"Validation Accuracy: {100 * correct / total:.2f}%")

In [None]:
import cv2
import torch
from torchvision import transforms
from PIL import Image

model.eval()

# Transformation matching your training pipeline
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# Start webcam feed
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("Error: Webcam access failed")
    exit()

print("Press 'q' to quit.")

while True:
    ret, frame = cap.read()
    if not ret:
        print("Error: Frame capture failed")
        break

    # Preprocess current frame for model
    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_pil = Image.fromarray(img)
    input_tensor = preprocess(img_pil)
    input_batch = input_tensor.unsqueeze(0)

    # Inference
    with torch.no_grad():
        outputs = model(input_batch)
        _, pred = torch.max(outputs, 1)
        predicted_class = class_names[pred.item()]

    # Display prediction on the frame
    cv2.putText(
        frame,
        f"Prediction: {predicted_class}",
        (10, 30),
        cv2.FONT_HERSHEY_SIMPLEX,
        1,
        (0, 255, 0),
        2,
        cv2.LINE_AA,
    )

    cv2.imshow("Webcam - Press q to quit", frame)

    if cv2.waitKey(1) & 0xFF == ord("q"):
        break

cap.release()
cv2.destroyAllWindows()