#### Step 1: Load Model
Before we begin transfer learning we first have to load the model. This can be done in two ways (1) load the `model.pth` which includes the model's architecture and weights, or (2) load the model class itself, defined in `pyg_model.py`. Either way works, but (2) is a bit safer when dealing with unknown files. After loading the model, we then load the state dictionary `model_state_dict.pth` which allows us to reference specific layers of the model and is crucial for examining, extracting, or modifying its underlying architecture.

In [21]:
import torch
import copy
from models import relative_positioning

# Model parameters
num_node_features = 1
num_edge_features = 1
hidden_channels = 64
out_channels = 32

# Load model
pretrained_model = relative_positioning(num_node_features=num_node_features, num_edge_features=num_edge_features, 
                             hidden_channels=hidden_channels, out_channels=out_channels)

# Load model state dictionary
model_sd_path = r"C:\Users\xmoot\Desktop\Models\PyTorch\ssl-seizure-detection\relative_positioning\jh101\jh101_12s_7min_model1_state_dict.pth"
pretrained_model.load_state_dict(torch.load(model_sd_path))

<All keys matched successfully>

#### Step 2: Extract Layers
In this step we extract the layers we want to use for the supervised model downstream. In this case, we need the NNConv and GATConv layers from our model, but since our NNConv actually depends on a separate layer called EdgeMLP (which is just a multilayer perpcetron), we'll need that too, since it's essentially part of the NNConv layer's parameters. You can assign it the old fashioned way using `EdgeMLP_module = model.edge_mlp` but this will create issues later on when we try to make two copies of `EdgeMLP_module` for freezing and unfreezing it, so we use the `copy` package instead.

In [22]:
EdgeMLP_module = copy.deepcopy(pretrained_model.edge_mlp)
NNConv_module = copy.deepcopy(pretrained_model.conv1)
GATConv_module = copy.deepcopy(pretrained_model.conv2)

We can examine the weights of a layer with the following:

In [23]:
for param_tensor in EdgeMLP_module.state_dict():
    print(param_tensor, "\t", EdgeMLP_module.state_dict()[param_tensor].size())


mlp.0.weight 	 torch.Size([128, 1])
mlp.0.bias 	 torch.Size([128])
mlp.2.weight 	 torch.Size([64, 128])
mlp.2.bias 	 torch.Size([64])
mlp.4.weight 	 torch.Size([64, 64])
mlp.4.bias 	 torch.Size([64])


And here's a test running some random input through the EdgeMLP, to verify it's functional.

In [24]:
# Create some dummy data
dummy_edge_attr = torch.randn(10, num_edge_features)  # 10 edges, each with `num_edge_features` features

# Run the data through the `edge_mlp` layer
output = EdgeMLP_module(dummy_edge_attr)
print(output)

