#### Step 1: Load Model
Before we begin transfer learning we first have to load the model. This can be done in two ways (1) load the `model.pth` which includes the model's architecture and weights, or (2) load the model class itself, defined in `pyg_model.py`. Either way works, but (2) is a bit safer when dealing with unknown files. After loading the model, we then load the state dictionary `model_state_dict.pth` which allows us to reference specific layers of the model and is crucial for examining, extracting, or modifying its underlying architecture.

In [1]:
import torch
import copy
import sys
sys.path.append("../src")

# Mac
model_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/relative_positioning.pth"
model_dict_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/relative_positioning_state_dict.pth"

# Load model
model = torch.load(model_path)

# Load state dictionary
model_dict = torch.load(model_dict_path)

# Set the state dictionary to the model
model.load_state_dict(model_dict)
model.eval()





relative_positioning(
  (embedder): gnn_embedder(
    (edge_mlp): EdgeMLP(
      (mlp): Sequential(
        (0): Linear(in_features=3, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=576, bias=True)
      )
    )
    (conv1): NNConv(9, 64, aggr=add, nn=EdgeMLP(
      (mlp): Sequential(
        (0): Linear(in_features=3, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=576, bias=True)
      )
    ))
    (conv2): GATConv(64, 32, heads=1)
    (fc1): Linear(in_features=32, out_features=64, bias=True)
    (fc2): Linear(in_features=64, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=256, bias=True)
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

#### Step 2: Extract Layers
In this step we extract the layers we want to use for the supervised model downstream. In this case, we need the NNConv and GATConv layers from our model, but since our NNConv actually depends on a separate layer called EdgeMLP (which is just a multilayer perpcetron), we'll need that too, since it's essentially part of the NNConv layer's parameters. You can assign it the old fashioned way using `EdgeMLP_module = model.edge_mlp` but this will create issues later on when we try to make two copies of `EdgeMLP_module` for freezing and unfreezing it, so we use the `copy` package instead.

In [2]:
EdgeMLP_pretrained = copy.deepcopy(model.embedder.edge_mlp)
NNConv_pretrained = copy.deepcopy(model.embedder.conv1)
GATConv_pretrained = copy.deepcopy(model.embedder.conv2)

We can examine the weights of a layer with the following:

In [3]:
for param_tensor in EdgeMLP_pretrained.state_dict():
    print(param_tensor, "\t", EdgeMLP_pretrained.state_dict()[param_tensor].size())


mlp.0.weight 	 torch.Size([128, 3])
mlp.0.bias 	 torch.Size([128])
mlp.2.weight 	 torch.Size([64, 128])
mlp.2.bias 	 torch.Size([64])
mlp.4.weight 	 torch.Size([576, 64])
mlp.4.bias 	 torch.Size([576])


And here's a test running some random input through the EdgeMLP, to verify it's functional.

In [4]:
# Create some dummy data
dummy_edge_attr = torch.randn(10, 3)  # 10 edges, each with `num_edge_features` features

# Run the data through the `edge_mlp` layer
output = EdgeMLP_pretrained(dummy_edge_attr)
print(output)

tensor([[ 0.1203,  0.0915, -0.0502,  ..., -0.2302,  0.1062, -0.0885],
        [ 0.1120,  0.1023,  0.0214,  ..., -0.1312,  0.3471,  0.1361],
        [ 0.3126,  0.0766,  0.1168,  ..., -0.3172,  0.1435, -0.1375],
        ...,
        [ 0.2829,  0.0081, -0.0191,  ..., -0.4392,  0.4150, -0.1764],
        [ 0.1666,  0.0201, -0.0104,  ..., -0.0859,  0.0363, -0.0681],
        [ 0.2758,  0.0096,  0.0423,  ..., -0.4276,  0.1904, -0.1761]],
       grad_fn=<AddmmBackward0>)


#### Step 3: Downstream Task
After extracting the layers and verifying everything is functional, we can now either (1) use the layers and their weights as initialization, or (2) use the layers but freeze the weights (i.e. they won't be updated during training). Below uses method (1), using our transferred layers as the initial layers of our network, and then we add on newer (untrained) layers on top of it. I've opted to use another `NNConv` and `GATConv` layer from `PyG`, adding onto the existing `NNConv` and `GATConv` layers, as well as a `global_mean_pool` layer and two fully connected layers. Now we're ready to go!

In [5]:
from models import downstream1
config = {
    "hidden_channels": [64, 64, 32],
    "dropout": 0.1,
}

pretrained_layers = [EdgeMLP_pretrained, NNConv_pretrained, GATConv_pretrained]

model = downstream1(config, pretrained_layers, frozen=False)

In [6]:
def check_frozen_status(model):
    layers_to_check = ["conv1", "conv2"]  # Names of the layers in your model that are pretrained

    for layer_name in layers_to_check:
        layer = getattr(model, layer_name)
        for name, param in layer.named_parameters():
            print(f"Layer: {layer_name}, Parameter: {name}, Frozen: {not param.requires_grad}")

# Check if the pretrained layers are frozen or not
check_frozen_status(model)

Layer: conv1, Parameter: bias, Frozen: False
Layer: conv1, Parameter: nn.mlp.0.weight, Frozen: False
Layer: conv1, Parameter: nn.mlp.0.bias, Frozen: False
Layer: conv1, Parameter: nn.mlp.2.weight, Frozen: False
Layer: conv1, Parameter: nn.mlp.2.bias, Frozen: False
Layer: conv1, Parameter: nn.mlp.4.weight, Frozen: False
Layer: conv1, Parameter: nn.mlp.4.bias, Frozen: False
Layer: conv1, Parameter: lin.weight, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.0.weight, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.0.bias, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.2.weight, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.2.bias, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.4.weight, Frozen: False
Layer: conv1, Parameter: edge_mlp.mlp.4.bias, Frozen: False
Layer: conv2, Parameter: att_src, Frozen: False
Layer: conv2, Parameter: att_dst, Frozen: False
Layer: conv2, Parameter: bias, Frozen: False
Layer: conv2, Parameter: lin_src.weight, Frozen: False


#### Finetuning on Downstream Task

In [7]:
import sys
import torch
sys.path.append("../src")
from preprocess import create_data_loaders

# Paths
data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/patient_pyg/jh101/supervised/jh101_run1.pt"
data = torch.load(data_path)

loaders, _ = create_data_loaders(data, val_ratio=0.2, test_ratio=0.1, batch_size=32, num_workers=4, model_id="downstream1", train_ratio=0.2)

0.2
Total number of examples in dataset: 1113.
Total number of examples used: 1113.
Number of training examples: 222. Number of training batches: 7.
Number of validation examples: 222. Number of validation batches: 7.
Number of test examples: 111. Number of test batches: 4.


In [8]:
train_loader, val_loader, test_loader = loaders

for batch in train_loader:
    print(f"Model output: {model(batch).size()}")
    break

Model output: torch.Size([32])


#### Automatic Transfer
If you want to do all of the above in one step, see below. Note that this implemented in `train.py` when you select the `downstream1` or `downstream2` models.

In [10]:
import sys
import torch
sys.path.append("../src")

from preprocess import extract_layers
from models import downstream1, downstream2

# Mac
model_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/relative_positioning.pth"
model_dict_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/models/jh101/model/relative_positioning_state_dict.pth"

# PC
# model_path = r"C:\Users\xmoot\Desktop\Models\ssl-seizure-detection\jh101\model\temporal_shuffling.pth"
# model_dict_path = r"C:\Users\xmoot\Desktop\Models\ssl-seizure-detection\jh101\model\temporal_shuffling_state_dict.pth"

extracted_layers = extract_layers(model_path, model_dict_path, "temporal_shuffling")

config1 = {
    "hidden_channels": [64, 64, 32],
    "dropout": 0.1,
}

config2 = {
    "hidden_channels": 32,
    "dropout": 0.1,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model1 = downstream1(config1, extracted_layers, frozen=False).to(device)
model2 = downstream2(config2, extracted_layers, frozen=False).to(device)

Load the data for supervised learning.

In [13]:
# Paths
from preprocess import create_data_loaders

# Mac
data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/patient_pyg/jh101/supervised/jh101_run1.pt"

# PC
# data_path = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_pyg\jh101\supervised\jh101_run1.pt"


data = torch.load(data_path)
loaders, _ = create_data_loaders(data, val_ratio=0.2, test_ratio=0.1, batch_size=32, num_workers=4, model_id="supervised")
train_loader = loaders[0]

None
Total number of examples in dataset: 1113.
Total number of examples used: 1113.
Number of training examples: 780. Number of training batches: 25.
Number of validation examples: 222. Number of validation batches: 7.
Number of test examples: 111. Number of test batches: 4.


And we can see that the model output is working!

In [14]:
for batch in train_loader:
    batch = batch.to(device)
    print(f"Model Output: {model1(batch)}")
    print(f"Model Output: {model2(batch)}")
    break

Model Output: tensor([17.1265, 17.4924, 11.6877, 11.8998, 10.9818,  8.9247,  9.3065, 10.8685,
        10.9933,  5.1047,  9.5081,  8.6494, 11.8649, 11.8538,  6.7729, 15.1422,
        13.8074, 15.0400, 14.5431, 12.5139, 16.1494, 13.2782, 14.7469, 11.2534,
        11.3575, 16.6535, 24.3897, 10.8897, 11.1243, 18.1994, 14.5320, 18.4266],
       grad_fn=<SqueezeBackward1>)
Model Output: tensor([ 8.1952e-01, -2.0821e-01,  3.8645e-01, -1.0718e+00,  2.3025e-01,
         1.7284e-01,  3.4750e-04, -3.7224e-01,  1.8576e-01,  9.4093e-02,
        -4.1354e-01,  9.7826e-02, -1.8611e-01,  4.0696e-02, -4.6752e-01,
         2.5970e-01,  1.5231e-01, -4.4645e-02, -3.8504e-02,  4.7481e-01,
        -9.9378e-02,  1.0567e-01, -1.1266e-01, -9.6616e-03,  1.4535e-01,
        -2.0479e-01, -8.3742e-01,  5.2366e-02, -1.9126e-01,  3.5200e-01,
         7.2762e-01, -1.1013e+00], grad_fn=<SqueezeBackward1>)
