In [None]:
import sys
import os
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import LambdaLR

# Get the directory containing the notebook
notebook_dir = os.path.dirname(os.path.abspath("__file__"))

# Add the directory containing the notebook to sys.path
sys.path.append(notebook_dir)

# Add the parent directory (which contains the 'dataloaders' directory) to sys.path
parent_dir = os.path.abspath(os.path.join(notebook_dir, '.'))
sys.path.append(parent_dir)


In [None]:
from functions.loader import getLoader
from functions.display_things import *
from functions.trainFuncs import a_proper_training
from functions.STGCN import STGCN
from TransferLearn_Evaluate_TL_FineTuning_Adapter import *


In [None]:

subsamples = [0.05, 0.2, 0.5, 1]
pretrain_seeds = [42, 43, 44]
finetune_seeds = [47, 48, 51]


for subsample in subsamples:

    mse = []
    for pretrain_seed in pretrain_seeds:
        print()
        print("\tpretrained_seed", ":", pretrain_seed)
        for finetune_seed in finetune_seeds:
            best_model, train_losses, val_losses, test_losses, best_epoch = do_da_test(pretrain_station="varnamo",
                                                                                       finetune_station="varberg",
                                                                                       pretrain_seed=pretrain_seed,
                                                                                       finetune_seed=finetune_seed,
                                                                                       epochs=30,
                                                                                       subsample=subsample,
                                                                                       verbose=False)
            print("\t\tfinetune seed,", finetune_seed, ":", test_losses[best_epoch])
            mse.append(test_losses[best_epoch])
            
    avg = sum(mse) / len(mse)
    std_div = np.std(mse)
    
    print("for subsample", subsample, ", average:" , avg, ", std div:", std_div)
    print()
    

# Loading model

In [None]:
class Adapter(nn.Module):
    def __init__(self, input_size, hidden_network):
        super(Adapter, self).__init__()
        self.fc1 = nn.Linear(input_size, input_size//2)
        self.fc2 = nn.Linear(input_size//2, input_size)
        self.hidden_network = hidden_network
        self.relu = nn.ReLU()

    def forward(self, data, inference):
        x = data.x
        x = reshape_to_batches(x, data.batch)
        batch_size, stations, seq_len, features = x.shape
        
        x = x.view(batch_size, -1)        
        # Apply the fully connected layer
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        # Reshape back to original shape
        x = x.view(64, 5, 576, 1)        
        
        x = reshape_from_batches(x)
        data.x = x
        
        x = self.hidden_network(data, inference)
        return x
    

In [None]:
model.load_state_dict(torch.load('Transfer Learning/trained_on_varnamo.pth'))

adapter_network = Adapter(2880, model).cuda()

for param in adapter_network.hidden_network.parameters():
    param.requires_grad = False
for param in adapter_network.fc1.parameters():
    param.requires_grad = True
for param in adapter_network.fc2.parameters():
    param.requires_grad = True

count_parameters(model)


In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()    

# Define the lambda function for scheduling with Noam-style learning rate decay
def lr_lambda(current_step: int, d_model: int, warmup_steps: int) -> float:
    current_step+=1
    return (d_model ** (-0.5)) * min((current_step ** (-0.5)), current_step * (warmup_steps ** (-1.5)))

d_model = transformer_hidden_size
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, d_model, warmup_steps))    

best_model, best_epoch, train_losses, val_losses, lrs = a_proper_training(
    epochs, adapter_network, optimizer, criterion, train_loader, val_loader, scheduler
)

torch.save(best_model.state_dict(), "trained_on_varnamo-finetuned_on_varberg_Adapter.pth")

plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
#plt.plot(lrs, label="learning rates")

plt.title("MSE Loss")
plt.legend()


In [None]:

best_model.eval()

predictAndDisplay(station, test_loader, best_model)