<a href="https://colab.research.google.com/github/paro2708/SER517_Group35_Capstone/blob/GazeRefineNet/GazeRefineNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GazeRefineNet(nn.Module):
    def __init__(self):
        super(GazeRefineNet, self).__init__()
        # Example CNN layers for feature extraction
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)

        # Example RNN layer for temporal processing
        self.rnn = nn.LSTM(input_size=64*28*28, hidden_size=128, num_layers=1, batch_first=True)

        # Output layer
        self.fc1 = nn.Linear(128, 2)  # Assuming output is 2D gaze point

    def forward(self, x_eye_left, x_eye_right, x_screen):
        # Example of processing left eye image through CNN
        x = F.relu(self.conv1(x_eye_left))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten for RNN input

        # Process through RNN (assuming temporal dimension is already included in input)
        x, _ = self.rnn(x.unsqueeze(0))  # Add batch dimension if necessary

        # Pass RNN output through fully connected layer
        x = self.fc1(x[:, -1, :])  # Get last time step output
        return x

# Example usage
model = GazeRefineNet()
# Example tensor for left eye image
x_eye_left = torch.randn(1, 3, 224, 224)  # Example size (batch_size, channels, height, width)
# Example tensor for right eye image (not used in this simplified example)
x_eye_right = torch.randn(1, 3, 224, 224)
# Example tensor for screen content (not used in this simplified example)
x_screen = torch.randn(1, 3, 224, 224)

# Forward pass
output = model(x_eye_left, x_eye_right, x_screen)
print(output)