In [1]:
import pandas as pd
import torch

from torch.utils.data import Dataset, DataLoader

# from flows.context_flow import ContextFlow

In [2]:
class CsvDataset(Dataset):

    def __init__(self, z_csv_file, y_csv_file):
        self.df_z = pd.read_csv(z_csv_file)
        self.df_y = pd.read_csv(y_csv_file)
        
    def __len__(self):
        return len(self.df_z)

    def __getitem__(self, idx):
        z = torch.tensor(self.df_z.iloc[idx, :].values, dtype=torch.float)
        y = torch.tensor(self.df_y.iloc[idx, :].values, dtype=torch.float)
        return z, y


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from nflows.flows.base import Flow
from nflows.distributions.normal import ConditionalDiagonalNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from sklearn.base import BaseEstimator, ClassifierMixin
from tqdm import trange

from utils_flow.context import generate_one_hot_context
from utils_flow.early_stopping.EarlyStopping import EarlyStopping


class ContextFlow(BaseEstimator, ClassifierMixin):

    def __init__(self, z_dim=20, context_dim=10, num_layers=5, num_iter=1000, patience=10, device='cpu'):
        """
        Context Flow model wrapper.
        Parameters
        ----------
        z_dim : Number of dimensions in the latent space
        context_dim : Number of classes to model
        num_layers : Number of Flow layers
        num_iter : Number of iterations during training
        """
        self.z_dim = z_dim
        self.context_dim = context_dim
        self.num_layers = num_layers
        self.num_classes = context_dim
        self.num_iter = num_iter
        self.patience = patience
        self.device = device

    def fit(self, dl):
        self.model = get_context_flow(inputs_dim=self.z_dim, context_dim=self.context_dim, num_layers=self.num_layers)
        self.model.to(self.device)
        optimizer = optim.Adam(self.model.parameters())
        es = EarlyStopping(patience=self.patience)

        with trange(self.num_iter) as t:
            for _ in t:
                for x, y in dl:
                    x, y = x.to(self.device), y.to(self.device)
                    optimizer.zero_grad()
                    loss = -self.model.log_prob(inputs=x, context=y).mean()
                    loss.backward()
                    optimizer.step()

                    t.set_postfix(loss=loss.item())

                if es.step(loss):
                    break

        return self

    def predict_proba(self, x):
        if type(x) is not torch.Tensor:
            x = torch.tensor(x, dtype=torch.float)

        results = []

        for i in range(self.num_classes):
            context = generate_one_hot_context(np.array(len(x) * [i]), num_classes=self.num_classes)
            results.append(self.model.log_prob(x, context).detach().numpy())

        y_prob = np.stack(results, axis=1)
        y_prob = y_prob / y_prob.sum(axis=1, keepdims=True)
        return y_prob

    def predict(self, x):
        if type(x) is not torch.Tensor:
            x = torch.tensor(x, dtype=torch.float)

        y = self.predict_proba(x).argmax(axis=1)
        return y


def get_context_flow(inputs_dim, context_dim, num_layers):
    base_dist = ConditionalDiagonalNormal(
        shape=[inputs_dim], 
        context_encoder=nn.Linear(context_dim, 2 * inputs_dim)
    )

    transforms = []
    for _ in range(num_layers):
        transforms.append(ReversePermutation(features=inputs_dim))
        transforms.append(MaskedAffineAutoregressiveTransform(
            features=inputs_dim, hidden_features=2 * inputs_dim, context_features=context_dim)
        )
    transform = CompositeTransform(transforms)

    flow = Flow(transform, base_dist)
    return flow


In [4]:
dataset_train = CsvDataset('data/z_train.csv', 'data/y_train_gender.csv')
dl_train = DataLoader(dataset_train, batch_size = 2048)

In [5]:
flow = ContextFlow(z_dim = 128, context_dim = 1, num_layers = 5, num_iter = 2, device = 'cuda:0')

In [6]:
flow.fit(dl_train)

100%|██████████| 2/2 [01:07<00:00, 33.98s/it, loss=182]


ContextFlow(context_dim=1, device='cuda:0', num_iter=2, z_dim=128)

In [7]:
torch.save(flow.model, 'flow-gender.pkt')