# <CENTER> Approximate nearest neighbors</CENTER>

## Assignment (in French) by Guillaume Pitel

À partir des exemples donnés ici : 
https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html

1°) Faire une Latent Semantic Analysis de 20newsgroups. (TuncatedSVD avec
l’algo Randomized en 128 dimensions à partir de la matrice TF-IDF)

2°) Tirer aléatoirement 1 document de chaque newsgroup et calculer ses 8
plus proches voisins selon la distance cosinus (normaliser les vecteurs
représentant les documents) => il faut sortir les subjects des news et leur
newsgroup d’origine.

3°) Développer une Product Quantization, réduire tous les représentants de
document (utiliser l’algorithme K-Means de sklearn) avec les paramètres
suivants : K=256, faire les K-means sur des tranches de 8 dimensions (on aura
donc 16 sous-clusterings de 256).
    
4°) Calculer les voisins des mêmes documents avec la version PQ (calcul direct à
partir des distances inter-clusters, et non pas en recréant des vecteurs
complets).

## 1. Preparation

**Disclaimer** <BR> 
We factorized our code as much as we could. Therefore, our answers require a long preparation, as this material is used without or with little modifications for all questions. In other words, some attributes, functions or methods written here might actually concern later questions only.<BR>
Nevertheless, once the preparation is complete, answering to all questions is very quick.

### 1.1. Imports 

In [1]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

### 1.2. Constants

In [2]:
# Number of clusters for K-Means performed in Q3 (called K in the assignment)
N_CLUSTERS = 256 

# Output dimension for the SVD performed in Q1
N_COMPONENTS = 128 

# Number of subclusterings made during Product Quantization (PQ), see Q3 & Q4
N_SUBCLUSTERINGS = 16 

# Number of different topics in 20Newsgroups
N_TOPICS = 20 

# Dimension of a subvector in PQ
PQ_DIM = 8 

# Seed for sklearn algorithms using a random state
RANDOM_STATE = 42 

### 1.3. A class handling the 20Newsgroups dataset 

#### 1.3.1. Definition

Let's build a handy class that performs most of the operations we want on the 20Newsgroups dataset.

In [3]:
class My20Newsgroups:
  '''
  This class performs various operations on the 20Newsgroups dataset.
 
  Attributes:
  * self.train: the training part of the 20Newsgroups dataset,
  which consists in 11314 documents.
  * self.data: alias for self.train.data
  * self.tf_idf: TF-IDF matrix for the dataset.
  * self.svd: LSA (aka truncated SVD) on the TF-IDF matrix (see Q1).
  * self.n_components: output dimension for the SVD.
  Passed in as the only argument to the constructor.
  * self.sample: a random sample of 20 documents, one for each topic
  (see Q2). Stored as the tuple of the indexes of these 20 documents.
  '''

  def __init__(self, n_components=N_COMPONENTS):
    '''
    Constructor.
    '''
    # Import the dataset
    self.train = fetch_20newsgroups(data_home=None, 
                                    shuffle=True, 
                                    subset='train', 
                                    random_state=RANDOM_STATE, 
                                    download_if_missing=True)
    self.data = self.train.data

    # Get the TF-IDF matrix
    train_counts = CountVectorizer().fit_transform(self.data)
    self.tf_idf = TfidfTransformer().fit_transform(train_counts)              
    
    # Perform a LSA (aka truncated SVD) on it
    self.n_components = n_components
    self.svd = TruncatedSVD(self.n_components).fit_transform(self.tf_idf)

    # Sampling - explanations about the code:
    # c represents a topic.
    # For a given c, np.where(self.train.target == c)[0] is the list of all
    # indexes of documents in the dataset having topic c.
    # np.random.choice() picks at random an index in this list.
    self.sample = tuple(np.random.choice(np.where(self.train.target == c)[0]) \
                   for c in range(N_TOPICS))
           
  def subject(self, i):
    '''
    Get the line "Subject:" (second line) for the i-th document.
    The word "Subject:" itself is removed.
    '''
    return self.data[i].split('\n')[1][len('Subject: ') :]
      
  def topic(self, i):
    '''
    Get the topic (=label, or class, or newsgroup) for the i-th document.
    Return it as a string, e.g. 'comp.sys.mac.hardware'.
    '''
    return self.train.target_names[self.train.target[i]]

  def its(self, i):
    '''
    Return the triple (i, topic, subject) for the i-th document.
    '''
    return (i, self.topic(i), self.subject(i))

