#### Step 1: Load Model
Before we begin transfer learning we first have to load the model. This can be done in two ways (1) load the `model.pth` which includes the model's architecture and weights, or (2) load the model class itself, defined in `pyg_model.py`. Either way works, but (2) is a bit safer when dealing with unknown files. After loading the model, we then load the state dictionary `model_state_dict.pth` which allows us to reference specific layers of the model and is crucial for examining, extracting, or modifying its underlying architecture.

In [1]:
import torch
import copy
import sys
sys.path.append("../src")

model_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/temporal_shuffling.pth"
model_dict_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/temporal_shuffling_state_dict.pth"

# Load model
model = torch.load(model_path)

# Load state dictionary
model_dict = torch.load(model_dict_path)

# Set the state dictionary to the model
model.load_state_dict(model_dict)
model.eval()





temporal_shuffling(
  (encoder): gnn_encoder(
    (edge_mlp): EdgeMLP(
      (mlp): Sequential(
        (0): Linear(in_features=3, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=576, bias=True)
      )
    )
    (conv1): NNConv(9, 64, aggr=add, nn=EdgeMLP(
      (mlp): Sequential(
        (0): Linear(in_features=3, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=576, bias=True)
      )
    ))
    (conv2): GATConv(64, 32, heads=1)
    (fc1): Linear(in_features=32, out_features=64, bias=True)
    (fc2): Linear(in_features=64, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=256, bias=True)
  )
  (fc): Linear(in_features=512, out_features=1, bias=True)
)

#### Step 2: Extract Layers
In this step we extract the layers we want to use for the supervised model downstream. In this case, we need the NNConv and GATConv layers from our model, but since our NNConv actually depends on a separate layer called EdgeMLP (which is just a multilayer perpcetron), we'll need that too, since it's essentially part of the NNConv layer's parameters. You can assign it the old fashioned way using `EdgeMLP_module = model.edge_mlp` but this will create issues later on when we try to make two copies of `EdgeMLP_module` for freezing and unfreezing it, so we use the `copy` package instead.

In [2]:
EdgeMLP_pretrained = copy.deepcopy(model.encoder.edge_mlp)
NNConv_pretrained = copy.deepcopy(model.encoder.conv1)
GATConv_pretrained = copy.deepcopy(model.encoder.conv2)

We can examine the weights of a layer with the following:

In [3]:
for param_tensor in EdgeMLP_pretrained.state_dict():
    print(param_tensor, "\t", EdgeMLP_pretrained.state_dict()[param_tensor].size())


mlp.0.weight 	 torch.Size([128, 3])
mlp.0.bias 	 torch.Size([128])
mlp.2.weight 	 torch.Size([64, 128])
mlp.2.bias 	 torch.Size([64])
mlp.4.weight 	 torch.Size([576, 64])
mlp.4.bias 	 torch.Size([576])


And here's a test running some random input through the EdgeMLP, to verify it's functional.

In [4]:
# Create some dummy data
dummy_edge_attr = torch.randn(10, 3)  # 10 edges, each with `num_edge_features` features

# Run the data through the `edge_mlp` layer
output = EdgeMLP_pretrained(dummy_edge_attr)
print(output)

tensor([[ 0.0321,  0.0009,  0.0084,  ...,  0.0310, -0.0705,  0.0445],
        [-0.0479,  0.0043, -0.0925,  ...,  0.0205, -0.1029,  0.1922],
        [-0.0017, -0.1137, -0.0789,  ...,  0.1040, -0.0402,  0.0833],
        ...,
        [-0.0005, -0.0056,  0.0161,  ...,  0.0091, -0.0817,  0.0355],
        [-0.0350, -0.0275, -0.0349,  ...,  0.0347, -0.0643,  0.1101],
        [-0.1118,  0.0496, -0.0055,  ...,  0.0199,  0.0280,  0.0592]],
       grad_fn=<AddmmBackward0>)


