In [1]:
import time

import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp

from sklearn.datasets import fetch_20newsgroups

import matplotlib.pyplot as plt
import seaborn as sns

from src.data_prepocessing import ArticlesDataset

In [2]:
data = fetch_20newsgroups(data_home='./data/', subset='all').data
df = pd.DataFrame(data, columns=['text'])
df

Unnamed: 0,text
0,From: Mamatha Devineni Ratnam <mr47+@andrew.cm...
1,From: mblawson@midway.ecn.uoknor.edu (Matthew ...
2,From: hilmi-er@dsv.su.se (Hilmi Eren)\nSubject...
3,From: guyd@austin.ibm.com (Guy Dawson)\nSubjec...
4,From: Alexander Samuel McDiarmid <am2o+@andrew...
...,...
18841,From: jim.zisfein@factory.com (Jim Zisfein) \n...
18842,From: rdell@cbnewsf.cb.att.com (richard.b.dell...
18843,From: westes@netcom.com (Will Estes)\nSubject:...
18844,From: steve@hcrlgw (Steven Collins)\nSubject: ...


In [61]:
df = pd.read_csv('./data/bbc_text_cls.csv')
maxlen = np.quantile(df.text.apply(len), q=0.99)
print(len(df))
df = df[df.text.apply(len) < maxlen].reset_index(drop=True)
df

2225


Unnamed: 0,text,labels
0,Ad sales boost Time Warner profit\n\nQuarterly...,business
1,Dollar gains on Greenspan speech\n\nThe dollar...,business
2,Yukos unit buyer faces loan claim\n\nThe owner...,business
3,High fuel prices hit BA's profits\n\nBritish A...,business
4,Pernod takeover talk lifts Domecq\n\nShares in...,business
...,...,...
2197,TV's future down the phone line\n\nInternet TV...,tech
2198,Cebit fever takes over Hanover\n\nThousands of...,tech
2199,BT program to beat dialler scams\n\nBT is intr...,tech
2200,Spam e-mails tempt net shoppers\n\nComputer us...,tech


In [37]:
a = jnp.array([1, 2, 3, 4, 5, 6, 7], dtype=int)
a = jnp.array(list(zip([1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7])), dtype=int)
bounds = jnp.array([0, 2, 7], dtype=int)
i = 2
t = 3
n = len(a)

shifts = jnp.arange(i * 2, -1, -1)

max_shift = i * 2 + n

pad_token = 0

result_matrix = np.full((max_shift, 2 * i + 1, t), fill_value=pad_token, dtype=int)

row_indices = jnp.arange(n)[:, None] + shifts[None, :]
col_indices = jnp.arange(2 * i + 1)

result_matrix[row_indices, col_indices] = a[:, None]
mask = jnp.all(result_matrix[:, i, :] == pad_token, axis=1)
result_matrix = result_matrix[~mask]

print(result_matrix.shape)
print(result_matrix)

(7, 5, 3)
[[[0 0 0]
  [0 0 0]
  [1 1 1]
  [2 2 2]
  [3 3 3]]

 [[0 0 0]
  [1 1 1]
  [2 2 2]
  [3 3 3]
  [4 4 4]]

 [[1 1 1]
  [2 2 2]
  [3 3 3]
  [4 4 4]
  [5 5 5]]

 [[2 2 2]
  [3 3 3]
  [4 4 4]
  [5 5 5]
  [6 6 6]]

 [[3 3 3]
  [4 4 4]
  [5 5 5]
  [6 6 6]
  [7 7 7]]

 [[4 4 4]
  [5 5 5]
  [6 6 6]
  [7 7 7]
  [0 0 0]]

 [[5 5 5]
  [6 6 6]
  [7 7 7]
  [0 0 0]
  [0 0 0]]]


In [15]:
gamma = 0.4
# C_i = gamma * (1 - gamma)**i
suffix_context_weights = np.cumprod(np.full(i, (1 - gamma))) * gamma
prefix_context_weights = suffix_context_weights[::-1]
context_weights = np.concatenate([
    prefix_context_weights,
    [0.],
    suffix_context_weights,
])
context_weights

array([0.144, 0.24 , 0.   , 0.24 , 0.144])

In [16]:
attn_matrix = np.ones((n, 2 * i + 1), dtype=int)

# prefix attention
ignored_mask_prefix = np.rot90(~np.triu(np.ones(i, dtype=bool)))

prefix_bounds = bounds[:-1]
shifts = np.ones((len(prefix_bounds), i), dtype=int)
shifts[:, 0] = prefix_bounds
shifts = np.cumsum(shifts, axis=1)

prefix_columns = np.arange(i)

attn_matrix[shifts.flatten()[:, None], prefix_columns] = np.tile(ignored_mask_prefix, reps=i).T

# suffix attention
ignored_mask_suffix = np.rot90(~np.tril(np.ones(i, dtype=bool)))

suffix_bounds = np.array(bounds[1:]) - i
shifts = np.ones((len(suffix_bounds), i), dtype=int)
shifts[:, 0] = suffix_bounds
shifts = np.cumsum(shifts, axis=1)

suffix_columns = np.arange(i + 1, i * 2 + 1)

attn_matrix[shifts.flatten()[:, None], suffix_columns] = np.tile(ignored_mask_suffix, reps=i).T

attn_matrix

array([[0, 0, 1, 1, 0],
       [0, 1, 1, 0, 0],
       [0, 0, 1, 1, 1],
       [0, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 0],
       [1, 1, 1, 0, 0]])

In [13]:
context_matrix = context_weights * attn_matrix
context_matrix /= context_matrix.sum(axis=1, keepdims=True)
context_matrix

array([[0.        , 0.        , 0.        , 1.        , 0.        ],
       [0.        , 1.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.625     , 0.375     ],
       [0.        , 0.38461538, 0.        , 0.38461538, 0.23076923],
       [0.1875    , 0.3125    , 0.        , 0.3125    , 0.1875    ],
       [0.23076923, 0.38461538, 0.        , 0.38461538, 0.        ],
       [0.375     , 0.625     , 0.        , 0.        , 0.        ]])

In [146]:
np.sum(context_matrix[..., None] * result_matrix, axis=1)

array([[2.        , 2.        , 2.        ],
       [1.        , 1.        , 1.        ],
       [4.375     , 4.375     , 4.375     ],
       [4.46153846, 4.46153846, 4.46153846],
       [5.        , 5.        , 5.        ],
       [5.53846154, 5.53846154, 5.53846154],
       [5.625     , 5.625     , 5.625     ]])

In [3]:
class ContextTopicModelDebug():
    """
    Topic model which uses local context of words
    """

    def __init__(
            self,
            ctx_len: int,
            vocab_size: int,
            n_topics: int = 10,
            gamma: float = 0.6,
            reg_list: list = None,
            eps: float = 1e-12,
    ):
        """
        Args:
            ctx_len: one-sided context size
            max_len: max length of a document, W_d
            vocab_size: corpus vocabulary size, W
            n_topics: number of topics, T
            regularizations: list of regularizations (see `add_regularization` method)
            eps: parameter set for balance between numerical stability and precision

        Note:
            - Total context of a word on `i`-th index is ctx_len words to the left,\\
            `ctx_len` words to the right, and the word itself
        """
        self.ctx_len = ctx_len
        self.vocab_size = vocab_size
        self.n_topics = n_topics
        self._gamma = gamma
        self._eps = eps

        self._context_weights_1d = self._calc_context_weights_1d()

        self._regularizations = dict()
        if reg_list is not None:
            for reg in reg_list:
                self.add_regularization(reg)

    def _norm(self, x: jax.Array) -> jax.Array:
        assert jnp.any(~jnp.isnan(x)), jnp.sum(x)
        # take x+ = max(x, 0) element-wise (perform projection on positive simplex)
        x = jnp.maximum(x, jnp.zeros_like(x))
        # normalize values in non-zero rows to 1
        # (mapping from the positive simplex to the unit simplex)
        norm = x.sum(axis=0)
        x = jnp.where(norm > self._eps, x / norm, jnp.zeros_like(x))
        return x

    def _calc_context_weights_1d(self) -> jax.Array:
        # C_i = gamma * (1 - gamma)**i
        suffix_context_weights = np.cumprod(np.full(self.ctx_len, (1 - self._gamma))) * self._gamma
        prefix_context_weights = suffix_context_weights[::-1]
        context_weights = np.concatenate([
            prefix_context_weights,
            [0.],  # ignoring the word itself when calculating the context
            suffix_context_weights,
        ])
        return context_weights  # (C, )

    def _create_context_coeff_matrix(self, batch_size: int, attn_bounds: jax.Array) -> jax.Array:
        attn_matrix = np.ones((batch_size, self.ctx_len * 2 + 1), dtype=bool)  # True where to attend

        # prefix attention (zero out words from the previous document)
        ignored_mask_prefix = np.rot90(~np.triu(np.ones((self.ctx_len, self.ctx_len), dtype=bool)))

        prefix_bounds = attn_bounds[:-1]
        shifts = np.ones((len(prefix_bounds), self.ctx_len), dtype=int)
        shifts[:, 0] = prefix_bounds
        shifts = np.cumsum(shifts, axis=1)

        prefix_columns = np.arange(self.ctx_len)
        attn_matrix[shifts.reshape(-1, 1), prefix_columns] = np.tile(ignored_mask_prefix, reps=len(prefix_bounds)).T

        # suffix attention (zero out words from the next document)
        ignored_mask_suffix = np.rot90(~np.tril(np.ones((self.ctx_len, self.ctx_len), dtype=bool)))

        suffix_bounds = attn_bounds[1:] - self.ctx_len
        shifts = np.ones((len(suffix_bounds), self.ctx_len), dtype=int)
        shifts[:, 0] = suffix_bounds
        shifts = np.cumsum(shifts, axis=1)

        suffix_columns = np.arange(self.ctx_len + 1, self.ctx_len * 2 + 1)
        attn_matrix[shifts.reshape(-1, 1), suffix_columns] = np.tile(ignored_mask_suffix, reps=len(suffix_bounds)).T

        # calculate context weights with respect to attention and normalize weights
        context_matrix = self._context_weights_1d * attn_matrix
        context_matrix /= context_matrix.sum(axis=1, keepdims=True)
        return context_matrix  # (I, C)

    def _construct_context_tensor(self, data: jax.Array) -> jax.Array:
        shifts = np.arange(self.ctx_len * 2, -1, -1)
        max_shift = self.ctx_len * 2 + len(data)

        pad_token = -10
        shifted_matrix = np.zeros((max_shift, self.ctx_len * 2 + 1, self.n_topics))
        shifted_matrix = np.full((max_shift, self.ctx_len * 2 + 1, self.n_topics), pad_token, dtype=float)

        row_indices = np.arange(len(data))[:, None] + shifts[None, :]
        col_indices = np.arange(self.ctx_len * 2 + 1)

        shifted_matrix[row_indices, col_indices] = data[:, None]

        pad_context_mask = jnp.all(jnp.isclose(shifted_matrix[:, self.ctx_len, :], pad_token), axis=1)
        shifted_matrix = shifted_matrix[~pad_context_mask]
        print(shifted_matrix.dtype)
        return shifted_matrix  # (I, C, T)

    def add_regularization(self, reg, tag: str = None):
        """
        Add `reg` regularization to the model with `tag` identifier \\
        Note:
        - `reg` has to be a child of base `Regularization` class
        - `tag` will use the name of the class by default
        """
        if tag is None:
            tag = reg.__name__
        if not isinstance(reg, Regularization):
            raise TypeError(f'Regularization [{tag}] has to be a subclass of Regularization class')

        try:
            self._regularizations[tag] = jax.grad(reg)
        except Exception:
            raise

    def _compose_regularizations(self):
        regs = self._regularizations.values()
        sum_reg = lambda x: sum([1.0, ] + [reg(x) for reg in regs])
        return jax.jit(jax.grad(sum_reg))

    def fit(self, data: jax.Array, doc_bounds, max_iter: int = 1000, tol: float = 1e-3, seed: int = 0):
        """
        Args:
            data: matrix of shape (D, W_d), containing tokenized words of each document
            max_iter: max number of iterations
            tol: early stopping threshold
            seed: random seed
        """
        key = jax.random.key(seed)
        self.phi = jax.random.uniform(
            key=key,
            shape=(self.vocab_size, self.n_topics),
        )  # (W, T)
        self.n_t = jnp.full(
            shape=(self.n_topics, ),
            fill_value=len(data) / self.n_topics,
        )  # (T, )
        grad_regularization = self._compose_regularizations()

        self.phi = self._norm(self.phi)
        t_cur = time.time()
        for it in range(max_iter):
            # Calculate phi' (words -> topics) matrix (phi with old p_{ti})
            print(np.array(self.phi).shape, np.array(self.n_t).shape, f'{time.time() - t_cur:.01f}')
            phi_hatch = self._norm(self.phi.T * self.n_t[:, None]).T  # (W, T)
            print(np.array(phi_hatch).shape, f'{time.time() - t_cur:.01f}')

            # Create theta (documents -> topics) matrix
            phi_it_hatch = jnp.take_along_axis(phi_hatch, indices=data[:, None], axis=0)  # (I, T)
            print(phi_it_hatch.shape, f'{time.time() - t_cur:.01f}')
            phi_it_hatch_block = self._construct_context_tensor(phi_it_hatch)  # (I, C, T)
            print(phi_it_hatch_block.shape, f'{time.time() - t_cur:.01f}')
            context_matrix = self._create_context_coeff_matrix(batch_size=data.shape[0], attn_bounds=doc_bounds)
            theta_it = np.sum(context_matrix[..., None] * phi_it_hatch_block, axis=1)  # (I, T)
            print(np.array(theta_it).shape, f'{time.time() - t_cur:.01f}')

            # Update p_{ti} - topic probability distribution for i-th context
            phi_it = jnp.take_along_axis(self.phi, indices=data[:, None], axis=0)  # (I, T)
            print(np.array(phi_it).shape, f'{time.time() - t_cur:.01f}')
            p_ti = self._norm((phi_it * theta_it).T).T  # (I, T)
            print(np.array(p_ti).shape, f'{time.time() - t_cur:.01f}')

            # Update n_{t} - topic probability distribution
            self.n_t = jnp.sum(p_ti, axis=0)  # (T, )
            print(np.array(self.n_t).shape, f'{time.time() - t_cur:.01f}')

            # Update phi (words -> topics) matrix (phi with new p_{ti})
            indices = data.flatten()  # (I, )
            phi_new = jnp.add.at(jnp.zeros_like(self.phi), indices, p_ti, inplace=False)  # (W, T)
            print(np.array(phi_new).shape, f'{time.time() - t_cur:.01f}')
            phi_new += self.phi * grad_regularization(self.phi)  # (W, T)
            print(np.array(phi_new).shape, f'{time.time() - t_cur:.01f}')
            phi_new = self._norm(phi_new)  # (W, T)

            diff_norm = jnp.linalg.norm(phi_new - self.phi)
            print(-jnp.sum(1 * jnp.log(np.sum(theta_it * phi_it, axis=1) + self._eps)) / len(data))
            # calculate perplexity
            res_diff_norm = jnp.exp(-jnp.sum(1 * jnp.log(np.sum(theta_it * phi_it, axis=1) + self._eps)) / len(data))
            print(f'Iteration [{it}/{max_iter}], update diff norm: {diff_norm:.04f}, perplexity: {res_diff_norm:.04f}')
            self.phi = phi_new
            if diff_norm < tol:
                break

In [4]:
dataset = ArticlesDataset(df.text.tolist())
len(dataset)

3414473

In [5]:
%%time

model = ContextTopicModelDebug(
    ctx_len=10,
    vocab_size=len(dataset.vocab),
    gamma=0.1,
    n_topics=10,
)
model.fit(*dataset.data)

(107671, 10) (10,) 0.0
(107671, 10) 0.2
(3414473, 10) 0.3
float64
(3414473, 21, 10) 2.4
(3414473, 10) 4.0
(3414473, 10) 4.0
(3414473, 10) 4.3
(10,) 4.3
(107671, 10) 4.5
(107671, 10) 4.6
11.597601
Iteration [0/1000], update diff norm: 0.1312, perplexity: 108836.3828
(107671, 10) (10,) 4.8
(107671, 10) 4.8
(3414473, 10) 4.8
float64
(3414473, 21, 10) 6.6
(3414473, 10) 8.2
(3414473, 10) 8.2
(3414473, 10) 8.3
(10,) 8.3
(107671, 10) 8.5
(107671, 10) 8.5
8.647864
Iteration [1/1000], update diff norm: 0.0261, perplexity: 5697.9648
(107671, 10) (10,) 8.5
(107671, 10) 8.5
(3414473, 10) 8.6


KeyboardInterrupt: 

In [107]:
# maxlen = 500, context = 10
topk = jnp.argsort(model.phi, axis=0, descending=True)[:10, :].T  # (T, W_{top})
reverse_vocab = {value: key for key, value in dataset.vocab.items()}

for t in topk:
    print('\t'.join([reverse_vocab[int(idx)] for idx in t]))
    print()

year	gun	law	government	u	bike	state	israeli	new	american

window	x	use	system	problem	get	need	anyone	drive	file

edu	writes	c	article	subject	apr	cc	news	andrew	uiuc

line	organization	subject	posting	host	nntp	university	distribution	ca	x

one	would	people	think	know	like	time	say	thing	get

thanks	e	please	mail	email	help	advance	p	u	address

com	writes	subject	article	apr	org	gov	netcom	nasa	access

PAD	ryerson	ryevm	keith	sola	employer	acps	tmi	teddy	polytechnical

game	team	player	year	win	last	v	co	play	pitt

PAD	rainer	elin	hochreiter	eeam	keith	excepted	omission	bye	donoghue

