In [None]:
from __future__ import print_function
import numpy as np 
import pandas as pd 
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
from torch import nn, optim
from torchvision import transforms, models, datasets
import pickle
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from collections import defaultdict

warnings.filterwarnings('ignore')

# Define a function to read images
def read_image(image_path):
    img = Image.open(image_path)
    img = img.resize((224, 224))  # Resize the image to match ResNet's input size
    img = np.array(img)
    return img

# Define a custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_folder, pickle_file, transform=None):
        self.image_folder = image_folder
        self.pickle_file = pickle_file
        self.transform = transform
        self.data = self.load_pickle_data()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_filename = self.data.iloc[idx]['filename']
        image = read_image(os.path.join(self.image_folder, image_filename + ".jpg"))
        phoneme_embedding = self.data.iloc[idx]['phoneme_embeddings']
        emotion = int(self.data.iloc[idx]['emotion'])

        if self.transform:
            image = Image.fromarray(image)  # Convert numpy array to PIL Image
            image = self.transform(image)

        return image, phoneme_embedding, emotion

    def load_pickle_data(self):
        with open(self.pickle_file, 'rb') as f:
            data = pickle.load(f)
        return data

# Define the MultimodalModel class
class MultimodalModel(nn.Module):
    def __init__(self, num_classes, phoneme_input_size, hidden_size):
        super(MultimodalModel, self).__init__()
        
        # Initialize pre-trained ResNet model
        self.resnet = models.resnet18(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()  # Remove the final fully connected layer
        
        # Define BiLSTM layers for processing phoneme embeddings
        self.bilstm1 = nn.LSTM(input_size=phoneme_input_size, hidden_size=hidden_size, batch_first=True, bidirectional=True)
        lstm_output_size = hidden_size * 2  # Because it's bidirectional
        
        # Define attention mechanism
        self.attention = nn.MultiheadAttention(embed_dim=num_ftrs + lstm_output_size, num_heads=1)
        
        # Define dense and dropout layers
        self.dense1 = nn.Linear(num_ftrs + lstm_output_size, 256)
        self.dropout = nn.Dropout(0.5)
        
        # Add another BiLSTM layer
        self.bilstm2 = nn.LSTM(input_size=256, hidden_size=hidden_size, batch_first=True, bidirectional=True)
        
        # Add another dense layer
        self.dense2 = nn.Linear(hidden_size * 2, 128)  # Output size can be adjusted as needed
        
        # Final output dense layer
        self.dense3 = nn.Linear(128, num_classes)
        
        # Pooling layer
        self.pooling = nn.AdaptiveAvgPool1d(1)  # Adjust kernel size if needed

    def forward(self, image_input, phoneme_input):
        # Process image input through ResNet
        image_output = self.resnet(image_input)  # No need to convert to byte or float
        
        # Process phoneme input through first BiLSTM
        lstm_output1, _ = self.bilstm1(phoneme_input)
        # Extract the final hidden state
        phoneme_output1 = lstm_output1[:, -1, :]
        
        # Combine ResNet and first BiLSTM outputs
        combined_output = torch.cat((image_output, phoneme_output1), dim=1)
        
        # Pass the combined output through the first dense layer
        output = F.relu(self.dense1(combined_output))
        
        # Add another BiLSTM layer after dense1
        lstm_output2, _ = self.bilstm2(output.unsqueeze(1))
        lstm_output2 = lstm_output2[:, -1, :]  # Consider only the last timestep output
        
        # Add another dense layer after the second BiLSTM
        output = F.relu(self.dense2(lstm_output2))
        output = self.dropout(output) #Remove this dropout later
        
        # Final output layer
        final_output = self.dense3(output)
        
        # Apply pooling
        output = output.unsqueeze(2)  # Add a channel dimension for 1D pooling
        output = self.pooling(output).squeeze(2)  # Pool along the temporal dimension
        
        return final_output

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 3: Define DataLoader and Model
batch_size = 8
num_classes = 4  # Number of classes
hidden_size = 32
phoneme_input_size = 44

# Load dataset
image_folder = ''
pickle_file = '' #contain phoneme embeddings

# Augmentation and transformation for training data
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), #Do Center crop
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)

# Transformation for validation and test data
val_test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)

# Load dataset with augmentation for training and without augmentation for validation and test
dataset = CustomDataset(image_folder, pickle_file, transform=train_transforms)

