In [1]:
import os
from pathlib import Path
import pickle

# Import the remaining JAX related 
from gabenet.mcmc import sample_markov_chain
from gabenet.nets import MultinomialDirichletBelieve
import haiku as hk
import jax
import jax.numpy as jnp
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer



In [2]:
i_checkpoint = 0

ARTEFACT_DIR = Path(os.environ.get('ARTEFACT_DIR', './checkpoints/'))
ARTEFACT_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
files_train = fetch_20newsgroups(subset='all', )
cv = CountVectorizer(min_df=10, max_features=2_000)
X_train = cv.fit_transform(files_train.data)

In [4]:
X_train = X_train[:10].todense().astype(jnp.float32)

In [5]:
# Pseudo-random number generator sequence.
key_seq = hk.PRNGSequence(42)

m_samples, n_features = X_train.shape

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [6]:
def save_last_state(states: dict) -> dict:
    global i_checkpoint
    """Extract and dump last state to disk."""
    last_state = jax.tree_util.tree_map(lambda x: x[:, -1, ...], states)

    with open(ARTEFACT_DIR / f'state_{i_checkpoint}.pkl', 'wb') as fo:
        pickle.dump(last_state, fo)
        print(f'Saving checkpoint i={i_checkpoint}.')

    i_checkpoint += 1

    del states
    
    return last_state

In [7]:
import numpy as np
def network_size(m_samples, n_features):
    """Estimate size of 2-layer mult-dir net."""
    n_hidden_units = [int(n_features**0.25), int(n_features**0.5)]
    
    n_theta = (m_samples * n_hidden_units[-1], m_samples * n_hidden_units[-2])
    n_phi = [n_features * n_hidden_units[-1], n_hidden_units[-1] * n_hidden_units[-2]]
    n_r = n_hidden_units[:1]
    n_c = [1]

    n_rate = m_samples * np.array([ n_hidden_units[0] * n_hidden_units[1], n_features * n_hidden_units[1]])
    n_m = m_samples * np.array([n_hidden_units[0], n_hidden_units[1]])
    n_x = m_samples * np.array([n_features, n_hidden_units[0],n_hidden_units[1]])
    n_activation = np.append(n_m, *n_r)
    n_theta_overhead = n_theta
    n_network = sum(n_theta) + sum(n_phi) + sum(n_r) + sum(n_c) + n_m[1]
    n_overhead = sum(n_theta) + sum(n_phi) + sum(n_m)
    n_sampling = sum(n_rate) + sum(n_x) + sum(n_phi) + sum(n_activation) + sum(n_theta_overhead)
    
    return n_network, n_overhead, n_sampling

In [8]:
n_network, n_overhead, n_sampling = network_size(*X_train.shape)

# n_network, n_overhead, n_sampling = network_size(m_samples=1_000, n_features=13_000)
# n_network, n_overhead = network_size(11314, 15593)
n_total = n_network + n_overhead + n_sampling
print('Network:', n_network, f'parameters ({n_network*8e-6:.2f} MB)')
print('Overhead:', n_overhead, f'parameters ({n_overhead*8e-6:.2f} MB)')
print('Sampling', n_sampling, f'parameters ({n_sampling*8e-6:.2f} MB)')
n_GB = n_total * 8e-9
print('Total:', n_total, f'({n_GB:.2f} GB)')

Network: 182271 parameters (1.46 MB)
Overhead: 188264 parameters (1.51 MB)
Sampling 90502270 parameters (724.02 MB)
Total: 90872805 (0.73 GB)


In [9]:
@hk.transform_with_state
def kernel(n_hidden_units = (200, )):
    """Advance the Markov chain by one step."""
    model = MultinomialDirichletBelieve(n_hidden_units, n_features)
    # Do one Gibbs sampling step.
    model(X_train)
    
@hk.without_apply_rng
@hk.transform_with_state
def _log_prob():
    n_hidden_units = (int(n_features**0.25), int(n_features**0.5))
    model = MultinomialDirichletBelieve(n_hidden_units, n_features)
    return model.log_prob(X_train)

log_prob = jax.vmap(_log_prob.apply, in_axes=[None, 0, None])

In [11]:
params, states = sample_markov_chain(
    next(key_seq), 
    kernel=kernel, 
    n_samples=20, 
    n_burnin_steps=20,
    n_chains=jax.device_count(),
)
_ = states['multinomial_dirichlet_believe/~/cap_layer']['r'].block_until_ready()


 1012. 1013. 1014. 1015. 1016. 1017. 1018. 1019. 1020. 1021. 1022. 1023.
 1024. 1025. 1026. 1027. 1028. 1029. 1030. 1031. 1032. 1033. 1034. 1035.
 1036. 1037. 1038. 1039. 1040. 1041. 1042. 1043. 1044. 1045. 1046. 1047.
 1048. 1049. 1050. 1051. 1052. 1053. 1054. 1055. 1056. 1057. 1058. 1059.
 1060. 1061. 1062. 1063. 1064. 1065. 1066. 1067. 1068. 1069. 1070. 1071.
 1072. 1073. 1074. 1075. 1076. 1077. 1078. 1079. 1080. 1081. 1082. 1083.
 1084. 1085. 1086. 1087. 1088. 1089. 1090. 1091. 1092. 1093. 1094. 1095.
 1096. 1097. 1098. 1099. 1100. 1101. 1102. 1103. 1104. 1105. 1106. 1107.
 1108. 1109. 1110. 1111. 1112. 1113. 1114. 1115. 1116. 1117. 1118. 1119.
 1120. 1121. 1122. 1123. 1124. 1125. 1126. 1127. 1128. 1129. 1130. 1131.
 1132. 1133. 1134. 1135. 1136. 1137. 1138. 1139. 1140. 1141. 1142. 1143.
 1144. 1145. 1146. 1147. 1148. 1149. 1150. 1151. 1152. 1153. 1154. 1155.
 1156. 1157. 1158. 1159. 1160. 1161. 1162. 1163. 1164. 1165. 1166. 1167.
 1168. 1169. 1170. 1171. 1172. 1173. 1174. 1175. 11

In [12]:
last_state = save_last_state(states)

Saving checkpoint i=0.


In [13]:
for _ in range(10):
    params, states = sample_markov_chain(
        next(key_seq), 
        kernel=kernel,
        n_samples=80, 
        n_burnin_steps=0, 
        initial_state=last_state,
        params=params,
    )
    last_state = save_last_state(states)
    lls, _ = log_prob(params, last_state)
    print('Log likelihood', lls)

Saving checkpoint i=1.
