In [None]:
import sys
import os
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import torch

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import LambdaLR

# Get the directory containing the notebook
notebook_dir = os.path.dirname(os.path.abspath("__file__"))

# Add the directory containing the notebook to sys.path
sys.path.append(notebook_dir)

# Add the parent directory (which contains the 'dataloaders' directory) to sys.path
parent_dir = os.path.abspath(os.path.join(notebook_dir, '.'))
sys.path.append(parent_dir)


In [None]:
from functions.loader import getLoader
from functions.display_things import *
from functions.trainFuncs import *
from functions.STGCN import *

In [None]:
station = "varnamo"
future_steps = 36
seq_len = 576
batch_size = 64
random_seed = 42

epochs = 60
warmup_steps = int(epochs * 0.2)
learning_rate = 0.01
hyperN_start=20


# Data Prep

In [None]:
train_loader, val_loader, test_loader = getLoader(station=station, future_steps=future_steps,
                                                  seq_len=seq_len, batch_size=batch_size,
                                                  random_seed=random_seed)

# Model Making

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv

edge_index_single = torch.tensor([[i, j] for i in range(5) for j in range(5) if i != j], dtype=torch.long).t()

class GCN(nn.Module):
    def __init__(self, in_channels=1, gcn_hidden_channels=8, gcn_layers=1):
        super(GCN, self).__init__()
        self.in_conv = GCNConv(in_channels, gcn_hidden_channels)
        self.hidden_convs = nn.ModuleList(
            [GCNConv(gcn_hidden_channels, gcn_hidden_channels) for _ in range(gcn_layers - 1)]
        )

    def forward(self, x, edge_index, edge_attr, weights, batch):        
        x = x.float()
        
        x = self.in_conv(x, edge_index)
        for i, conv in enumerate(self.hidden_convs[:-1]):
            x = F.relu(x)
            x = conv(x, edge_index)
        x = F.relu(x)


        if weights:
            x = reshape_to_batches(x, batch)
            # Preallocate tensor for the result
            result = torch.empty(batch_size, x.size(1), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
            # Iterate over sequences and weights
            for i, (seq, weight) in enumerate(zip(x, weights[0])):
                self.hidden_convs[-1].lin.weight.data = weight  # Assuming weight is a tensor of appropriate shape
                result[i] = self.hidden_convs[-1](seq, edge_index)
            x = reshape_from_batches(result)
        return x


class SimpleTransformer(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_size, nhead, seq_length, num_layers=1, dropout=0.1):
        super(SimpleTransformer, self).__init__()

        self.seq_length = seq_length
        self.output_size = output_size
        self.hidden_layer_size = hidden_layer_size
        
        self.embeddingIn = nn.Linear(input_size, hidden_layer_size)
        self.embeddingTGT = nn.Linear(output_size, hidden_layer_size)
        
        self.PositionalEncoding = PositionalEncoding(max_len=1000, d_model=hidden_layer_size)
        
        # Separate encoder and decoder layers into separate variables
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=hidden_layer_size, nhead=nhead, 
                                       dim_feedforward=4*hidden_layer_size, dropout=dropout, 
                                       activation='gelu')
            for _ in range(num_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=hidden_layer_size, nhead=nhead,
                                       dim_feedforward=4*hidden_layer_size, dropout=dropout, 
                                       activation='gelu')
            for _ in range(num_layers)
        ])

        self.linear1 = nn.Linear(hidden_layer_size, output_size)
                
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
        return mask
        
    def forward(self, x, tgt=None, last_value=None, inference=False, weights=None):
        last_value = torch.unsqueeze(last_value, dim=2)

        initial_tgt = last_value
        
        tgt_input = torch.cat([last_value, tgt[:, :-1]], dim=1)
        
        x = self.embeddingIn(x)
        x = self.PositionalEncoding(x)
        enc_mask = self.generate_square_subsequent_mask(x.size(1)).to(tgt.device)
        x = x.permute(1, 0, 2)
        
        # Pass through each encoder layer
        for layer in self.encoder_layers[:-1]:
            x = layer(x, src_mask=enc_mask)

        if weights:
            # Preallocate tensor for encoder outputs
            encoder_output = torch.empty(batch_size, x.size(0), x.size(2), dtype=x.dtype, device=x.device)
            # Iterate over sequences and weights
            for i, (seq, weight) in enumerate(zip(x.permute(1, 0, 2), weights[1])):
                encoder_output[i] = self.encoder_layers[-1](seq, src_mask=enc_mask)
            # Transpose encoder_outputs to match the desired shape [batch_size, seq_len, feature_size]
            encoder_output = encoder_output.transpose(0, 1)
        else:
            encoder_output = x
        
        encoder_output = encoder_output.permute(1, 0, 2)
                
        if inference:
            tgt_gen = initial_tgt
            generated_sequence = torch.zeros((initial_tgt.size(0), self.seq_length, self.output_size), device=x.device)
            encoder_output = encoder_output.permute(1, 0, 2)

            for i in range(self.seq_length):
                tgt_emb = self.embeddingTGT(tgt_gen)
                tgt_emb = self.PositionalEncoding(tgt_emb)
                tgt_emb = tgt_emb.permute(1, 0, 2)

                # Pass through each decoder layer
                for layer in self.decoder_layers[:-1]:
                    decoder_output = layer(tgt_emb, encoder_output)
                
                output_step = self.linear1(decoder_output[-1, :, :])
                output_step = output_step.unsqueeze(1) 

                generated_sequence[:, i:i+1, :] = output_step

                tgt_gen = torch.cat((tgt_gen, output_step), dim=1)

                if tgt_gen.size(1) > self.seq_length:
                    tgt_gen = tgt_gen[:, 1:, :]

            return generated_sequence

        else:
            tgt = self.embeddingTGT(tgt_input)
            tgt = self.PositionalEncoding(tgt)
            tgt = tgt.permute(1, 0, 2)

            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

            encoder_output = encoder_output.permute(1, 0, 2)
            
            # Pass through each decoder layer
            for layer in self.decoder_layers[:-1]:
                decoder_output = layer(tgt, encoder_output, tgt_mask=tgt_mask)
            
            output = self.linear1(decoder_output)

            return output.permute(1, 0, 2)


