-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
tfidf.py
145 lines (120 loc) · 5.93 KB
/
tfidf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""Scikit learn interface for :class:`~gensim.models.tfidfmodel.TfidfModel`.
Follows scikit-learn API conventions to facilitate using gensim along with scikit-learn.
Examples
--------
>>> from gensim.test.utils import common_corpus, common_dictionary
>>> from gensim.sklearn_api import TfIdfTransformer
>>>
>>> # Transform the word counts inversely to their global frequency using the sklearn interface.
>>> model = TfIdfTransformer(dictionary=common_dictionary)
>>> tfidf_corpus = model.fit_transform(common_corpus)
"""
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError
from gensim.models import TfidfModel
import gensim
class TfIdfTransformer(TransformerMixin, BaseEstimator):
"""Base TfIdf module, wraps :class:`~gensim.models.tfidfmodel.TfidfModel`.
For more information please have a look to `tf-idf <https://en.wikipedia.org/wiki/Tf%E2%80%93idf>`_.
"""
def __init__(self, id2word=None, dictionary=None, wlocal=gensim.utils.identity,
wglobal=gensim.models.tfidfmodel.df2idf, normalize=True, smartirs="ntc",
pivot=None, slope=0.65):
"""
Parameters
----------
id2word : {dict, :class:`~gensim.corpora.Dictionary`}, optional
Mapping from int id to word token, that was used for converting input data to bag of words format.
dictionary : :class:`~gensim.corpora.Dictionary`, optional
If specified it will be used to directly construct the inverse document frequency mapping.
wlocals : function, optional
Function for local weighting, default for `wlocal` is :func:`~gensim.utils.identity` which does nothing.
Other options include :func:`math.sqrt`, :func:`math.log1p`, etc.
wglobal : function, optional
Function for global weighting, default is :func:`~gensim.models.tfidfmodel.df2idf`.
normalize : bool, optional
It dictates how the final transformed vectors will be normalized. `normalize=True` means set to unit length
(default); `False` means don't normalize. You can also set `normalize` to your own function that accepts
and returns a sparse vector.
smartirs : str, optional
SMART (System for the Mechanical Analysis and Retrieval of Text) Information Retrieval System,
a mnemonic scheme for denoting tf-idf weighting variants in the vector space model.
The mnemonic for representing a combination of weights takes the form XYZ,
for example 'ntc', 'bpn' and so on, where the letters represents the term weighting of the document vector.
Term frequency weighing:
* `n` - natural,
* `l` - logarithm,
* `a` - augmented,
* `b` - boolean,
* `L` - log average.
Document frequency weighting:
* `n` - none,
* `t` - idf,
* `p` - prob idf.
Document normalization:
* `n` - none,
* `c` - cosine.
For more info, visit `"Wikipedia" <https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System>`_.
pivot : float, optional
It is the point around which the regular normalization curve is `tilted` to get the new pivoted
normalization curve. In the paper `Amit Singhal, Chris Buckley, Mandar Mitra:
"Pivoted Document Length Normalization" <http://singhal.info/pivoted-dln.pdf>`_ it is the point where the
retrieval and relevance curves intersect.
This parameter along with slope is used for pivoted document length normalization.
Only when `pivot` is not None pivoted document length normalization will be applied else regular TfIdf
is used.
slope : float, optional
It is the parameter required by pivoted document length normalization which determines the slope to which
the `old normalization` can be tilted. This parameter only works when pivot is defined by user and is not
None.
"""
self.gensim_model = None
self.id2word = id2word
self.dictionary = dictionary
self.wlocal = wlocal
self.wglobal = wglobal
self.normalize = normalize
self.smartirs = smartirs
self.slope = slope
self.pivot = pivot
def fit(self, X, y=None):
"""Fit the model according to the given training data.
Parameters
----------
X : iterable of iterable of (int, int)
Input corpus
Returns
-------
:class:`~gensim.sklearn_api.tfidf.TfIdfTransformer`
The trained model.
"""
self.gensim_model = TfidfModel(
corpus=X, id2word=self.id2word, dictionary=self.dictionary, wlocal=self.wlocal,
wglobal=self.wglobal, normalize=self.normalize, smartirs=self.smartirs,
pivot=self.pivot, slope=self.slope
)
return self
def transform(self, docs):
"""Get the tf-idf scores in BoW representation for `docs`
Parameters
----------
docs: {iterable of list of (int, number), list of (int, number)}
Document or corpus in BoW format.
Returns
-------
iterable of list (int, float) 2-tuples.
The BOW representation of each document. Will have the same shape as `docs`.
"""
if self.gensim_model is None:
raise NotFittedError(
"This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method."
)
# input as python lists
if isinstance(docs[0], tuple):
docs = [docs]
return [self.gensim_model[doc] for doc in docs]