In [6]:
%cd /mnt/home/tnguyen/projects/florah/florah-tree
%load_ext autoreload
%autoreload 2

import os
import h5py
import pickle
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import torch_geometric
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from ml_collections import config_dict

from models import training_utils, models, models_utils, flows_utils

/mnt/home/tnguyen/projects/florah/florah-tree
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
dset_path = '/mnt/ceph/users/tnguyen/florah/datasets/experiments/GUREFT05-Nanc1.debug.pkl'
with open(dset_path, 'rb') as f:
    data = pickle.load(f)

# convert networkx to pytorch geometric
data = [from_networkx(d) for d in data]

def prepare_dataloader(data, train_frac=0.8, batch_size=1024, num_workers=1):

    num_total = len(data)
    num_train = int(num_total * train_frac)

    np.random.shuffle(data)

    # calculate the normaliziation statistics
    x = torch.cat([d.x for d in data[:num_train]])
    x_mean = x.mean(dim=0)
    x_std = x.std(dim=0)
    norm_dict = {
        "x_mean": list(x_mean.numpy()),
        "x_std": list(x_std.numpy()),
    }
    for d in data:
        d.x = (d.x - x_mean) / x_std

    train_loader = DataLoader(
        data[:num_train], batch_size=batch_size, shuffle=True, 
        num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(
        data[num_train:], batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, norm_dict

  data[key] = torch.tensor(value)


### Debugging stuff

In [8]:
# fake config for debugging
config = config_dict.ConfigDict()
config.input_size = 3
config.d_time = 1
config.d_time_projection = 64
config.d_feat_projection = 64
config.sum_features = False

# featurizer parameters
config.featurizer = config_dict.ConfigDict()
config.featurizer.d_model = 64
config.featurizer.nhead = 4
config.featurizer.num_encoder_layers = 2
config.featurizer.dim_feedforward = 128
config.featurizer.use_embedding = True

# rnn parameters
config.rnn = config_dict.ConfigDict()
config.rnn.output_size = 64
config.rnn.hidden_size = 64
config.rnn.num_layers = 2

# flows parameters
config.flows = config_dict.ConfigDict()
config.flows.hidden_size = 64
config.flows.num_blocks = 2
config.flows.num_layers = 2

featurizer = models.TransformerFeaturizer(
    input_size=config.input_size,
    d_model=config.featurizer.d_model,
    nhead=config.featurizer.nhead,
    num_encoder_layers=config.featurizer.num_encoder_layers,
    dim_feedforward=config.featurizer.dim_feedforward,
    use_embedding=config.featurizer.use_embedding,
)
rnn = models.GRUModel(
    input_size=config.d_feat_projection + config.d_time_projection,
    output_size=config.rnn.output_size,
    hidden_size=config.rnn.hidden_size,
    num_layers=config.rnn.num_layers,
)
flows = flows_utils.build_maf(
    features=config.input_size - config.d_time,
    hidden_features=config.flows.hidden_size,
    context_features=config.rnn.output_size + config.featurizer.d_model,
    num_layers=config.flows.num_layers,
    num_blocks=config.flows.num_blocks,
)

time_projection_layer = nn.Linear(
    config.d_time, config.d_time_projection)
feat_projection_layer = nn.Linear(
    config.input_size - config.d_time, config.d_feat_projection)

In [9]:
train_loader, val_loader, norm_dict = prepare_dataloader(data, batch_size=32)
batch = next(iter(train_loader))

batch  = training_utils.prepare_batch(batch, num_samples_per_graph=1)
padded_features = batch[0]
lengths = batch[1]
padded_out_features = batch[2]
out_lengths = batch[3]

# Separate the time and feature dimensions of the output
t_out = padded_out_features[:, 0, -config.d_time:]
f_out = padded_out_features[:, :, :-config.d_time]

# add a starting token of all zeros to the first time step of padded_out_features
# this is the input to the RNN
padded_rnn_features = nn.functional.pad(
    f_out, (0, 0, 1, 0), value=0)

# divide the rnn into the input and output component
# the input will be feed into the RNN, while the output will be used for the flow loss
padded_rnn_input = padded_rnn_features[:, :-1, :]
padded_rnn_output = padded_rnn_features[:, 1:, :]

# Assuming padded_features is your input to the transformer
# with shape [batch_size, seq_len, feature_size]
batch_size, seq_len, _ = padded_features.size()
out_seq_len = padded_out_features.size(1)

# Create a mask for padding (assuming padding tokens are zero)
# The mask should have the shape [seq_len, batch_size]
transformer_padding_mask = training_utils.create_padding_mask(
    lengths, seq_len, batch_first=True)
rnn_padding_mask = training_utils.create_padding_mask(
    out_lengths, out_seq_len, batch_first=True)

In [107]:
x = featurizer(padded_features, src_key_padding_mask=transformer_padding_mask)
x = x.masked_fill(transformer_padding_mask.unsqueeze(-1), 0)
x2 = x.sum(1)

In [124]:
featurizer.eval()
x = featurizer(padded_features, src_key_padding_mask=transformer_padding_mask)

if not config.sum_features:
    lengths = transformer_padding_mask.eq(0).sum(dim=1)
    x = x[torch.arange(batch_size).to(x.device), lengths-1]
else:
    # set all the padding tokens to zero
    x = x.masked_fill(transformer_padding_mask.unsqueeze(-1), 0)

In [36]:
   summed_features = torch.stack([torch.sum(seq[:length], dim=0) 
                                   for seq, length in zip(x, original_lengths)])

    # Select the last features for each sequence in the batch
    last_features = torch.stack([seq[length - 1] 
                                 for seq, length in zip(x, original_lengths)])

tensor([[ 1.0466,  1.6705, -0.3154,  ...,  0.1360, -0.3569,  0.6613],
        [ 1.0263,  1.7973, -0.5968,  ..., -0.0718, -0.2929,  0.8572],
        [ 0.9830,  1.8787, -0.8677,  ..., -0.2584, -0.2375,  1.0461],
        ...,
        [ 1.0903,  1.9821, -0.5649,  ..., -1.1361,  0.5769,  0.2221],
        [ 1.0754,  1.9864, -0.5998,  ..., -1.2472,  0.6409,  0.2055],
        [ 1.0594,  1.9078, -0.2549,  ..., -1.3134,  0.7642, -0.0770]],
       grad_fn=<IndexBackward0>)

In [None]:
        x = x.masked_select(
            transformer_padding_mask.unsqueeze(-1).repeat(1, 1, x.size(-1)))

In [15]:
transformer_padding_mask

tensor([[False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False,  True],
        [False, False, False, False, False, False, False, False, False, False,
          True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True,
          True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True,
          True,  True,  True],
        [F

torch.Size([29, 13, 64])

In [10]:
# run the feature extractor
x = x.sum(dim=1) if config.sum_features else x[:, -1]


In [62]:

# project the time and feature dimensions of the output
t_proj = time_projection_layer(t_out)
f_proj = feat_projection_layer(padded_rnn_input)

# create the input for the RNN and run
x_rnn = torch.cat(
    [f_proj, t_proj.unsqueeze(1).repeat(1, out_seq_len, 1)], dim=-1)
x_rnn, hout = rnn(x_rnn, out_lengths, return_hidden_states=True)

# create the context and input for the flows
flow_context = torch.cat(
    [x_rnn, x.unsqueeze(1).repeat(1, out_seq_len, 1)], dim=-1)
flow_context = flow_context[~rnn_padding_mask]
x_flow = f_out[~rnn_padding_mask]

### Make the class

In [98]:
class SequenceRegressor(pl.LightningModule):
    """
    A PyTorch Lightning module that combines a Transformer-based feature extractor,
    an RNN, and a normalizing flow to regress a sequence of features.

    Attributes
    ----------
    input_size : int
        The size of the input
    d_feat_projection : int
        The dimension of the feature projection layer.
    d_time : int
        The dimension of the time projection layer.
    d_time_projection : int
        The dimension of the time projection layer.
    sum_features : bool
        Whether to sum the features of the featurizer in the time dimension.
    num_samples_per_graph : int
        The number of samples per graph.
    batch_first : bool
        Whether the input is batch first.
    featurizer : nn.Module
        The featurizer.
    rnn : nn.Module
        The RNN.
    flows : nn.Module
        The normalizing flow.
    optimizer : torch.optim.Optimizer
        The optimizer.
    scheduler : torch.optim.lr_scheduler._LRScheduler
        The scheduler.
    norm_dict : dict
        The normalization dictionary. For bookkeeping purposes only.
    """
    def __init__(
        self,
        input_size,
        featurizer_args,
        rnn_args,
        flows_args,
        optimizer_args=None,
        scheduler_args=None,
        sum_features=False,
        d_time = 1,
        d_time_projection = 128,
        d_feat_projection = 128,
        num_samples_per_graph=1,
        norm_dict=None,
    ):
        """
        Parameters
        ----------
        input_size : int
            The size of the input
        num_classes : int
            The number of classes
        featurizer_args : dict
            Arguments for the featurizer
        rnn_args : dict
            Arguments for the RNN
        flows_args : dict
            Arguments for the normalizing flow
        optimizer_args : dict, optional
            Arguments for the optimizer. Default: None
        scheduler_args : dict, optional
            Arguments for the scheduler. Default: None
        sum_features : bool, optional
            Whether to sum the features of the featurizer in the time dimension.
            Default: False
        d_time : int, optional
            The dimension of the time projection layer. Default: 1
        d_time_projection : int, optional
            The dimension of the time projection layer. Default: 1
        d_feat_projection : int, optional
            The dimension of the feature projection layer. Default: 1
        num_samples_per_graph : int, optional
            The number of samples per graph. Default: 1
        norm_dict : dict, optional
            The normalization dictionary. For bookkeeping purposes only.
            Default: None
        """
        super().__init__()
        self.input_size = input_size
        self.featurizer_args = featurizer_args
        self.rnn_args = rnn_args
        self.flows_args = flows_args
        self.optimizer_args = optimizer_args or {}
        self.scheduler_args = scheduler_args or {}
        self.sum_features = sum_features
        self.d_time = d_time
        self.d_time_projection = d_time_projection
        self.d_feat_projection = d_feat_projection
        self.num_samples_per_graph = num_samples_per_graph
        self.norm_dict = norm_dict
        self.batch_first = True # always True
        self.save_hyperparameters()

        self._setup_model()

    def _setup_model(self):

        # create the featurizer
        if self.featurizer_args.name == 'transformer':
            activation_fn = models_utils.get_activation(
                self.featurizer_args.activation)
            self.featurizer = models.TransformerFeaturizer(
                input_size=self.input_size,
                d_model=self.featurizer_args.d_model,
                nhead=self.featurizer_args.nhead,
                num_encoder_layers=self.featurizer_args.num_encoder_layers,
                dim_feedforward=self.featurizer_args.dim_feedforward,
                batch_first=self.batch_first,
                use_embedding=self.featurizer_args.use_embedding,
                activation_fn=activation_fn,
            )
        else:
            raise ValueError(
                f'Featurizer {featurizer_name} not supported')

        # create the rnn
        if self.rnn_args.name == 'gru':
            activation_fn = models_utils.get_activation(
                self.rnn_args.activation)
            self.rnn = models.GRUModel(
                input_size=config.d_feat_projection + config.d_time_projection,
                output_size=self.rnn_args.output_size,
                hidden_size=self.rnn_args.hidden_size,
                num_layers=self.rnn_args.num_layers,
                activation_fn=activation_fn,
            )
        else:
            raise ValueError(
                f'RNN {rnn_name} not supported')

        # create the flows
        self.flows = flows_utils.build_maf(
            features=self.input_size - self.d_time,
            hidden_features=self.flows_args.hidden_size,
            context_features=self.rnn_args.output_size + self.featurizer_args.d_model,
            num_layers=self.flows_args.num_layers,
            num_blocks=self.flows_args.num_blocks,
        )

        # create the projection layers
        self.time_proj_layer = nn.Linear(self.d_time, self.d_time_projection)
        self.feat_proj_layer = nn.Linear(
            self.input_size - self.d_time, self.d_feat_projection)
        
    def _prepare_batch(self, batch):
        """ Prepare the batch for training. """
        batch  = training_utils.prepare_batch(batch, num_samples_per_graph=1)
        padded_features = batch[0]
        lengths = batch[1]
        padded_out_features = batch[2]
        out_lengths = batch[3]

        # Separate the time and feature dimensions of the output
        t_out = padded_out_features[:, 0, -self.d_time:]
        f_out = padded_out_features[:, :, :-self.d_time]

        # add a starting token of all zeros to the first time step of padded_out_features
        # this is the input to the RNN
        padded_rnn_features = nn.functional.pad(f_out, (0, 0, 1, 0), value=0)

        # divide the rnn into the input and output component
        # the input will be feed into the RNN, 
        # while the output will be used for the flow loss
        padded_rnn_input = padded_rnn_features[:, :-1, :]
        padded_rnn_output = padded_rnn_features[:, 1:, :]

        # Assuming padded_features is your input to the transformer
        # with shape [batch_size, seq_len, feature_size]
        batch_size, seq_len, _ = padded_features.size()
        out_seq_len = padded_out_features.size(1)

        # Create a mask for padding (assuming padding tokens are zero)
        # The mask should have the shape [seq_len, batch_size]
        transformer_padding_mask = training_utils.create_padding_mask(
            lengths, seq_len, batch_first=True)
        rnn_padding_mask = training_utils.create_padding_mask(
            out_lengths, out_seq_len, batch_first=True)

        # Move to the same device as the model
        padded_features = padded_features.to(self.device)
        padded_rnn_input = padded_rnn_input.to(self.device)
        padded_rnn_output = padded_rnn_output.to(self.device)
        transformer_padding_mask = transformer_padding_mask.to(self.device)
        rnn_padding_mask = rnn_padding_mask.to(self.device)
        t_out = t_out.to(self.device)

        # return a dictionary of the inputs
        return_dict = {
            'padded_features': padded_features,
            'padded_rnn_input': padded_rnn_input,
            'padded_rnn_output': padded_rnn_output,
            'transformer_padding_mask': transformer_padding_mask,
            'rnn_padding_mask': rnn_padding_mask,
            't_out': t_out,
            'batch_size': batch_size,
            'seq_len': seq_len,
            'out_seq_len': out_seq_len,
        }
        return return_dict
        
    def forward(
        self, padded_features, padded_rnn_input, t_out,
        transformer_padding_mask=None, rnn_padding_mask=None
    ):
        # extract the features
        x = self.featurizer(
            padded_features, src_key_padding_mask=transformer_padding_mask)
        x = x.sum(dim=1) if self.sum_features else x[:, -1]

        # project the time and feature dimensions
        t_proj = self.time_proj_layer(t_out)
        f_proj = self.feat_proj_layer(padded_rnn_input)

        # create the input for the RNN 
        out_seq_len = padded_rnn_input.size(1)  # lengths after padding
        out_lengths = rnn_padding_mask.eq(0).sum(-1) # original lengths
        x_rnn = torch.cat(
            [f_proj, t_proj.unsqueeze(1).repeat(1, out_seq_len, 1)], dim=-1)
        x_rnn = self.rnn(x_rnn, out_lengths)

        # create the context and input for the flows
        flow_context = torch.cat(
            [x_rnn, x.unsqueeze(1).repeat(1, out_seq_len, 1)], dim=-1)
        flow_context = flow_context[~rnn_padding_mask]

        return flow_context

    def training_step(self, batch, batch_idx):
        batch_dict = self._prepare_batch(batch)

        flow_context = self.forward(
            padded_features=batch_dict['padded_features'],
            padded_rnn_input=batch_dict['padded_rnn_input'],
            t_out=batch_dict['t_out'],
            transformer_padding_mask=batch_dict['transformer_padding_mask'],
            rnn_padding_mask=batch_dict['rnn_padding_mask'],
        )
        x_flow = batch_dict['padded_rnn_output'][~batch_dict['rnn_padding_mask']]
        log_prob = self.flows.log_prob(x_flow, context=flow_context)
        loss = -log_prob.mean()

        # log the loss
        self.log(
            'train_loss', loss, on_step=True, on_epoch=True, logger=True,
            prog_bar=True, batch_size=batch_dict['batch_size'])
        return loss

    def validation_step(self, batch, batch_idx):
        batch_dict = self._prepare_batch(batch)

        flow_context = self.forward(
            padded_features=batch_dict['padded_features'],
            padded_rnn_input=batch_dict['padded_rnn_input'],
            t_out=batch_dict['t_out'],
            transformer_padding_mask=batch_dict['transformer_padding_mask'],
            rnn_padding_mask=batch_dict['rnn_padding_mask'],
        )
        x_flow = batch_dict['padded_rnn_output'][~batch_dict['rnn_padding_mask']]
        log_prob = self.flows.log_prob(x_flow, context=flow_context)
        loss = -log_prob.mean()

        # log the loss
        self.log(
            'val_loss', loss, on_step=True, on_epoch=True, logger=True,
            prog_bar=True, batch_size=batch_dict['batch_size'])
        return loss

    def configure_optimizers(self):
        """ Initialize optimizer and LR scheduler """

        # setup the optimizer
        if self.optimizer_args.name == "Adam":
            return torch.optim.Adam(
                self.parameters(), lr=self.optimizer_args.lr,
                weight_decay=self.optimizer_args.weight_decay)
        elif self.optimizer_args.name == "AdamW":
            return torch.optim.AdamW(
                self.parameters(), lr=self.optimizer_args.lr,
                weight_decay=self.optimizer_args.weight_decay)
        else:
            raise NotImplementedError(
                "Optimizer {} not implemented".format(self.optimizer_args.name))

        # setup the scheduler
        if self.scheduler_args.get(name) is None:
            scheduler = None
        elif self.scheduler_args.name == 'ReduceLROnPlateau':
            scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, 'min', factor=self.scheduler_args.factor,
                patience=self.scheduler_args.patience)
        elif self.scheduler_args.name == 'CosineAnnealingLR':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.scheduler_args.T_max,)
        else:
            raise NotImplementedError(
                "Scheduler {} not implemented".format(self.scheduler_args.name))

        if scheduler is None:
            return optimizer
        else:
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'monitor': 'train_loss',
                    'interval': 'epoch',
                    'frequency': 1
                }
            }

In [80]:
# fake config for debugging
config = config_dict.ConfigDict()
config.input_size = 3
config.d_time = 1
config.d_time_projection = 64
config.d_feat_projection = 64
config.sum_features = False

# featurizer parameters
config.featurizer = config_dict.ConfigDict()
config.featurizer.name = 'transformer'
config.featurizer.d_model = 64
config.featurizer.nhead = 4
config.featurizer.num_encoder_layers = 2
config.featurizer.dim_feedforward = 128
config.featurizer.use_embedding = True
config.featurizer.activation = config_dict.ConfigDict()
config.featurizer.activation.name = 'identity'

# rnn parameters
config.rnn = config_dict.ConfigDict()
config.rnn.name = 'gru'
config.rnn.output_size = 64
config.rnn.hidden_size = 64
config.rnn.num_layers = 2
config.rnn.activation = config_dict.ConfigDict()
config.rnn.activation.name = 'relu'

# flows parameters
config.flows = config_dict.ConfigDict()
config.flows.name = 'maf'
config.flows.hidden_size = 64
config.flows.num_blocks = 2
config.flows.num_layers = 2
config.flows.activation = config_dict.ConfigDict()
config.flows.activation.name = 'tanh'

# optimizer and scheduler configuration
config.optimizer = config_dict.ConfigDict()
config.optimizer.name = 'AdamW'
config.optimizer.lr = 5e-4
config.optimizer.betas = (0.9, 0.98)
config.optimizer.weight_decay = 1e-4
config.optimizer.eps = 1e-9
config.scheduler = config_dict.ConfigDict()
config.scheduler.name = 'CosineAnnealingLR'
config.scheduler.T_max = 100

model = SequenceRegressor(
    input_size=config.input_size,
    featurizer_args=config.featurizer,
    rnn_args=config.rnn,
    flows_args=config.flows,
    optimizer_args=config.optimizer,
    scheduler_args=config.scheduler,
    sum_features=config.sum_features,
    d_time = config.d_time,
    d_time_projection = config.d_time_projection,
    d_feat_projection = config.d_feat_projection,
)

In [91]:
train_loader, val_loader, norm_dict = prepare_dataloader(data, batch_size=32)
batch = next(iter(train_loader))