#### Step 3: Downstream Task
After extracting the layers and verifying everything is functional, we can now either (1) use the layers and their weights as initialization, or (2) use the layers but freeze the weights (i.e. they won't be updated during training). Below uses method (1), using our transferred layers as the initial layers of our network, and then we add on newer (untrained) layers on top of it. I've opted to use another `NNConv` and `GATConv` layer from `PyG`, adding onto the existing `NNConv` and `GATConv` layers, as well as a `global_mean_pool` layer and two fully connected layers. Now we're ready to go!

In [6]:
from models import supervised_downstream1
config = {
    "hidden_channels": [64, 64, 32],
    "dropout": 0.1,
}

pretrained_layers = [EdgeMLP_pretrained, NNConv_pretrained, GATConv_pretrained]

model = supervised_downstream1(config, pretrained_layers, frozen=False)

In [28]:
def check_frozen_status(model):
    layers_to_check = ["conv1", "conv2"]  # Names of the layers in your model that are pretrained

    for layer_name in layers_to_check:
        layer = getattr(model, layer_name)
        for name, param in layer.named_parameters():
            print(f"Layer: {layer_name}, Parameter: {name}, Frozen: {not param.requires_grad}")

# Check if the pretrained layers are frozen or not
check_frozen_status(model)


Layer: conv1, Parameter: bias, Frozen: True
Layer: conv1, Parameter: nn.mlp.0.weight, Frozen: True
Layer: conv1, Parameter: nn.mlp.0.bias, Frozen: True
Layer: conv1, Parameter: nn.mlp.2.weight, Frozen: True
Layer: conv1, Parameter: nn.mlp.2.bias, Frozen: True
Layer: conv1, Parameter: nn.mlp.4.weight, Frozen: True
Layer: conv1, Parameter: nn.mlp.4.bias, Frozen: True
Layer: conv1, Parameter: lin.weight, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.0.weight, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.0.bias, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.2.weight, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.2.bias, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.4.weight, Frozen: True
Layer: conv1, Parameter: edge_mlp.mlp.4.bias, Frozen: True
Layer: conv2, Parameter: att_src, Frozen: True
Layer: conv2, Parameter: att_dst, Frozen: True
Layer: conv2, Parameter: bias, Frozen: True
Layer: conv2, Parameter: lin_src.weight, Frozen: True


#### Finetuning on Downstream Task

In [24]:
from preprocess import create_data_loaders

# Paths
data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/patient_pyg/jh101/supervised/jh101_run1.pt"
data = torch.load(data_path)

loaders, _ = create_data_loaders(data, data_size=1.0, val_ratio=0.2, test_ratio=0.1, batch_size=32, num_workers=4, model_id="supervised")

Total number of examples in dataset: 1113.
Total number of examples used: 1113.
Number of training examples: 890. Number of training batches: 28.
Number of validation examples: 223. Number of validation batches: 7.
Number of test examples: 112. Number of test batches: 4.


In [25]:
train_loader, val_loader, test_loader = loaders

for batch in train_loader:
    print(f"Model output: {model(batch).size()}")
    break

Model output: torch.Size([32])


#### One Shot Transfer

In [7]:
from preprocess import extract_layers
from models import supervised_downstream1
model_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/temporal_shuffling.pth"
model_dict_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/temporal_shuffling_state_dict.pth"

extraced_layers = extract_layers(model_path, model_dict_path, "temporal_shuffling")

config = {
    "hidden_channels": [64, 64, 32],
    "dropout": 0.1,
}

model = supervised_downstream1(config, extraced_layers, frozen=False)

In [9]:
# Paths
from preprocess import create_data_loaders
data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/patient_pyg/jh101/supervised/jh101_run1.pt"
data = torch.load(data_path)

loaders, _ = create_data_loaders(data, data_size=1.0, val_ratio=0.2, test_ratio=0.1, batch_size=32, num_workers=4, model_id="supervised")

Total number of examples in dataset: 1113.
Total number of examples used: 1113.
Number of training examples: 890. Number of training batches: 28.
Number of validation examples: 223. Number of validation batches: 7.
Number of test examples: 112. Number of test batches: 4.


In [10]:
train_loader = loaders[0]

for batch in train_loader:
    print(f"Model Output: {model(batch).size()}")
    break

Model Output: torch.Size([32])


Now let's test the model.

In [27]:
# Model parameters
num_edge_features = 1
prev_channels = 64 #<--- The dimension of of the node features that comes from the transferred layers, i.e. x is shape [num_nodes, prev_channels
hidden_channels = 128
out_channels = 64
fc_channels = 32


# Load model
unfrozen_model = seizure_detection1(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

# Try an example data point
example = data[0]

out = unfrozen_model(example.x, example.edge_index, example.edge_attr, example.batch, mode="linear")
print(out)

Initial: torch.Size([107, 1]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv1: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv2: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv3: torch.Size([107, 128]) torch.Size([2, 11342]) torch.Size([11342, 1])
After conv4: torch.Size([107, 64]) torch.Size([2, 11342]) torch.Size([11342, 1])
tensor([[856.9504]], grad_fn=<AddmmBackward0>)


Now let's implement method (2) with frozen layers.

In [28]:
import copy

frozen_EdgeMLP_module = copy.deepcopy(pretrained_model.edge_mlp)
frozen_NNConv_module = copy.deepcopy(pretrained_model.conv1)
frozen_GATConv_module = copy.deepcopy(pretrained_model.conv2)

# Freeze the layers
for param in frozen_EdgeMLP_module.parameters():
    param.requires_grad = False

for param in frozen_NNConv_module.parameters():
    param.requires_grad = False

for param in frozen_GATConv_module.parameters():
    param.requires_grad = False

In [29]:
import torch.nn as nn
from torch_geometric.nn import NNConv, GATConv
from torch.functional import F
from torch_geometric.nn import global_mean_pool
from models import EdgeMLP


class seizure_detection2(nn.Module):
    def __init__(self, num_edge_features, prev_channels, hidden_channels, out_channels, fc_channels):
        super(seizure_detection2, self).__init__()
        
        # Transfered graph layers
        self.edge_mlp1 = frozen_EdgeMLP_module
        self.conv1 = frozen_NNConv_module 
        self.conv2 = frozen_GATConv_module   

        # # New graph layers
        self.edge_mlp2 = EdgeMLP(num_edge_features, prev_channels, hidden_channels) # The number of node features are updated, therefore harde code this in.
        self.conv3 = NNConv(prev_channels, hidden_channels, self.edge_mlp2)  # New NNConv layer
        self.conv4 = GATConv(hidden_channels, out_channels, heads=1, concat=False) # New GATConv layer
        
        # # Fully connected layers
        self.fc1 = nn.Linear(out_channels, fc_channels)
        self.fc2 = nn.Linear(fc_channels, 1)
        
    def forward(self, x, edge_index, edge_attr, batch, mode = "sigmoid"):
        # Your forward pass
        
        # NNConv layer 1
        print("Initial:", x.shape, edge_index.shape, edge_attr.shape)
        x = self.conv1(x, edge_index, edge_attr)
        print("After conv1:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 1
        x = self.conv2(x, edge_index)
        print("After conv2:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # NNConv layer 2
        x = self.conv3(x, edge_index, edge_attr)
        print("After conv3:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # GATConv layer 2
        x = self.conv4(x, edge_index)
        print("After conv4:", x.shape, edge_index.shape, edge_attr.shape)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch)
        
        # Fully connected layer 1
        x = self.fc1(x)
        x = F.relu(x)
        
        # Fully connected layer 2
        x = self.fc2(x)

        if mode == "sigmoid":
            x = torch.sigmoid(x)
            
        elif mode == "linear":
            pass
        
        return x

Then we can check whether our models are frozen or not.

In [32]:
num_edge_features = 1
prev_channels = 64
hidden_channels = 128
out_channels = 64
fc_channels = 32

unfrozen_model = seizure_detection1(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

frozen_model = seizure_detection2(num_edge_features=num_edge_features, prev_channels=prev_channels, hidden_channels=hidden_channels, 
                          out_channels=out_channels, fc_channels=fc_channels)

for name, param in unfrozen_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}")
print("----------------------------------------------------")    
for name, param in frozen_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}")

Layer: edge_mlp1.mlp.0.weight, Frozen: False
Layer: edge_mlp1.mlp.0.bias, Frozen: False
Layer: edge_mlp1.mlp.2.weight, Frozen: False
Layer: edge_mlp1.mlp.2.bias, Frozen: False
Layer: edge_mlp1.mlp.4.weight, Frozen: False
Layer: edge_mlp1.mlp.4.bias, Frozen: False
Layer: conv1.bias, Frozen: False
Layer: conv1.nn.mlp.0.weight, Frozen: False
Layer: conv1.nn.mlp.0.bias, Frozen: False
Layer: conv1.nn.mlp.2.weight, Frozen: False
Layer: conv1.nn.mlp.2.bias, Frozen: False
Layer: conv1.nn.mlp.4.weight, Frozen: False
Layer: conv1.nn.mlp.4.bias, Frozen: False
Layer: conv1.lin.weight, Frozen: False
Layer: conv2.att_src, Frozen: False
Layer: conv2.att_dst, Frozen: False
Layer: conv2.bias, Frozen: False
Layer: conv2.lin_src.weight, Frozen: False
Layer: edge_mlp2.mlp.0.weight, Frozen: False
Layer: edge_mlp2.mlp.0.bias, Frozen: False
Layer: edge_mlp2.mlp.2.weight, Frozen: False
Layer: edge_mlp2.mlp.2.bias, Frozen: False
Layer: edge_mlp2.mlp.4.weight, Frozen: False
Layer: edge_mlp2.mlp.4.bias, Frozen: 

#### Step 4: Fine-Tuning for Seizure Detection