In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights

from torchinfo import summary
from torchmetrics.functional.classification import multiclass_accuracy

In [4]:
# Load Data
img_dir = '../data/Snake Images copy/'

# Resnet 전처리
preprocessing = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),    # 이미지 크기 조절
    # - 256x256
    # - BILINEAR : 이미지 픽셀 사이의 값을 보간하여 새로운 픽셀 값을 생성
    # - InterpolationMode : 보간 방법
    transforms.CenterCrop(224),   # 이미지 중앙을 기준으로 224x224로 이미지를 자름
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 이미지 정규화
])

# 이미지 데이터셋 생성
# - ImageFolder : Folder 구조로 된 이미지 데이터셋을 로드
img_ds = ImageFolder(img_dir, transform=preprocessing)

print(img_ds.classes, img_ds.targets, img_ds.imgs, sep='\n')

['test', 'train']
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [5]:
# 데이터 로더 생성
img_dl = DataLoader(img_ds, batch_size=3, shuffle=True, drop_last=True)
for img, label in img_dl:
    print(img.shape, label)
    break

torch.Size([3, 3, 224, 224]) tensor([1, 1, 1])


In [6]:
# 모델 설계 및 설정
## 사전 학습된 모델 인스턴스 생성
res_model = resnet18(weights=ResNet18_Weights.DEFAULT)

## 모델의 출력 레이어(전결합층) 변경
# - in_features : 입력 특징의 개수, out_features : 출력 특징의 개수
res_model.fc = nn.Linear(in_features=512, out_features=3)   # out_features : 분류할 클래스 개수   

In [7]:
# Resnet18 Feature module parameters : requires_grad=true -> False
for name, param in res_model.named_parameters():
    print(name, param.requires_grad, end='==>')
    param.requires_grad = False
    print(param.requires_grad)
    
# Resnet18 Full Connected module parameters : requires_grad=false -> True
for name, param in res_model.fc.named_parameters():
    print(name, param.requires_grad, end='==>')
    param.requires_grad = True
    print(param.requires_grad)

conv1.weight True==>False
bn1.weight True==>False
bn1.bias True==>False
layer1.0.conv1.weight True==>False
layer1.0.bn1.weight True==>False
layer1.0.bn1.bias True==>False
layer1.0.conv2.weight True==>False
layer1.0.bn2.weight True==>False
layer1.0.bn2.bias True==>False
layer1.1.conv1.weight True==>False
layer1.1.bn1.weight True==>False
layer1.1.bn1.bias True==>False
layer1.1.conv2.weight True==>False
layer1.1.bn2.weight True==>False
layer1.1.bn2.bias True==>False
layer2.0.conv1.weight True==>False
layer2.0.bn1.weight True==>False
layer2.0.bn1.bias True==>False
layer2.0.conv2.weight True==>False
layer2.0.bn2.weight True==>False
layer2.0.bn2.bias True==>False
layer2.0.downsample.0.weight True==>False
layer2.0.downsample.1.weight True==>False
layer2.0.downsample.1.bias True==>False
layer2.1.conv1.weight True==>False
layer2.1.bn1.weight True==>False
layer2.1.bn1.bias True==>False
layer2.1.conv2.weight True==>False
layer2.1.bn2.weight True==>False
layer2.1.bn2.bias True==>False
layer3.0.con

In [None]:
# 학습 준비
optimizer = optim.Adam(res_model.parameters())
loss_fn = nn.CrossEntropyLoss()
EPOCHS = 5

# ... 등등 이대로 진행하면 됨