<a href="https://colab.research.google.com/github/paro2708/SER517_Group35_Capstone/blob/GazeRefineNet/GazeRefineNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GazeRefineNet(nn.Module):
    def __init__(self, config):
        super(GazeRefineNet, self).__init__()
        # Dynamically set input channels based on configuration
        in_channels = 4 if config['load_screen_content'] else 1
        self.do_skip = config['use_skip_connections']

        # Define initial convolution layers to process the input image
        self.initial_conv = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.InstanceNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Example of a simplified backbone architecture
        self.backbone = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            # Additional layers would be added here...
        )

        # Final convolution to generate the heatmap
        self.final_conv = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()  # Assuming the output is a heatmap
        )

    def forward(self, input_dict):
        # Example preprocessing step
        # Assuming input_dict contains 'screen_frame' and 'heatmap_initial'
        if config['load_screen_content']:
            input_image = torch.cat([input_dict['screen_frame'], input_dict['heatmap_initial']], dim=1)
        else:
            input_image = input_dict['heatmap_initial']

        # Pass through initial convolutions
        x = self.initial_conv(input_image)
        # Pass through the backbone
        x = self.backbone(x)
        # Generate final heatmap
        final_heatmap = self.final_conv(x)

        return final_heatmap

# Need to change configuration accordingly
config = {
    'load_screen_content': True,
    'use_skip_connections': True,
}

# Initialize the model
model = GazeRefineNet(config)
