In [1]:
batch=32
data_path = r"C:\Datasets\Engagement\VIDEO000\00000"  # Provide your data path here
ckpt_dir = r"C:\Datasets\Engagement\checkpoints\epoch_6.ckpt"

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import random
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform2 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
class GazeDataset(Dataset):
    def __init__(self,  data_path, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file.
            data_path (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        dir_path =os.path.join(data_path,"image_original")
        png_files = [f for f in os.listdir(dir_path) if f.endswith('.png')]
        self.data_frame = pd.DataFrame(png_files, columns=['imgID'])
        print(dir_path)
        self.data_path = data_path
        self.transform = transform
        # Create a dictionary mapping each unique gaze value to a unique integer
        self.gaze_to_int ={3: 0, 2: 1, 1: 0, 4: 1}
        self.num_classes = 2
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_path+"\image_original", self.data_frame['imgID'][idx])  # Assuming imgID is in the first column
        image = Image.open(img_name)
        head_name = os.path.join(self.data_path+"\predict_head", self.data_frame['imgID'][idx])
        head = Image.open(head_name)
        gaze_name = os.path.join(self.data_path+"\predict_heatmap", self.data_frame['imgID'][idx])
        gaze_img = Image.open(gaze_name) 
        skeleton_name = os.path.join(self.data_path+"\image_skeleton", self.data_frame['imgID'][idx].replace(".png","_rendered.png"))
        skeleton_img = Image.open(skeleton_name)
        
        if self.transform:
            image = self.transform(image)
            head = transform2(head)
            gaze_img = transform2(gaze_img)
            skeleton_img = transform2(skeleton_img) 
            
 
        return image, head, gaze_img, skeleton_img, self.data_frame['imgID'][idx]

# Create datasets
import pandas as pd
from sklearn.model_selection import train_test_split
  

model = models.resnet18(weights="ResNet18_Weights.DEFAULT")
model.fc = nn.Linear(model.fc.in_features, 2)
model.conv1 = nn.Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
checkpoint = torch.load(ckpt_dir)
model.load_state_dict(checkpoint)
model.to(device)

predicted_dataset = GazeDataset(data_path=data_path, transform=transform)  
# Randomly pick 3 images from the validation dataset

data = []
df=pd.DataFrame(columns=['imgID',"prediction"])

for idx in tqdm(range(len(predicted_dataset))):
    image, depth, gaze, skeleton_img, basename = predicted_dataset[idx]
    input_tensor = torch.cat([image.unsqueeze(0), depth.unsqueeze(0), gaze.unsqueeze(0), skeleton_img.unsqueeze(0)], dim=1)
    input_tensor = input_tensor.to(device)
    
    # Get prediction
    with torch.no_grad(): 
        output = model(input_tensor)
        _, predicted = torch.max(output.data, 1)
    data.append({'imgID': basename, 'prediction': predicted.item()})
df = pd.DataFrame(data)

    # Plotting
    #print( f"Prediction: {predicted.item()}")
    