In [1]:
import os
from PIL import Image

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.transforms as transforms

from torch.utils.data import DataLoader, Dataset, ConcatDataset, WeightedRandomSampler, Subset
import torchvision.transforms as transforms
import torchvision.models as models

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [3]:
def conv3x3(in_planes, out_planes, stride=1, padding=1, bias=False):
  return nn.Conv2d(in_planes, out_planes,
      kernel_size =3,
      stride = stride,
      padding = padding,
      bias = bias
  )

def conv1x1(in_planes, out_planes, stride=1, padding=0, bias=False):
  return nn.Conv2d(in_planes, out_planes,
      kernel_size =1,
      stride = stride,
      padding = padding,
      bias = bias
  )

In [4]:
# ResNet 18, 34에 쓰이는 Basic Block
class BasicBlock(nn.Module):
  mul = 1
  def __init__(self, in_planes, out_planes, stride=1):
    super(BasicBlock, self).__init__()

    self.conv1 = conv3x3(in_planes, out_planes, stride)
    self.conv2 = conv3x3(out_planes, out_planes, 1)

    self.bn1 = nn.BatchNorm2d(out_planes)
    self.bn2 = nn.BatchNorm2d(out_planes)

    self.shortcut = nn.Sequential()
    if stride != 1:
      self.shortcut = nn.Sequential(
          conv1x1(in_planes, out_planes, stride),
          nn.BatchNorm2d(out_planes)
      )

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = F.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out += self.shortcut(x)
    out = F.relu(out)
    return out



In [5]:
class ResNet(nn.Module):
  def __init__(self, block, num_blocks, num_classes=31):
    super(ResNet, self).__init__()

    # 7*7, 64 channels, stride 2 in paper
    self.in_planes = 64

    # RGB channel -> 64 channels
    self.conv = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3)
    self.bn = nn.BatchNorm2d(self.in_planes)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    # self.gender_fc = nn.Linear(num_features, 2)  # 성별 분류
    # self.style_fc = nn.Linear(num_features, 31)  # 스타일 분류

    _layers = []
    outputs, strides = [64, 128, 256, 512], [1, 2, 2, 2]
    for i in range(4):
      _layers.append(self._make_layer(block, outputs[i], num_blocks[i], stride=strides[i]))
    self.layers = nn.Sequential(*_layers)

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.linear = nn.Linear(512 * block.mul, num_classes)

  def _make_layer(self, block, out_planes, num_block, stride):
    layers = [ block(self.in_planes, out_planes, stride) ]
    self.in_planes = block.mul * out_planes
    for i in range(num_block - 1):
      layers.append(block(self.in_planes, out_planes, 1))

    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.conv(x)
    out = self.bn(out)
    out = F.relu(out)
    out = self.maxpool(out)

    out = self.layers(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)

    return out

In [6]:
def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])

In [7]:
# 남성 스타일과 여성 스타일을 성별 구분하여 인덱싱
style_dict = {
    # 남성 스타일 (8개)
    'M_bold': 0,
    'M_hiphop': 1,
    'M_hippie': 2,
    'M_ivy': 3,
    'M_metrosexual': 4,
    'M_mods': 5,
    'M_normcore': 6,
    'M_sportivecasual': 7,

    # 여성 스타일 (23개)
    'W_athleisure': 8,
    'W_bodyconscious': 9,
    'W_cityglam': 10,
    'W_classic': 11,
    'W_disco': 12,
    'W_ecology': 13,
    'W_feminine': 14,
    'W_genderless': 15,
    'W_grunge': 16,
    'W_hiphop': 17,
    'W_hippie': 18,
    'W_kitsch': 19,
    'W_lingerie': 20,
    'W_lounge': 21,
    'W_military': 22,
    'W_minimal': 23,
    'W_normcore': 24,
    'W_oriental': 25,
    'W_popart': 26,
    'W_powersuit': 27,
    'W_punk': 28,
    'W_space': 29,
    'W_sportivecasual': 30
}


In [8]:
class ImageDataset(Dataset):
  def __init__(self, image_dir, transform=None):
    self.image_dir = image_dir
    self.image_files = [ f for f in os.listdir(image_dir) if f.endswith('.jpg') and os.path.isfile(os.path.join(image_dir, f))]
    self.transform = transform

  def __len__(self):
    return len(self.image_files)

  def __getitem__(self, idx):
    img_name = self.image_files[idx]
    img_path = os.path.join(self.image_dir, img_name)
    image = Image.open(img_path).convert('RGB')  # RGB로 변환

    # 파일명에서 성별과 스타일 추출
    parts = img_name.split('_')
    try:
        gender = parts[-1].replace('.jpg', '')  # 성별
        style = parts[-2]  # 스타일
    except IndexError:
        print(f"Error parsing filename: {img_name}")
        return None, None  # or handle the error differently

    label = f"{gender}_{style}"

    # 라벨을 (성별, 스타일)로 설정 (성별을 0/1, 스타일을 별도의 인덱스로 설정)
    # gender_label = 0 if gender == 'M' else 1
    style_label = style_dict.get(f"{gender}_{style}")

    if self.transform:
      image = self.transform(image)

    return image, style_label

In [None]:
from google.colab import drive
drive.mount('/content/drive')

train_image_dir = '/content/drive/MyDrive/dataset/training_image'
val_image_dir = '/content/drive/MyDrive/dataset/validation_image'