#### 1.3.2. Instantiation

This takes a little time, because the LSA is performed at this stage. 

In [4]:
dataset = My20Newsgroups()

#### 1.3.3. A few tests

In [5]:
# Number of documents
len(dataset.data)

11314

In [6]:
# Print the 42nd document
i = 42
for line in dataset.data[i].split('\n'):
  print(line)

From: ab245@cleveland.Freenet.Edu (Sam Latonia)
Subject: Re: Need phone number for Western Digital (ESDI problem)
Organization: Case Western Reserve University, Cleveland, Ohio (USA)
Lines: 5
NNTP-Posting-Host: slc10.ins.cwru.edu


Western Digital 1-800-832-4778.....Sam
-- 
Gosh..I think I just installed a virus..It was called MS DOS6...
Don't copy that floppy..BURN IT...I just love Windows...CRASH...



In [7]:
# Subject of the 42nd document
dataset.subject(i)

'Re: Need phone number for Western Digital (ESDI problem)'

In [8]:
# Topic (=class, returned as a string) of the 42nd document
dataset.topic(i)

'comp.sys.ibm.pc.hardware'

### 1.4. A class searching nearest neighbors in any dataset

#### 1.4.1. Definition

Let's build a class that takes a dataset passed in as the only argument to the constructor, and searches the nearest neighbors of a given element according to various metrics: the cosine distance for Q2, or the distance given by PQ for Q3 and Q4.

