<a href="https://colab.research.google.com/github/paro2708/SER517_Group35_Capstone/blob/GazeRefineNet/Trial_EyeNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
import torchvision.models as models
import torch
from torch import nn
import numpy as np
import math
import torch.nn.functional as F

In [46]:
class EyeNet(nn.Module):
  def __init__(self, side='left', use_rnn=True):
    super(EyeNet, self).__init__()
    self.resnet = models.resnet18(
        #block = models.resnet.BasicBlock,
        #layers = [2,2,2,2],
        pretrained=False,        # Not using pre-trained weights
        num_classes=128,        # (IMP)Number of output classes - needs to be defined based on number of eye features needed
        norm_layer=nn.InstanceNorm2d,  # Normalization layer - layers need to be added
    )
    self.use_rnn = use_rnn

    # Optional recurrent component - GRUCell
    if use_rnn:
            self.rnn = nn.GRUCell(input_size=128, hidden_size=128)
            #self.fc_gaze = nn.Linear(128, 3)  # Output size for gaze direction

    self.fc_gaze = nn.Sequential(
                    nn.Linear(128, 128),
                    nn.SELU(inplace=True),
                    nn.Linear(128, 2, bias=False),
                    nn.Tanh(),
                )  # Output size for gaze direction
    self.fc_pupil = nn.Sequential(
                      nn.Linear(128, 128),
                      nn.SELU(inplace=True),
                      nn.Linear(128, 1),
                      nn.ReLU(inplace=True),
                  )  # Output size for pupil size

  def forward(self, input_eye_image, rnn_output=None):

    features = self.resnet(input_eye_image)

    if self.use_rnn:
      rnn_features=features
      batch_size, feature_size = features.shape
      hidden = torch.zeros(batch_size, 128, device=rnn_features.device)
      previous_results = []
      output=[]

      for i in range(batch_size):  # Loop through layers

        if rnn_output is not None:
            previous_results = output[i-1] if i>0 else None

        GRUResult= self.rnn(rnn_features,hidden)

        if isinstance(GRUResult, tuple):
          rnn_features=GRUResult[0]
          output[i] = GRUResult
        else:
          rnn_features = GRUResult
          output.append(GRUResult)

      features=rnn_features

    #to calculate point of gaze
    gaze_direction = (0.5 * np.pi) * self.fc_gaze(features)
    gaze_direction_vector= convert_angles_to_vector(gaze_direction)
    #x1,y1,x2,y2 need to be taken from meta data files for each eye - side to be given as input to eyenet
    print(gaze_direction_vector[0][2])
    origin = calculate_gaze_origin_direction(torch.tensor([0. ,0. ,gaze_direction_vector[0][2]]), z1=0, z2=0)
    point_of_gaze_mm = calculate_intersection_with_screen(origin,gaze_direction_vector)
    #Hard coding device dimenions from meta data
    #screen_pixels from meta
    screen_size_mm = [123.8 , 53.7]
    screen_size_pixels = [568 , 320]
    point_of_gaze_px = mm_to_pixels(point_of_gaze_mm,screen_size_mm, screen_size_pixels) # need to get from screen.json
    pupil_size =self.fc_pupil(features)
    print("Gaze Direction shape before linear layer:", gaze_direction.shape)
    print("Pupil Size shape before linear layer:", pupil_size.shape)
    return gaze_direction, pupil_size , point_of_gaze_px

#Converting pitch and yaw to a vector - to convert gaze direction to a vector

def convert_angles_to_vector(angles):
    # Check if the input angles are 2-dimensional (pitch and yaw)
    if angles.shape[1] == 2:
        sine_values = torch.sin(angles)
        cosine_values = torch.cos(angles)
        # Construct and return the direction vector
        return torch.stack([cosine_values[:, 0] * sine_values[:, 1], sine_values[:, 0], cosine_values[:, 0] * cosine_values[:, 1]], dim=1)
    # Normalize the vector if the input is 3-dimensional
    elif angles.shape[1] == 3:
        return F.normalize(angles, dim=1)
    else:
        # Raise an error for unsupported input dimensions
        raise ValueError(f'Unexpected input dimensions: {angles.shape}')

def apply_transformation(T, vec):
    if vec.shape[1] == 2:
        vec = convert_angles_to_vector(vec)
    vec = vec.reshape(-1, 3, 1)
    h_vec = F.pad(vec, pad=(0, 0, 0, 1), value=1.0)
    if T.size(-2) != 4 or T.size(-1) != 4:
        raise ValueError("Transformation matrix T must be of shape [4, 4]")
    return torch.matmul(T, h_vec)[:, :3, 0]