resnet_transform = transforms.Compose([
    transforms.Resize((224, 224)), # resnet paper input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# transform_augmentation = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.CenterCrop((224,224)),
#     transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(degrees=15),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
center_crop_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop((224, 224))
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])

train_dataset = ImageDataset(train_image_dir, transform=resnet_transform)

# class_indices = [[] for _ in range(31)]
# for i in range(len(train_dataset)):
#   _, label = train_dataset[i]
#   class_indices[label].append(i)

# calculate the number of examples to sample from each class
# max_class_size = max([len(class_indices[c]) for c in range(31)])
# class_weights = [max_class_size / len(class_indices[c]) for c in range(31)]
# num_samples = [int(class_weights[c] * len(class_indices[c])) for c in range(31)]

# create a WeightedRandomSampler to oversample the training set
# sampler = WeightedRandomSampler(weights=class_weights, num_samples=sum(num_samples), replacement=True)

train_crop = ImageDataset(train_image_dir, center_crop_transform)
train_dataset = ConcatDataset([train_dataset, train_crop])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, sampler=sampler) # 현재 논문상에서 256 배치사이즈를 사용했다고 함

#create new training set with oversampled examples
# oversampled_train_dataset = Subset(train_dataset, indices=list(sampler))

# Sampling the subset
# oversampled_train_dataset.transform = transform_augmentation
# subset_train_loader = train_loader = DataLoader(oversampled_train_dataset, batch_size=64, sampler=sampler)
# 이부분은 추후 실험을 통해서 조정을 해야할듯.

# train_dataset = ConcatDataset([train_dataset, oversampled_train_dataset])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, sampler=sampler) # 현재 논문상에서 256 배치사이즈를 사용했다고 함


val_dataset = ImageDataset(val_image_dir, transform=resnet_transform)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(f"{len(train_dataset)} train dataset 길이")
# print(f"{len(oversampled_train_dataset)} train dataset 길이")
print(f"{len(val_dataset)} validation dataset 길이")

Mounted at /content/drive


In [11]:
def evaluate(model, val_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0.0

    with torch.no_grad():
        for batch_in, batch_out in val_loader:
            batch_in = batch_in.to(device)
            batch_out = batch_out.to(device)
            # target = batch_out.view(-1)
            outputs = model(batch_in)
            loss = criterion(outputs, batch_out)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == batch_out).sum().item()
            total += batch_out.size(0)
    accuracy = correct / total
    avg_loss = total_loss / len(val_loader)
    # for inputs, labels in val_loader:
    #     with torch.no_grad():
            # outputs = model(inputs.to(device))
            # loss = criterion(outputs, labels.to(device))

            # total_loss += loss.item()
            # _, predicted = torch.max(outputs, 1)
            # correct += (predicted == labels.to(device)).sum().item()
            # total += labels.size(0)

    # accuracy = correct / total if total > 0 else 0
    # avg_loss = total_loss / len(val_loader) if len(val_loader) > 0 else 0

    return accuracy, avg_loss

In [12]:
model = ResNet18().to(device)

epochs = 10

criterion = nn.CrossEntropyLoss()  # loss func 정의

# 논문에서 제시한 optim 사용
optimizer = optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-4)

for epoch in range(epochs):
  model.train()
  running_loss = 0.0
  for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs = model(inputs.to(device))

    loss = criterion(outputs, labels.to(device))

    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

  print(f"Epoch {epoch+1}/{epochs}, Training Loss: {running_loss / len(train_loader)}")
  accuracy, avg_loss = evaluate(model, val_loader, criterion)
  print(f"Epoch {epoch+1}/{epochs}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_loss:.4f}")

Epoch [1/10], Loss: 3.6024
Epoch [1/10], Loss: 4.0140
Epoch [1/10], Loss: 4.2099
Epoch [1/10], Loss: 4.0249
Epoch [1/10], Loss: 3.5301
Epoch [1/10], Loss: 3.2848
Epoch [1/10], Loss: 3.4614
Epoch [1/10], Loss: 3.4896
Epoch [1/10], Loss: 3.6256
Epoch [1/10], Loss: 3.4224
Epoch [1/10], Loss: 3.3265
Epoch [1/10], Loss: 3.5610
Epoch [1/10], Loss: 3.4043
Epoch [1/10], Loss: 3.3171
Epoch [1/10], Loss: 3.4513
Epoch [1/10], Loss: 3.5223
Epoch [1/10], Loss: 3.2885
Epoch [1/10], Loss: 3.5071
Epoch [1/10], Loss: 3.2822
Epoch [1/10], Loss: 3.3951
Epoch [1/10], Loss: 3.6531
Epoch [1/10], Loss: 3.3944
Epoch [1/10], Loss: 3.4382
Epoch [1/10], Loss: 3.2250
Epoch [1/10], Loss: 3.3126
Epoch [1/10], Loss: 3.3880
Epoch [1/10], Loss: 3.3937
Epoch [1/10], Loss: 3.5647
Epoch [1/10], Loss: 3.2431
Epoch [1/10], Loss: 3.4455
Epoch [1/10], Loss: 3.5067
Epoch [1/10], Loss: 3.3445
Epoch [1/10], Loss: 3.5228
Epoch [1/10], Loss: 3.2232
Epoch [1/10], Loss: 3.4072
Epoch [1/10], Loss: 3.3198
Epoch [1/10], Loss: 3.4364
E

KeyboardInterrupt: 