In [9]:
class NearestNeighborsFinder:
  '''  
  This class computes the nearest neighbors of a given element in
  a given dataset, w.r.t. various metrics.
  
  Attributes:
  * self.X: the data matrix, passed in as the only argument
  to the constructor.
  Vectors corresponding to the elements of the dataset must be
  stored as rows, and features as columns. If you followed the other
  convention, consider transposing X first.
  * self.norms: 1-D NumPy array of the Euclidean norms of the rows of X.
  * (self.PQ_labels, self.PQ_centroid_distances): result of PQ(),
  see the docstring of this method below. Used for Q3 and Q4.
  ''' 

  # For Q3 and Q4
  def PQ(self):
    '''
    Perform a PQ on the dataset, with parameters (constants) defined in §1.2.
    Return a pair (PQ_labels, PQ_centroid_distances) defined as follows.
    * 'PQ_labels' is a matrix of shape (n, r), where
    n = self.X.shape[0] and r = N_SUBCLUSTERINGS.
    Each row of 'PQ_labels' is the 1-D array of labels through PQ for the
    corresponding row of the dataset; in other words, PQ_labels[i,:] is
    the compression of self.X[i,:] through PQ.
    * 'PQ_centroid_distances' is a 3-D array of shape (r, k, k) with
    k = N_CLUSTERS. For the j-th K-Means performed during PQ,
    PQ_centroid_distances[j, :, :] is the matrix of the
    *squared* L2-distances between all centroids.
    '''
    # Shorter names for dims
    d, k, n, r = PQ_DIM, N_CLUSTERS, self.X.shape[0], N_SUBCLUSTERINGS 

    # Sanity check
    assert N_COMPONENTS == d * r, \
      "Incompatible values: N_COMPONENTS != PQ_DIM * N_SUBCLUSTERINGS"

    # Initialization 
    PQ_centroid_distances = np.zeros(shape=(r, k, k))
    PQ_labels = np.zeros(shape=(n, r), dtype=np.uint8)

    # Main loop
    for j in range(r):
      # Extract the matrix of subvectors
      sub_X = self.X[ : , j*d : (j+1)*d]
  
      # Train a KMeans on it
      kmeans = KMeans(n_clusters=k,
                      n_init=5,
                      max_iter=5,
                      random_state=RANDOM_STATE,
                      n_jobs=-1,
                      algorithm='elkan'
                     ).fit(sub_X)
  
      # Fill the labels accordingly
      PQ_labels[:, j] = kmeans.labels_

      # We compute the squared L2-distances between the centroids c_i as follows.
      # We have ||c_i-c_i'||**2 = G[i,i'] + G[i',i'] - 2*G[i,i'],
      # where G is the Gram matrix of the centroids. If the centroids are
      # stored as rows in a matrix C, then we easily have G = C @ C.T.
      # Let D denote the 1-D array equal to the diagonal part of G.
      # Then, the square matrix of the G[i',i'] (i' being a column index)
      # is N = np.tile(D, (k, 1)), and the square matrix of the G[i,i]
      # (i being a row index) is N.T. Therefore, we have
      # (*): ||c_i-c_i'||**2 = (N.T)[i,i'] + N[i,i'] - 2*G[i,i'],
      # whence the following code.
      C = kmeans.cluster_centers_ # matrix of centroids
      G = C @ C.T # Gram matrix of centroids
      D = np.diag(G) # this is a 1-D array
      N = np.tile(D, (k,1)) # all rows of N equal D
      PQ_centroid_distances[j, :, :] = N.T + N - 2 * G # code for (*)

    return PQ_labels, PQ_centroid_distances
  
  # For Q2, Q3 and Q4
  def __init__(self, X):
    '''
    Constructor. 
    Compute the norms of the rows of X and performs PQ on the whole dataset.
    '''
    self.X = X
    self.norms = np.linalg.norm(self.X, axis=-1)
    self.PQ_labels, self.PQ_centroid_distances = self.PQ()

  # For Q2
  def cosine_distances(self, i):
    '''
    Return a 1-D array storing a refinement of the cosine distances
    between the i-th row of self.X and all the other rows.
    It is a refinement in the sense that it is not the actual
    cosine distance, but an increasing function of it which is
    simpler to compute. This changes nothing as far as nearest
    neighbor search is concerned.
    '''
    dot_products = self.X @ self.X[i,:].T
    # Refinement: for the true cosine distance, we must actually
    # return
    #
    # 1 - dot_products / (self.norms * N)
    #
    # with N = np.linalg.norm(self.X[i,:]).
    # But as far as nearest neighbor search from X[i,:] is concerned,
    # there is no need to divide by N which is constant, nor to add 1,
    # so we may return the following:
    return - dot_products / self.norms
  
  # For Q4
  def PQ_distances(self, i):
    '''
    Return the 1-D array of the squared L2-distances between
    the i-th row of self.X and all the other rows, on which PQ
    has already been applied.
    '''
    n = self.X.shape[0]
    out = np.zeros(n)
    for h in range(n): # index for rows of X
      for j in range(N_SUBCLUSTERINGS): # index for subvectors, or K-Means
        label_h = self.PQ_labels[h, j]
        label_i = self.PQ_labels[i, j]  
        out[h] += self.PQ_centroid_distances[j, label_h, label_i] 
    return out
    
  # For Q2 and Q4
  def nearest_neighbors(self, n_neighbors, dist):
    '''
    Return, as a (n_neighbors+1)-tuple, the *indexes* of a given element
    x in the dataset and of its n_neighbors nearest neighbors.
    The index of x is returned first, then the one of its nearest
    neighbor, then the one of its second nearest neighbor, and so on.
    The neighbors are computed w.r.t. the 'dist' parameter,
    which must be a 1-D array containing the true distances,
    or *an increasing function of the distances*, between x and
    all elements in the dataset.
    '''
    return tuple(sorted
                 (range(self.X.shape[0]), key=lambda j: dist[j])
                 [: n_neighbors + 1])

#### 1.4.2. Instantiation

This takes a while (~1 min on a fast computer), because the PQ is performed at this stage.

In [10]:
nnf = NearestNeighborsFinder(dataset.svd)

### 1.5. A function searching nearest neighbors  in the 20Newsgroups dataset

While the previous 2 classes work independenty, searching nearest neighbors of a document in the 20Newsgroups dataset, like in Q2 and Q4, uses both classes through the function `nearest_neighbors()` defined below. 
Please read its docstring carefully to understand what is the format for our
answers to Q2 and Q4. Note that in particular, we return 9-tuples instead of 8-tuples, because
we nevertheless have to mention the current document before its 8 nearest neighbors.

