In [7]:
from spatialSSL.Dataloader import FullImageDatasetConstructor
from spatialSSL.Utils import split_dataset
from spatialSSL.Training import train
from spatialSSL.Models import GAT4
from spatialSSL.Training import train_epoch
from spatialSSL.Testing import test
from spatialSSL.Dataset import InMemoryGraphDataset
import torch.nn as nn
import numpy as np

In [2]:
# Fine tune
import torch
import zipfile

# Define a function to load the data from the ZIP file
def load_from_zip(zip_path, file_name):
    with zipfile.ZipFile(zip_path, 'r') as zipf:
        with zipf.open(file_name) as file:
            return torch.load(file)

# Load the pre_val_list for fine tuning
pre_val_list = load_from_zip('./processed_data/pre_training_data_img6_r30_n1_random_01.zip', 'pre_val_list.pt')
pre_train_list = load_from_zip('./processed_data/pre_training_data_img6_r30_n1_random_01.zip', 'pre_train_list.pt')

In [3]:
len(pre_train_list)

4

In [4]:
from sklearn.model_selection import train_test_split

###################### Now for testing
# Split the pre_val_list into 80% for tune_train and 20% for temporary validation/test
tune_train, temp_val_test = train_test_split(pre_train_list, test_size=0.20, random_state=42)

# Split the temporary validation/test into 50% for tune_val and 50% for tune_test
tune_val, tune_test = train_test_split(pre_val_list, test_size=0.50, random_state=42)

# Now, tune_train contains 80% of pre_val_list, tune_val contains 10%, and tune_test contains 10%

In [5]:
# Load into dataloader 

from torch_geometric.loader import DataLoader

# Create DataLoader objects for pre-training and pre-validation
tune_train_loader = DataLoader(tune_train, batch_size=1, shuffle=True)
tune_val_loader = DataLoader(tune_val, batch_size=1, shuffle=False)
tune_test_loader = DataLoader(tune_test, batch_size=1, shuffle=False)
# Now you can use pre_train_loader and pre_val_loader in your training loop


In [8]:
# Load pre-trained models

PRE_TRAINED_MODEL_PATH = "./models/img6_r30_n1_random_01_GAT4_0.001_weight.pt"





# Create Identity class
class Identity(nn.Module):
    def __init__(self):
        super(Identity,self).__init__()
        
    def forward(self,x):
        return x





In [23]:
# Load pre-trained model and add our layers

input_layer = 550
hidden_1 = 256
hidden_2 = 33
out_layer = 550

# Pretraining
# Define the device
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu') #"cpu"

def freeze_except_last(model, num_unfrozen_layers=1):
    # Convert model parameters into a list
    params = list(model.parameters())

    # Freeze all parameters
    for param in params:
        param.requires_grad = False

    # Unfreeze the last 'num_unfrozen_layers' layers
    for param in params[-num_unfrozen_layers:]:
        param.requires_grad = True

model = GAT4(input_layer, hidden_1,hidden_2, out_layer).to(device)

model.load_state_dict(torch.load(PRE_TRAINED_MODEL_PATH))
freeze_except_last(model, num_unfrozen_layers=1) # Unfreeze the last 2 layers

# Training code
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
criterion = nn.MSELoss()



In [24]:

# Define loss function and optimizer
#criterion = nn.MSELoss()
#optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 300
patience = 8

In [25]:
train(model=model, train_loader=tune_train_loader, val_loader=tune_val_loader, criterion=criterion, num_epochs= num_epochs, patience = patience, optimizer= optimizer,weight_loss = False ,model_path = './models/img6_r30_n1_random_01_GAT4_0.001_tuned_r30_n1_random_01_freeze_1_GAT4_0.001.pt')


Epoch 1/300, train loss: 0.4793, train r2: 0.0945, train mse: 0.4787,  val loss: 0.5247, val r2: 0.0874, val mse: 0.5247, Time: 1.4150s
Epoch 2/300, train loss: 0.4763, train r2: 0.0978, train mse: 0.4772,  val loss: 0.5240, val r2: 0.0889, val mse: 0.5244, Time: 0.8651s
Epoch 3/300, train loss: 0.4738, train r2: 0.1032, train mse: 0.4758,  val loss: 0.5229, val r2: 0.0875, val mse: 0.5239, Time: 0.8251s
Epoch 4/300, train loss: 0.4726, train r2: 0.1061, train mse: 0.4749,  val loss: 0.5224, val r2: 0.0897, val mse: 0.5235, Time: 0.8236s
Epoch 5/300, train loss: 0.4718, train r2: 0.1087, train mse: 0.4741,  val loss: 0.5217, val r2: 0.0911, val mse: 0.5231, Time: 0.8551s
Epoch 6/300, train loss: 0.4708, train r2: 0.1110, train mse: 0.4735,  val loss: 0.5210, val r2: 0.0924, val mse: 0.5228, Time: 0.8620s
Epoch 7/300, train loss: 0.4704, train r2: 0.1129, train mse: 0.4729,  val loss: 0.5200, val r2: 0.0936, val mse: 0.5224, Time: 0.8434s
Epoch 8/300, train loss: 0.4700, train r2: 0.114

KeyboardInterrupt: 

In [None]:
pre_trained_model = model.load_state_dict(torch.load(PRE_TRAINED_MODEL_PATH))