def apply_rotation(T, vec):
    if vec.shape[1] == 2:
        vec = convert_angles_to_vector(vec)
    vec = vec.reshape(-1, 3, 1)
    if T.dim() == 2:
        T = T.unsqueeze(0)  # Add a batch dimension if it's missing
    elif T.dim() != 3:
        raise ValueError("T must be a 2D or 3D tensor")
    R = T[:, :3, :3]
    return torch.matmul(R, vec).reshape(-1, 3)

# To calculate point of gaze, gaze origin assumed to be 0,0,0
def calculate_intersection_with_screen(o, direction):

    # Ensure o and direction are 2D tensors [N, 3]
    if o.dim() == 1:
        o = o.unsqueeze(0)  # Add batch dimension if necessary
    if direction.dim() == 1:
        direction = direction.unsqueeze(0)  # Add batch dimension if necessary

    rotation = torch.tensor([
    [0.99970895052,-0.017290327698, 0.0168244000524],
    [-0.0110340490937,0.292467236519, 0.9562118053443],
    [-0.0214538034052,-0.956119179726,0.292191326618]
    ], dtype=torch.float32)

    # Assuming no translation, and the camera is at the origin of the world space
    camera_transformation_matrix = torch.eye(4)
    camera_transformation_matrix[:3, :3] = rotation
    inverse_camera_transformation_matrix = torch.inverse(camera_transformation_matrix)

    # De-rotate gaze vector
    inv_rotation = torch.inverse(rotation)
    direction = direction.reshape(-1, 3, 1)
    direction = torch.matmul(inv_rotation, direction)

    direction = apply_rotation(inverse_camera_transformation_matrix, direction)
    o = apply_transformation(inverse_camera_transformation_matrix, o)

    # Assuming o = (0, 0, 0) for simplicity
    # Solve for t when z = 0
    t = -o[:, 2] / direction[:, 2]

    # Calculate intersection point in millimeters
    p_x = o[:, 0] + t * direction[:, 0]
    p_y = o[:, 1] + t * direction[:, 1]

    return torch.stack([p_x, p_y], dim=-1)

def mm_to_pixels(intersection_mm, screen_size_mm, screen_size_pixels):
    # Unpack screen dimensions
    screen_height_mm, screen_width_mm = screen_size_mm
    screen_height_px, screen_width_px = screen_size_pixels

    # Calculate pixels per millimeter
    ppmm_x = screen_width_px #/ screen_width_mm
    ppmm_y = screen_height_px #/ screen_height_mm

    # Convert intersection point from mm to pixels
    intersection_px = intersection_mm * torch.tensor([ppmm_x, ppmm_y])
    return intersection_px

def calculate_gaze_origin_direction(z_gd, z1=0, z2=0):
    # Convert points to tensors
    x1 = 293
    x2 = 346
    y1 = 406
    y2 = 405
    point1 = torch.tensor([x1, y1, z1], dtype=torch.float32)
    point2 = torch.tensor([x2, y2, z2], dtype=torch.float32)

    # Calculate the vector pointing from point1 to point2
    direction_vector = point2 - point1

    # Normalize the vector to get a unit vector
    unit_vector = direction_vector / torch.norm(direction_vector)

    unit_vector= unit_vector + z_gd

    return unit_vector

In [47]:
from torchvision import transforms
from PIL import Image
import torch


eyenet= EyeNet()

# Load and preprocess the image
image_path = '/content/sample_data/0.jpg'
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img = Image.open(image_path)
img_tensor = preprocess(img)  # Add batch dimension

# # Define a transform to convert the torch tensor to PIL image
# transform = transforms.ToPILImage()

# # Apply the transform to the torch tensor
# image = transform(img_tensor)

# # Save the PIL image to a file
# image.save("/content/sample_data/image.png")


# print(img_tensor.unsqueeze(0).shape)

# #%debug
# # Assuming the EyeNet model is already defined and initialized as eyenet
gaze_direction, pupil_size, point_of_gaze_px = eyenet(img_tensor.unsqueeze(0))

print("Predicted Gaze Direction:", gaze_direction)
print("Predicted Pupil Size:", pupil_size)
print("Predicted Point of Gaze:", point_of_gaze_px)
magnitude = torch.norm(gaze_direction, p=2)
magnitude = magnitude *(180/np.pi)
print("Normalized Gaze Direction Magnitude(in radians):", magnitude.item())

tensor(0.9792, grad_fn=<SelectBackward0>)
Gaze Direction shape before linear layer: torch.Size([1, 2])
Pupil Size shape before linear layer: torch.Size([1, 1])
Predicted Gaze Direction: tensor([[-0.1964, -0.0566]], grad_fn=<MulBackward0>)
Predicted Pupil Size: tensor([[0.0453]], grad_fn=<ReluBackward0>)
Predicted Point of Gaze: tensor([[ 306.7181, -612.0364]], grad_fn=<MulBackward0>)
Normalized Gaze Direction Magnitude(in radians): 11.711485862731934