In [11]:
def nearest_neighbors(n_neighbors, dataset, nnf, distance):
  '''
  Parameters:
  * n_neighbors: the number of neighbors we want.
  * dataset: instance of the My20Newsgroups class.
  * nnf: instance of the NearestNeighborsFinder class, typically:
  nnf = NearestNeighborsFinder(dataset.svd)
  * distance: 
  must be either the string 'cos' for the cosine distance as in Q2,
  or the string 'PQ' for the distance given by PQ as in Q4.

  Return a dictionary of (key, value) pairs, where:
  * each key is the name of a topic (there are 20 of them),
  * each value is a (1 + n_neighbors)-tuple, where:
    - the 0th component deals with the randomly chosen document corresponding
    to this topic,
    - the next component deals with its closest neighbor, w.r.t. 'distance',
    - the next component deals with the second closest neighbor,
    and so on;
    - each component is itself a triple (index, topic, subject)
    corresponding to the relevant document.
  '''
  
  if distance not in ['cos', 'PQ']:
    raise TypeError("distance must be either 'cos' or 'PQ'")    
  
  dict_ = {}
  for c in range(N_TOPICS):
    doc = dataset.sample[c]
    key = dataset.topic(doc)
    
    if distance == 'cos':
      dist = nnf.cosine_distances(doc)
    else:
      dist = nnf.PQ_distances(doc)
    
    nn_idx = nnf.nearest_neighbors(n_neighbors, dist)
    value = tuple(dataset.its(i) for i in nn_idx)
    
    dict_[key] = value
  
  return dict_

## 2. Answers

### 2.1. Question 1

**Our answer to this question is just the** `dataset.svd` **attribute.**

Note that the output dimension, namely `dataset.n_components`, is already set to the default value, which is 128. 
We merely check here that the output shape is correct: it must be `(11314, 128)`, as there are 11314 documents and we keep 128 features.

In [12]:
dataset.svd.shape

(11314, 128)

### 2.2. Question 2

#### 2.2.1. Sampling

In `dataset.sample`, we chose at random 1 document for each topic.
For each of them, let's print its `.its()` (=index, topic, subject):

In [13]:
for i in dataset.sample:
  print(dataset.its(i))