class STGCN(nn.Module):
    def __init__(self, in_channels, gcn_layers, hidden_channels, transformer_hidden_size, transformer_num_layers, transformer_nhead, out_channels):
        super(STGCN, self).__init__()
        print(f"\033[100mhidden_channels: {hidden_channels}   GCN hidden layers: {gcn_layers}   "
              f"transformer_hidden_size: {transformer_hidden_size}   transformer_num_layers: {transformer_num_layers}   "
              f"transformer_nhead: {transformer_nhead}\033[0m")

        self.GCN = GCN(in_channels=in_channels, gcn_hidden_channels=hidden_channels, gcn_layers=gcn_layers)

        self.transformer = SimpleTransformer(
            input_size=hidden_channels, hidden_layer_size=transformer_hidden_size,
            output_size=out_channels, seq_length=36, num_layers=transformer_num_layers,
            nhead=transformer_nhead
        ).cuda()

    def forward(self, data, inference=False, weights=None):
        batch = data.batch
        label = data.y
        label = torch.squeeze(label, 2)

        data.x = data.x.float()
        data.edge_attr = data.edge_attr.float()
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Spatial processing
        x = self.GCN(x, edge_index, edge_attr, weights=weights, batch=batch)
        x = reshape_to_batches(x, batch)
        last_value = reshape_to_batches(data.x[:, -1, :], batch)
        label = reshape_to_batches(label, batch)

        # Reshape and pass data through the model for each station
        predictions = []

        for station_data, station_label, station_last_value in zip(x.permute(1, 0, 2, 3), label.permute(1, 0, 2, 3), last_value.permute(1, 0, 2)):
            output = self.transformer(station_data, station_label, station_last_value, inference, weights=weights)
            predictions.append(output)

        # Concatenate predictions for all stations
        predictions = torch.stack(predictions, dim=1)
        return predictions


