In [1]:
import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
from torch.utils.data import ConcatDataset

### Initialization & Preprocessing

In [2]:
IMG_HEIGHT = 48
IMG_WIDTH = 48

# Path to the training data
TRAIN_DATA_PATH = os.path.join(os.getcwd(), 'data', 'train')

# Path to the test data
TEST_DATA_PATH = os.path.join(os.getcwd(), 'data', 'test')

# Disgust Samples Multiplier
MULTIPLIER = 3

# Define your transformations
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor()
])

# Load the datasets
train_dataset = ImageFolder(TRAIN_DATA_PATH, transform=transform)
test_dataset = ImageFolder(TEST_DATA_PATH, transform=transform)

# Create the dataloader for validation set only, train data still needs to be augmented
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Oversampling the disgust samples since we don't have many samples

# Define additional transformations for data augmentation
augment_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Create a new dataset with only the "disgust" images
disgust_dataset = [img for img in train_dataset if img[1]
                   == train_dataset.class_to_idx['disgust']]

# Convert Tensor to PIL Image
to_pil = ToPILImage()

# Apply data augmentation to the "disgust" images
augmented_disgust_dataset = [(augment_transform(to_pil(img[0])), img[1])
                             for _ in range(MULTIPLIER) for img in disgust_dataset]

# Combine the original dataset with the augmented "disgust" images
train_dataset = ConcatDataset([train_dataset, augmented_disgust_dataset])

# Update the train DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Count the number of "disgust" samples in the train_dataset
num_disgust_samples = len(disgust_dataset)

print(f"Number of 'disgust' samples: {num_disgust_samples * MULTIPLIER}")

Number of 'disgust' samples: 1308


### Train and Evaluate

In [4]:
import torch
from torch import optim, nn
from cnnNew import FERModel
from preprocessing import train_loader, test_loader
from tqdm import tqdm
from sklearn.metrics import classification_report
import numpy as np

# Create an instance of the model
num_classes = 7  # replace with the number of classes in your dataset
model = FERModel(num_classes)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs to train for
num_epochs = 20

# Loop over the dataset multiple times
for epoch in range(num_epochs):
    running_loss = 0.0
    progress_bar = tqdm(enumerate(train_loader), total=int(len(train_loader)))
    for i, data in progress_bar:
        # Get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        progress_bar.set_description(
            f"Epoch {epoch + 1} loss: {running_loss/(i+1)}")

# Save the trained model
torch.save(model.state_dict(), 'bestmodel.pth')

# Evaluate the model on the test set and print a classification report
model.eval()
all_labels = []
all_predictions = []
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.numpy())
        all_predictions.extend(predicted.numpy())

print(classification_report(all_labels, all_predictions))

print('Finished Training')

  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 1 loss: 1.943797221361113: 100%|██████████| 242/242 [00:13<00:00, 17.48it/s] 
Epoch 2 loss: 1.94059882873346: 100%|██████████| 242/242 [00:13<00:00, 18.50it/s]  
Epoch 3 loss: 1.9365219729991: 100%|██████████| 242/242 [00:13<00:00, 17.85it/s]   
Epoch 4 loss: 1.9318340396092943: 100%|██████████| 242/242 [00:13<00:00, 18.36it/s]
Epoch 5 loss: 1.9300252031688847: 100%|██████████| 242/242 [00:13<00:00, 18.15it/s]
Epoch 6 loss: 1.925383195404179: 100%|██████████| 242/242 [00:13<00:00, 18.47it/s] 
Epoch 7 loss: 1.9256146998444865: 100%|██████████| 242/242 [00:13<00:00, 18.03it/s]
Epoch 8 loss: 1.922529495452061: 100%|██████████| 242/242 [00:13<00:00, 18.02it/s] 
Epoch 9 loss: 1.9242701239822324: 100%|██████████| 242/242 [00:13<00:00, 17.98it/s]
Epoch 10 loss: 1.92448389677962: 100%|██████████| 242/242 [00:13<00:00, 17.45it/s]  
Epoch 11 loss: 1.9240965656012543: 100%|██████████| 242/242 [00:13<00:00, 18.40it/s]
Epoch 12 loss: 1.9247833894304007: 100%|██████████| 242/242 [00:12<00:00, 

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       958
           1       0.02      1.00      0.03       111
           2       0.00      0.00      0.00      1024
           3       0.00      0.00      0.00      1774
           4       0.00      0.00      0.00      1233
           5       0.00      0.00      0.00      1247
           6       0.00      0.00      0.00       831

    accuracy                           0.02      7178
   macro avg       0.00      0.14      0.00      7178
weighted avg       0.00      0.02      0.00      7178

Finished Training


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