tensor([[-0.3229, -0.0647, -0.2705, -0.2843, -0.1541, -0.2524, -0.4764,  0.0350,
         -0.5136,  0.8326, -0.3843, -0.1110,  0.3624,  0.5415, -0.3629, -0.2937,
          0.2259,  0.3114, -0.4199, -0.3076, -0.2893, -0.3380, -0.4476,  0.0408,
         -0.5259,  0.3565,  0.3233, -0.3103,  0.0072, -0.2885, -0.3414, -0.1583,
          0.5735, -0.1864, -0.1011,  0.0099, -0.1902, -0.4369, -0.0530, -0.1795,
         -0.4398, -0.4034,  0.6602,  0.7605, -0.1404,  0.4234, -0.1385, -0.1724,
         -0.3312,  0.7015,  0.6666,  0.8652,  0.2292, -0.3177, -0.2220, -0.0868,
          0.3910, -0.3581, -0.0492,  0.4877, -0.0851, -0.5695, -0.2949,  0.6281],
        [-0.4303, -0.2289, -0.2291, -0.2866, -0.1709, -0.4072, -0.4801,  0.0354,
         -0.6393,  0.8586, -0.3117, -0.1384,  0.2977,  0.6900, -0.4050, -0.2402,
          0.1145,  0.3382, -0.4300, -0.4596, -0.2861, -0.5100, -0.4041, -0.0090,
         -0.6471,  0.2327,  0.2302, -0.3289, -0.1161, -0.2566, -0.4129, -0.0981,
          0.5699, -0.1898, 

#### Step 3: Downstream Task
After extracting the layers and verifying everything is functional, we can now either (1) use the layers and their weights as initialization, or (2) use the layers but freeze the weights (i.e. they won't be updated during training). Below uses method (1), using our transferred layers as the initial layers of our network, and then we add on newer (untrained) layers on top of it. I've opted to use another `NNConv` and `GATConv` layer from `PyG`, adding onto the existing `NNConv` and `GATConv` layers, as well as a `global_mean_pool` layer and two fully connected layers. Now we're ready to go!

In [25]:
import torch.nn as nn
from torch_geometric.nn import NNConv, GATConv
from torch.functional import F
from torch_geometric.nn import global_mean_pool
from models import EdgeMLP


class seizure_detection1(nn.Module):
    def __init__(self, num_edge_features, prev_channels, hidden_channels, out_channels, fc_channels):
        super(seizure_detection1, self).__init__()
        
        # Transfered graph layers
        self.edge_mlp1 = EdgeMLP_module
        self.conv1 = NNConv_module 
        self.conv2 = GATConv_module   

        # # New graph layers
        self.edge_mlp2 = EdgeMLP(num_edge_features, prev_channels, hidden_channels) # The number of node features are updated, therefore harde code this in.
        self.conv3 = NNConv(prev_channels, hidden_channels, self.edge_mlp2)  # New NNConv layer
        self.conv4 = GATConv(hidden_channels, out_channels, heads=1, concat=False) # New GATConv layer
        
        # # Fully connected layers
        self.fc1 = nn.Linear(out_channels, fc_channels)
        self.fc2 = nn.Linear(fc_channels, 1)
        
    def forward(self, x, edge_index, edge_attr, batch, mode = "sigmoid"):
        # Your forward pass
        
        # NNConv layer 1
        print("Initial:", x.shape, edge_index.shape, edge_attr.shape)
        x = self.conv1(x, edge_index, edge_attr)
        print("After conv1:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 1
        x = self.conv2(x, edge_index)
        print("After conv2:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # NNConv layer 2
        x = self.conv3(x, edge_index, edge_attr)
        print("After conv3:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 2
        x = self.conv4(x, edge_index)
        print("After conv4:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch)
        
        # Fully connected layer 1
        x = self.fc1(x)
        x = F.relu(x)
        
        # Fully connected layer 2
        x = self.fc2(x)

        if mode == "sigmoid":
            x = torch.sigmoid(x)
            
        elif mode == "linear":
            pass
        
        return x

Let's load our list of `Data` objects we prepared, which holds our graphs `[edge_index, x, edge_attr]` and labels `y`.

In [26]:
# Load data
from torch_geometric.loader import DataLoader

data_path = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_gr\jh101_pyg_Data.pt"
data = torch.load(data_path)
loader = DataLoader(data, batch_size=32, shuffle=True)

Now let's test the model.

In [27]:
# Model parameters
num_edge_features = 1
prev_channels = 64 #<--- The dimension of of the node features that comes from the transferred layers, i.e. x is shape [num_nodes, prev_channels
hidden_channels = 128
out_channels = 64
fc_channels = 32


# Load model
unfrozen_model = seizure_detection1(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

# Try an example data point
example = data[0]

out = unfrozen_model(example.x, example.edge_index, example.edge_attr, example.batch, mode="linear")
print(out)

Initial: torch.Size([107, 1]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv1: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv2: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv3: torch.Size([107, 128]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv4: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
tensor([[856.9504]], grad_fn=<AddmmBackward0>)


Now let's implement method (2) with frozen layers.

In [28]:
import copy

frozen_EdgeMLP_module = copy.deepcopy(pretrained_model.edge_mlp)
frozen_NNConv_module = copy.deepcopy(pretrained_model.conv1)
frozen_GATConv_module = copy.deepcopy(pretrained_model.conv2)

# Freeze the layers
for param in frozen_EdgeMLP_module.parameters():
    param.requires_grad = False

for param in frozen_NNConv_module.parameters():
    param.requires_grad = False

for param in frozen_GATConv_module.parameters():
    param.requires_grad = False

In [29]:
import torch.nn as nn
from torch_geometric.nn import NNConv, GATConv
from torch.functional import F
from torch_geometric.nn import global_mean_pool
from models import EdgeMLP


class seizure_detection2(nn.Module):
    def __init__(self, num_edge_features, prev_channels, hidden_channels, out_channels, fc_channels):
        super(seizure_detection2, self).__init__()
        
        # Transfered graph layers
        self.edge_mlp1 = frozen_EdgeMLP_module
        self.conv1 = frozen_NNConv_module 
        self.conv2 = frozen_GATConv_module   

        # # New graph layers
        self.edge_mlp2 = EdgeMLP(num_edge_features, prev_channels, hidden_channels) # The number of node features are updated, therefore harde code this in.
        self.conv3 = NNConv(prev_channels, hidden_channels, self.edge_mlp2)  # New NNConv layer
        self.conv4 = GATConv(hidden_channels, out_channels, heads=1, concat=False) # New GATConv layer
        
        # # Fully connected layers
        self.fc1 = nn.Linear(out_channels, fc_channels)
        self.fc2 = nn.Linear(fc_channels, 1)
        
    def forward(self, x, edge_index, edge_attr, batch, mode = "sigmoid"):
        # Your forward pass
        
        # NNConv layer 1
        print("Initial:", x.shape, edge_index.shape, edge_attr.shape)
        x = self.conv1(x, edge_index, edge_attr)
        print("After conv1:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 1
        x = self.conv2(x, edge_index)
        print("After conv2:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # NNConv layer 2
        x = self.conv3(x, edge_index, edge_attr)
        print("After conv3:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 2
        x = self.conv4(x, edge_index)
        print("After conv4:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch)
        
        # Fully connected layer 1
        x = self.fc1(x)
        x = F.relu(x)
        
        # Fully connected layer 2
        x = self.fc2(x)

        if mode == "sigmoid":
            x = torch.sigmoid(x)
            
        elif mode == "linear":
            pass
        
        return x

Then we can check whether our models are frozen or not.

In [32]:
num_edge_features = 1
prev_channels = 64
hidden_channels = 128
out_channels = 64
fc_channels = 32

unfrozen_model = seizure_detection1(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

frozen_model = seizure_detection2(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

for name, param in unfrozen_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}")
print("----------------------------------------------------")    
for name, param in frozen_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}")

Layer: edge_mlp1.mlp.0.weight, Frozen: False
Layer: edge_mlp1.mlp.0.bias, Frozen: False
Layer: edge_mlp1.mlp.2.weight, Frozen: False
Layer: edge_mlp1.mlp.2.bias, Frozen: False
Layer: edge_mlp1.mlp.4.weight, Frozen: False
Layer: edge_mlp1.mlp.4.bias, Frozen: False
Layer: conv1.bias, Frozen: False
Layer: conv1.nn.mlp.0.weight, Frozen: False
Layer: conv1.nn.mlp.0.bias, Frozen: False
Layer: conv1.nn.mlp.2.weight, Frozen: False
Layer: conv1.nn.mlp.2.bias, Frozen: False
Layer: conv1.nn.mlp.4.weight, Frozen: False
Layer: conv1.nn.mlp.4.bias, Frozen: False
Layer: conv1.lin.weight, Frozen: False
Layer: conv2.att_src, Frozen: False
Layer: conv2.att_dst, Frozen: False
Layer: conv2.bias, Frozen: False
Layer: conv2.lin_src.weight, Frozen: False
Layer: edge_mlp2.mlp.0.weight, Frozen: False
Layer: edge_mlp2.mlp.0.bias, Frozen: False
Layer: edge_mlp2.mlp.2.weight, Frozen: False
Layer: edge_mlp2.mlp.2.bias, Frozen: False
Layer: edge_mlp2.mlp.4.weight, Frozen: False
Layer: edge_mlp2.mlp.4.bias, Frozen: 

#### Step 4: Fine-Tuning for Seizure Detection