(8541, 'alt.atheism', 'Re: Genocide is Caused by Theism : Evidence?')
(6717, 'comp.graphics', 'Re: Rumours about 3DO ???')
(2557, 'comp.os.ms-windows.misc', 'Re: Utility for updating Win.ini and system.ini')
(6479, 'comp.sys.ibm.pc.hardware', 'ion: na')
(3555, 'comp.sys.mac.hardware', "Re: Centris 610 Video Problem - I'm having it also!")
(6945, 'comp.windows.x', 'Re: Animation with XPutImage()?')
(9931, 'misc.forsale', 'Onkyo 55w/ch integrated amp forsale:')
(6580, 'rec.autos', 'Re: New Alarm Proposal')
(4470, 'rec.motorcycles', 'Re: Observation re: helmets')
(10350, 'rec.sport.baseball', "I've found the secret!")
(5713, 'rec.sport.hockey', 'My Predictions of a classic playoff year!')
(8983, 'sci.crypt', 'Clipper chip -- technical details')
(3758, 'sci.electronics', 'Re: Illusion')
(1329, 'sci.med', 'Re: Migraines')
(993, 'sci.space', 'Re: Boom!  Whoosh......')
(5004, 'soc.religion.christian', 'Re: The arrogance of Christians')
(8281, 'talk.politics.guns', 'S4@psuvm.psu.edu>')
(10377,

#### 2.2.2. Nearest neighbors

We just need to call `nearest_neighbors()` (see §1.5) with the parameter `distance='cos'`:

In [14]:
nearest_neighbors(n_neighbors=8, dataset=dataset, nnf=nnf, distance='cos')

{'alt.atheism': ((8541,
   'alt.atheism',
   'Re: Genocide is Caused by Theism : Evidence?'),
  (8071, 'alt.atheism', 'Re: islamic genocide'),
  (3289, 'talk.religion.misc', 'Re: Albert Sabin'),
  (8900, 'alt.atheism', 'Re: free moral agency and Jeff Clark'),
  (5081, 'alt.atheism', 'Re: An Anecdote about Islam'),
  (8488, 'alt.atheism', 'Re: free moral agency and Jeff Clark'),
  (260, 'alt.atheism', 'Re: An Anecdote about Islam'),
  (9694, 'alt.atheism', 'Re: ISLAM: a clearer view'),
  (1534, 'alt.atheism', 'Re: Wholly Babble (Was Re: free moral agency)')),
 'comp.graphics': ((6717, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (5761, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (6142, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (10663, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (5328, 'comp.graphics', 'I need to make my VGA do shades.'),
  (11205, 'comp.sys.mac.hardware', 'Re: Disappointed by La Cie'),
  (9647, 'comp.windows.x', 'Re: Animation with XPutImage()?'),
 

### 2.3. Question 3

Recall from the docstring of the `NearestNeighborsFinder` class (§1.4) that **the reduction we want here, for all the documents in the dataset, is just its** `.PQ_labels` **attribute.** As far as our sample is concerned, we just need to extract the relevant 20 rows from it:

In [15]:
for i in dataset.sample:
  print("Its:", dataset.its(i), ", labels:", nnf.PQ_labels[i, :])

Its: (8541, 'alt.atheism', 'Re: Genocide is Caused by Theism : Evidence?') , labels: [242 208 238 114  36  45 180  55 224 176  26 115  16  13  52  53]
Its: (6717, 'comp.graphics', 'Re: Rumours about 3DO ???') , labels: [142 222 225 169 214  77  92  91 243 165  75 215 107   7  19   2]
Its: (2557, 'comp.os.ms-windows.misc', 'Re: Utility for updating Win.ini and system.ini') , labels: [ 13  85 164 145 107  39 149  49  88 223  85 146 121   6 204 218]
Its: (6479, 'comp.sys.ibm.pc.hardware', 'ion: na') , labels: [200  36 156 248  15  48 205 187 247 187 177 247 227 182   9 227]
Its: (3555, 'comp.sys.mac.hardware', "Re: Centris 610 Video Problem - I'm having it also!") , labels: [237  59 253  24 209  35 125 137  21 153 142  77  41 238 208  38]
Its: (6945, 'comp.windows.x', 'Re: Animation with XPutImage()?') , labels: [181 246 138 251 226 243 215 181 175  90  32  93  13  42 200 126]
Its: (9931, 'misc.forsale', 'Onkyo 55w/ch integrated amp forsale:') , labels: [123 158 251 171  31  14 229 129  9

### 2.4. Question 4

We just need to call `nearest_neighbors()` (see §1.5) with the parameter `distance='PQ'` (our sample is still the same):

In [16]:
nearest_neighbors(n_neighbors=8, dataset=dataset, nnf=nnf, distance='PQ')

{'alt.atheism': ((8541,
   'alt.atheism',
   'Re: Genocide is Caused by Theism : Evidence?'),
  (8071, 'alt.atheism', 'Re: islamic genocide'),
  (6651, 'alt.atheism', 'Re: The Inimitable Rushdie'),
  (3289, 'talk.religion.misc', 'Re: Albert Sabin'),
  (1785, 'sci.space', 'Re: Sunrise/ sunset times'),
  (252, 'alt.atheism', 'Re: Who Says the Apostles Were Tortured?'),
  (1534, 'alt.atheism', 'Re: Wholly Babble (Was Re: free moral agency)'),
  (3906, 'alt.atheism', "Re: A visit from the Jehovah's Witnesses"),
  (1236, 'misc.forsale', 'NEW AIRCRAFT TU-154M')),
 'comp.graphics': ((6717, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (6142, 'comp.graphics', 'Re: Rumours about 3DO ???'),
  (3306, 'sci.med', 'Re: Lasers for dermatologists'),
  (4501, 'sci.med', 'Re: Blood Glucose test strips'),
  (1038, 'sci.electronics', 'Homebuilt PAL (EPLD) programer?'),
  (5465,
   'comp.os.ms-windows.misc',
   'DOS6 - doublespace + stacker 3.0, is it okay?'),
  (9647, 'comp.windows.x', 'Re: Animation w