In [14]:
import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
from torch.utils.data import ConcatDataset

### Initialization & Preprocessing

In [15]:
IMG_HEIGHT = 48
IMG_WIDTH = 48

# Path to the training data
TRAIN_DATA_PATH = os.path.join(os.getcwd(), 'data', 'train')

# Path to the test data
TEST_DATA_PATH = os.path.join(os.getcwd(), 'data', 'test')

# Disgust Samples Multiplier
MULTIPLIER = 3

# Define your transformations
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor()
])

# Load the datasets
train_dataset = ImageFolder(TRAIN_DATA_PATH, transform=transform)
test_dataset = ImageFolder(TEST_DATA_PATH, transform=transform)

# Create the dataloader for validation set only, train data still needs to be augmented
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Oversampling the disgust samples since we don't have many samples

# Define additional transformations for data augmentation
augment_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Create a new dataset with only the "disgust" images
disgust_dataset = [img for img in train_dataset if img[1]
                   == train_dataset.class_to_idx['disgust']]

# Convert Tensor to PIL Image
to_pil = ToPILImage()

# Apply data augmentation to the "disgust" images
augmented_disgust_dataset = [(augment_transform(to_pil(img[0])), img[1])
                             for _ in range(MULTIPLIER) for img in disgust_dataset]

# Combine the original dataset with the augmented "disgust" images
train_dataset = ConcatDataset([train_dataset, augmented_disgust_dataset])

# Update the train DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Count the number of "disgust" samples in the train_dataset
num_disgust_samples = len(disgust_dataset)

print(f"Number of 'disgust' samples: {num_disgust_samples * MULTIPLIER}")

Number of 'disgust' samples: 1308


### Train and Evaluate

In [25]:
import torch
from torch import optim, nn
from cnn import FERModel
from tqdm import tqdm
from sklearn.metrics import classification_report
import numpy as np

# Create an instance of the model
num_classes = 7  # replace with the number of classes in your dataset
model = FERModel(num_classes)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs to train for
num_epochs = 100

# Loop over the dataset multiple times
for epoch in range(num_epochs):
    running_loss = 0.0
    progress_bar = tqdm(enumerate(train_loader), total=int(len(train_loader)))
    for i, data in progress_bar:
        # Get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        progress_bar.set_description(
            f"Epoch {epoch + 1} loss: {running_loss/(i+1)}")

# Save the trained model
torch.save(model.state_dict(), 'model_high.pth')

# Evaluate the model on the test set and print a classification report
model.eval()
all_labels = []
all_predictions = []
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.numpy())
        all_predictions.extend(predicted.numpy())

print(classification_report(all_labels, all_predictions))

print('Finished Training')

  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 1 loss: 1.894217251253522: 100%|██████████| 242/242 [00:56<00:00,  4.25it/s] 
Epoch 2 loss: 1.8077049373595182: 100%|██████████| 242/242 [00:54<00:00,  4.47it/s]
Epoch 3 loss: 1.7799091215961236: 100%|██████████| 242/242 [00:54<00:00,  4.43it/s]
Epoch 4 loss: 1.7683011598823484: 100%|██████████| 242/242 [00:57<00:00,  4.24it/s]
Epoch 5 loss: 1.7465792864807381: 100%|██████████| 242/242 [00:59<00:00,  4.10it/s]
Epoch 6 loss: 1.7288422628867726: 100%|██████████| 242/242 [00:58<00:00,  4.14it/s]
Epoch 7 loss: 1.721179388278772: 100%|██████████| 242/242 [00:58<00:00,  4.13it/s] 
Epoch 8 loss: 1.7081321413851966: 100%|██████████| 242/242 [00:58<00:00,  4.14it/s]
Epoch 9 loss: 1.6951826234494358: 100%|██████████| 242/242 [01:03<00:00,  3.80it/s]
Epoch 10 loss: 1.6812878265853757: 100%|██████████| 242/242 [01:04<00:00,  3.77it/s]
Epoch 11 loss: 1.6753918188662569: 100%|██████████| 242/242 [01:01<00:00,  3.95it/s]
Epoch 12 loss: 1.6650862802158704: 100%|██████████| 242/242 [01:01<00:00, 

KeyboardInterrupt: 

In [17]:
for dataset in train_dataset.datasets:
    try:
        print(dataset.class_to_idx)
    except AttributeError:
        print("")


{'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprise': 6}