class WeightGenerator(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(WeightGenerator, self).__init__()
        self.fc1 = nn.Linear(in_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class FusionNetwork(nn.Module):
    def __init__(self, stgcn_params, weight_gen_params):
        super(FusionNetwork, self).__init__()
        self.stgcn = STGCN(**stgcn_params).cuda()
        self.weight_generator = WeightGenerator(**weight_gen_params).cuda()
        self.hyper_on_state = False

    def hyper_on(self, state):
        self.hyper_on_state = state

    def forward(self, data, inference=False):
        if self.hyper_on_state:
            # Generate weights using the WeightGenerator
            x_flat = data.x.view(data.x.size(0), -1)
            generated_weights = self.weight_generator(x_flat)
            
            gcn_last_layer_weight = generated_weights[:, :self.stgcn.GCN.hidden_convs[-1].lin.weight.numel()]
            #print(gcn_last_layer_weight.shape)
            gcn_last_layer_weight = gcn_last_layer_weight.reshape(
                (gcn_last_layer_weight.shape[0]//5, 5, 4, 4)
            )
            gcn_last_layer_weight = gcn_last_layer_weight.sum(dim=1)

            #print(gcn_last_layer_weight.shape)
            #print()
            transformer_last_layer_weight = generated_weights[:, self.stgcn.GCN.hidden_convs[-1].lin.weight.numel():]
            transformer_last_layer_weight = transformer_last_layer_weight.reshape(
                (transformer_last_layer_weight.shape[0]//5, 5, 36, 12)
            )
            transformer_last_layer_weight = transformer_last_layer_weight.sum(dim=1)

            
            #print(transformer_last_layer_weight.shape)
            # Apply STGCN with the generated weights
            predictions = self.stgcn(data, inference, weights=[gcn_last_layer_weight, transformer_last_layer_weight])
        else:
            predictions = self.stgcn(data, inference)

        return predictions

# Training

In [None]:

# Example usage:
stgcn_params = {
    'in_channels': 1,
    'gcn_layers': 3,
    'hidden_channels': 4,
    'transformer_hidden_size': 12,
    'transformer_num_layers': 2,
    'transformer_nhead': 2,
    'out_channels': 1
}

weight_gen_params = {
    'in_channels': 1 * 576,  # Adjust based on input dimensions
    'hidden_channels': 64,
    'out_channels': 4 * 4 + 2 * 12 * 12 + 150 - 6 # Adjust based on the weight dimensions
}

model = FusionNetwork(stgcn_params, weight_gen_params).cuda()


optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()    

# Define the lambda function for scheduling with Noam-style learning rate decay
def lr_lambda(current_step: int, d_model: int, warmup_steps: int) -> float:
    current_step+=1
    return (d_model ** (-0.5)) * min((current_step ** (-0.5)), current_step * (warmup_steps ** (-1.5)))

d_model = stgcn_params['transformer_hidden_size']
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, d_model, warmup_steps))    

# Now pass the scheduler to the training function
best_model, best_epoch, train_losses, val_losses, lrs = a_proper_training(
    epochs, model, optimizer, criterion, train_loader, val_loader, scheduler, hyperN_start=hyperN_start
)

torch.save(best_model.state_dict(), "HyperNetwork_pretrained_on_varnamo.pth")

plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
#plt.plot(lrs, label="learning rates")

plt.title("MSE Loss")
plt.legend()


In [None]:
"""
import optuna
from optuna.trial import TrialState
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
import time


NUM_EPOCHS = 50

def objective(trial):
    print("\033[41m--------"+str(trial.number)+"-----------------------------------------------------------------------------\033[0m")
    try:
        # Suggest hyperparameters with even values
        hidden_channels = trial.suggest_int('hidden_channels', 2, 14, step=2)
        gcn_layers = trial.suggest_int('gcn_layers', 1, 4)
        transformer_num_layers = trial.suggest_int('transformer_num_layers', 1, 6)
        transformer_nhead = trial.suggest_int('transformer_nhead', 1, 6)
        factor = trial.suggest_int('factor', 2, 12, step=2)
        transformer_hidden_size = transformer_nhead * factor
        learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True) 
        
        print(hidden_channels, gcn_layers, transformer_num_layers, transformer_nhead, factor, transformer_hidden_size, learning_rate)

        model = STGCN(in_channels=1,
                      gcn_layers=gcn_layers,
                      hidden_channels=hidden_channels, 
                      transformer_hidden_size=transformer_hidden_size, 
                      transformer_num_layers=transformer_num_layers,
                      transformer_nhead=transformer_nhead,
                      out_channels=1).cuda()

        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.MSELoss()

        # Define the lambda function for scheduling with Noam-style learning rate decay
        def lr_lambda(current_step: int, d_model: int, warmup_steps: int) -> float:
            current_step+=1
            return (d_model ** (-0.5)) * min((current_step ** (-0.5)), current_step * (warmup_steps ** (-1.5)))

        warmup_steps = NUM_EPOCHS // 3#int(NUM_EPOCHS * 0.3)
        d_model = transformer_hidden_size
        scheduler = LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, d_model, warmup_steps))    

        best_loss = float('inf')
        patience = 20  # Number of epochs to wait for improvement before stopping
        patience_counter = 0  # Counter for epochs without improvement
        best_model = None
        train_losses = list()
        val_losses = list()
        
        for epoch in range(NUM_EPOCHS):
            train_loss = train_epoch(epoch, optimizer, criterion, model, train_loader)
            val_loss = validate_epoch(epoch, criterion, model, val_loader)
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            scheduler.step()

            if val_loss < best_loss:
                best_loss = val_loss
                best_model = copy.deepcopy(model)
                patience_counter = 0  # Reset counter if improvement is observed
            else:
                patience_counter += 1  # Increment counter if no improvement

            if patience_counter >= patience:
                print(f"\033[34mStopping early at epoch {epoch} due to no improvement in validation loss.\033[0m")
                break  # Exit the loop if the model hasn't improved for 'patience' epochs
        
        plt.plot(train_losses, label="train")
        plt.plot(val_losses, label="val")
        plt.title("MSE Loss, lr=" + str(learning_rate))
        plt.legend()
        torch.save(best_model.state_dict(), f'trained_on_varnamo{trial.number}.png')
                
        print()
        return best_loss
    except Exception as e:
        print(e)
        print()
        return float('inf')

# Optimize hyperparameters
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=30)  # Define the number of trials

print("Best trial:")
trial = study.best_trial

print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")
"""