# Perform train-validation-test split
train_size = 0.8
val_size = 0.1
test_size = 0.1
train_dataset, temp_dataset = train_test_split(dataset, train_size=train_size, test_size=(val_size + test_size), shuffle=True)
val_dataset, test_dataset = train_test_split(temp_dataset, train_size=val_size / (val_size + test_size), test_size=test_size / (val_size + test_size), shuffle=True)

# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the multimodal model
multimodal_model = MultimodalModel(num_classes, phoneme_input_size, hidden_size)
multimodal_model.to(device)

# Define optimizer and loss function
optimizer = optim.Adam(multimodal_model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

from torch.optim.lr_scheduler import StepLR

#optimizer = optim.Adam(multimodal_model.parameters(), lr=0.0001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    multimodal_model.train()
    total_correct = 0
    total_samples = 0
    total_loss = 0.0  # Initialize total loss
    for images, phoneme_embeddings, labels in train_loader:
        images = images.to(device)  # Move images to GPU
        phoneme_embeddings = phoneme_embeddings.to(device)  # Move phoneme embeddings to GPU
        labels = labels.to(device)  # Move labels to GPU

        optimizer.zero_grad()
        # Forward pass
        outputs = multimodal_model(images, phoneme_embeddings)
        labels = labels - 1
        loss = criterion(outputs, labels)
        total_loss += loss.item() * images.size(0)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

    # Update learning rate
    scheduler.step()

    train_accuracy = total_correct / total_samples
    train_average_loss = total_loss / total_samples
    train_accuracy = train_accuracy*100
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Accuracy: {train_accuracy:.4f}, Train Average Loss: {train_average_loss:.4f}")

    # Evaluate on validation set
    multimodal_model.eval()
    with torch.no_grad():
        total_correct = 0
        total_samples = 0
        val_total_loss = 0.0  # Initialize total loss for validation set
        class_accuracies = defaultdict(int)  # Store individual class accuracies
        
        # Initialize a dictionary to store the number of samples for each class
        class_samples = defaultdict(int)

        # Iterate through the validation set
        for images, phoneme_embeddings, labels in val_loader:
            images = images.to(device)  # Move images to GPU
            phoneme_embeddings = phoneme_embeddings.to(device)  # Move phoneme embeddings to GPU
            labels = labels.to(device)  # Move labels to GPU

            # Forward pass
            outputs = multimodal_model(images, phoneme_embeddings)
            # Calculate loss
            labels = labels - 1
            loss = criterion(outputs, labels)
            val_total_loss += loss.item() * images.size(0)

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            
            # Update class_samples dictionary
            for pred, label in zip(predicted, labels):
                class_samples[label.item()] += 1
                if pred == label:
                    class_accuracies[label.item()] += 1

        val_accuracy = total_correct / total_samples
        val_average_loss = val_total_loss / total_samples
        print(f"Epoch [{epoch + 1}/{num_epochs}], Validation Accuracy: {val_accuracy:.4f}")
        
        # Calculate weighted validation accuracy
        weighted_accuracy = sum(class_accuracies[label] / class_samples[label] * acc for label, acc in class_accuracies.items())
        weighted_accuracy = weighted_accuracy/4
        print(f"Epoch [{epoch + 1}/{num_epochs}], Weighted Validation Accuracy: {weighted_accuracy:.4f}")

# Evaluate on test set
multimodal_model.eval()
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    test_total_loss = 0.0  # Initialize total loss for test set
    
    # Store individual class accuracies
    class_accuracies = defaultdict(int)  
    
    # Store the number of samples for each class
    class_samples = defaultdict(int)  

    for images, phoneme_embeddings, labels in test_loader:
        images = images.to(device)  # Move images to GPU
        phoneme_embeddings = phoneme_embeddings.to(device)  # Move phoneme embeddings to GPU
        labels = labels.to(device)  # Move labels to GPU

        # Forward pass
        outputs = multimodal_model(images, phoneme_embeddings)
        
        # Calculate loss
        labels = labels - 1
        loss = criterion(outputs, labels)
        test_total_loss += loss.item() * images.size(0)

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

        # Update class_samples dictionary
        for pred, label in zip(predicted, labels):
            class_samples[label.item()] += 1
            if pred == label:
                class_accuracies[label.item()] += 1

    test_accuracy = total_correct / total_samples
    test_average_loss = test_total_loss / total_samples
    print(f"Test Accuracy: {test_accuracy:.4f}")
    
    # Calculate weighted test accuracy
    weighted_accuracy = sum(class_accuracies[label] / class_samples[label] * acc for label, acc in class_accuracies.items())
    weighted_accuracy = weighted_accuracy/4
    print(f"Weighted Test Accuracy: {weighted_accuracy:.4f}")
