In [11]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm


In [12]:
def load_fer2013_dataset(csv_file, transform=None):
    dataframe = pd.read_csv(csv_file)
    dataset = []
    
    for index, row in dataframe.iterrows():
        emotion, pixels = row['emotion'], row['pixels']
        image = np.fromstring(pixels, dtype=int, sep=' ').reshape(48, 48).astype('uint8')
        image = Image.fromarray(image)
        if transform:
            image = transform(image)
        dataset.append((image, emotion))
        
    return dataset

In [13]:
def mobilenet_v2_block(in_channels, out_channels, expansion_factor, stride):
    block = nn.Sequential(
        nn.Conv2d(in_channels, in_channels * expansion_factor, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(in_channels * expansion_factor),
        nn.ReLU6(inplace=True),
        nn.Conv2d(in_channels * expansion_factor, in_channels * expansion_factor, 3, stride=stride, padding=1, groups=in_channels * expansion_factor, bias=False),
        nn.BatchNorm2d(in_channels * expansion_factor),
        nn.ReLU6(inplace=True),
        nn.Conv2d(in_channels * expansion_factor, out_channels, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels)
    )
    return block

def create_mobilenet_v2(num_classes=7):
    model = nn.Sequential(
        nn.Conv2d(1, 32, 3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU6(inplace=True),
        mobilenet_v2_block(32, 16, 1, 1),
        mobilenet_v2_block(16, 24, 6, 2),
        mobilenet_v2_block(24, 24, 6, 1),
        mobilenet_v2_block(24, 32, 6, 2),
        mobilenet_v2_block(32, 32, 6, 1),
        mobilenet_v2_block(32, 32, 6, 1),
        mobilenet_v2_block(32, 64, 6, 2),
        mobilenet_v2_block(64, 64, 6, 1),
        mobilenet_v2_block(64, 64, 6, 1),
        mobilenet_v2_block(64, 64, 6, 1),
        mobilenet_v2_block(64, 96, 6, 1),
        mobilenet_v2_block(96, 96, 6, 1),
        mobilenet_v2_block(96, 96, 6, 1),
        mobilenet_v2_block(96, 160, 6, 2),
        mobilenet_v2_block(160, 160, 6, 1),
        mobilenet_v2_block(160, 160, 6, 1),
        mobilenet_v2_block(160, 320, 6, 1),
        nn.Conv2d(320, 1280, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(1280),
        nn.ReLU6(inplace=True),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Dropout(0.2),
        nn.Linear(1280, num_classes)
    )
    return model

In [23]:
csv_file = 'fer2013/fer2013/fer2013.csv'

transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = load_fer2013_dataset(csv_file, transform)

In [24]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = create_mobilenet_v2()
model.to(device)

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU6(inplace=True)
  (3): Sequential(
    (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU6(inplace=True)
    (6): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (4): Sequential(
    (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=

In [26]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [27]:
num_epochs = 50

In [28]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 1/50, Loss: 1.8365


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 22.01%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 2/50, Loss: 1.8208


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 24.64%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 3/50, Loss: 1.8137


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 24.37%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 4/50, Loss: 1.8054


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 25.20%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 5/50, Loss: 1.7925


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 25.62%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 6/50, Loss: 1.7797


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 26.43%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 7/50, Loss: 1.7663


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 27.08%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 8/50, Loss: 1.7512


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 28.99%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 9/50, Loss: 1.7241


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 29.90%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 10/50, Loss: 1.6938


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 30.84%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 11/50, Loss: 1.6584


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 33.27%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 12/50, Loss: 1.6148


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 34.27%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 13/50, Loss: 1.5783


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 35.85%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 14/50, Loss: 1.5363


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 36.19%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 15/50, Loss: 1.5003


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 37.20%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 16/50, Loss: 1.4641


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 37.96%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 17/50, Loss: 1.4298


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.03%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 18/50, Loss: 1.3904


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.53%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 19/50, Loss: 1.3609


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 37.55%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 20/50, Loss: 1.3257


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.95%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 21/50, Loss: 1.3058


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.62%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 22/50, Loss: 1.2627


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.66%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 23/50, Loss: 1.2347


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.60%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 24/50, Loss: 1.1957


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.90%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 25/50, Loss: 1.1709


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.73%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 26/50, Loss: 1.1397


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.15%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 27/50, Loss: 1.1276


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.34%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 28/50, Loss: 1.0832


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.94%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 29/50, Loss: 1.0706


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.67%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 30/50, Loss: 1.0342


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.81%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 31/50, Loss: 1.0072


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.11%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 32/50, Loss: 0.9916


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.15%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 33/50, Loss: 0.9629


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.58%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 34/50, Loss: 0.9357


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.67%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 35/50, Loss: 0.9229


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.73%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 36/50, Loss: 0.8882


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.59%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 37/50, Loss: 0.8762


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.36%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 38/50, Loss: 0.8538


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.57%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 39/50, Loss: 0.8211


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.72%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 40/50, Loss: 0.8085


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.81%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 41/50, Loss: 0.8048


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 40.09%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 42/50, Loss: 0.7756


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.16%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 43/50, Loss: 0.7550


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.91%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 44/50, Loss: 0.7453


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.76%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 45/50, Loss: 0.7256


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.19%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 46/50, Loss: 0.7063


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 40.14%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 47/50, Loss: 0.6923


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 40.29%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 48/50, Loss: 0.6814


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 39.76%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 49/50, Loss: 0.6771


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 38.79%


  0%|          | 0/898 [00:00<?, ?it/s]

Epoch 50/50, Loss: 0.6501


  0%|          | 0/225 [00:00<?, ?it/s]

Accuracy: 40.29%


In [None]:
from PIL import ImageOps

def predict_emotion(image_path, model, transform):
    image = Image.open(image_path).convert('L')  # 이미지를 흑백으로 변환
    image = ImageOps.equalize(image)  # 이미지 히스토그램 평준화
    input_tensor = transform(image).unsqueeze(0).to(device)
    output = model(input_tensor)
    _, predicted = torch.max(output.data, 1)
    return predicted.item()

emotion_labels = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

# 새로운 이미지 경로를 사용하여 감정을 예측합니다.
image_path = 'path/to/your/image.jpg'
predicted_emotion = predict_emotion(image_path, loaded_model, transform)
print(f"Predicted emotion: {emotion_labels[predicted_emotion]}")