Skip to content

Commit

Permalink
e
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed May 15, 2020
1 parent dba716e commit 219a1f7
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 35 deletions.
52 changes: 35 additions & 17 deletions Manteia/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,25 @@ class Dataset:
r"""
This is the class to give datasets.
Args:
dataset_name (:obj:`string`, optional, defaults to ''):
Name of the dataset.
Example::
* **name** - name of the dataset (str)
* **train** - load the dataset train Default: ‘True’.
* **test** - load the dataset test Default: ‘False’.
* **dev** - load the dataset dev Default: ‘False’.
* **description** - load description Default: ‘False’.
* **url** -
* **verbose** -
* **path** - Path to the data file.
.. code-block:: python
Attributes:
print('hello')
"""
def __init__(self,name='20newsgroups',path='./dataset',verbose=True):

def __init__(self,name='20newsgroups',train=True,test=False,dev=False,description=False,url=False,path='./dataset',verbose=True):
r"""
"""
self.name=name
self.path=path
self.verbose=verbose
Expand Down Expand Up @@ -91,6 +97,11 @@ def load(self):
self.load_Yelp_Review_Polarity()

def load_20newsgroups(self):
r"""
This is the function to give 20newsgroups datasets.
"""

if self.verbose:
print('Downloading 20newsgroups...')
#categorie = ['sci.crypt', 'sci.electronics','sci.med', 'sci.space']
Expand All @@ -102,7 +113,14 @@ def load_20newsgroups(self):
self.labels_train.append(categorie[twenty_train.target[i]])

def load_Yelp_Review_Polarity(self):

"""
Defines YelpReviewPolarity datasets.
The labels includes:
* 0 : Negative polarity.
* 1 : Positive polarity.
"""
self.path_dir = os.path.join(self.path,'yelp_review_polarity')
#!!!!!!!!!!!!!!!!!!!!
self.del_dir(self.path_dir)
Expand Down Expand Up @@ -158,7 +176,7 @@ def load_Yelp_Review_Full(self):
self.description+=row

def load_Yahoo_Answers(self):
"""
r"""
Example Yahoo_Answers::
from Manteia.Dataset import Dataset
Expand All @@ -171,7 +189,7 @@ def load_Yahoo_Answers(self):
"""
self.path_dir = os.path.join(self.path,'yahoo_answers')
#!!!!!!!!!!!!!!!!!!!!
self.del_dir(self.path_dir)
#self.del_dir(self.path_dir)
#!!!!!!!!!!!!!!!!!!!!
if not os.path.isdir(self.path_dir):
os.mkdir(self.path_dir)
Expand Down Expand Up @@ -249,7 +267,7 @@ def load_Amazon_Review_Polarity(self):

self.path_dir = os.path.join(self.path,'amazon_review_polarity')
#!!!!!!!!!!!!!!!!!!!!
self.del_dir(self.path_dir)
#self.del_dir(self.path_dir)
#!!!!!!!!!!!!!!!!!!!!
if not os.path.isdir(self.path_dir):
os.mkdir(self.path_dir)
Expand Down Expand Up @@ -307,7 +325,7 @@ def load_Amazon_Review_Full(self):
self.description+=row

def load_DBPedia(self):
"""
r"""
Example DBPedia::
from Manteia.Dataset import Dataset
Expand Down Expand Up @@ -452,7 +470,7 @@ def load_yelp(self):
print("\tCompleted!")

def load_drugscom(self):
"""
r"""
Example pubmed_rct20k::
from Manteia.Dataset import Dataset
Expand Down Expand Up @@ -560,7 +578,7 @@ def load_SST_B(self):
self.labels_test = df_test['label'].values

def load_pubmed_rct20k(self):
"""
r"""
Example pubmed_rct20k::
from Manteia.Dataset import Dataset
Expand Down
6 changes: 3 additions & 3 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,6 @@ def pad_sequence(sequence=None,MAX_SEQ_LEN=None,pad='post'):
def encode_label(labels,list_labels):

def label_int(label):
print(list_labels)
print(label)
if label in list_labels:
idx_label=list_labels.index(label)
return idx_label
Expand Down Expand Up @@ -626,7 +624,7 @@ def __call__(self, acc_validation , model,device_model):
def save_checkpoint(self, acc_validation, model,device_model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print('Validation accuracy increased ({:.6f} --> {acc_validation:.6f}). Saving model ...'.format(self.acc_validation_min))
print('Validation accuracy increased ({:.6f} --> {:.6f}). Saving model ...'.format(self.acc_validation_min,acc_validation))
if not os.path.isdir(self.path):
# define the name of the directory to be created
try:
Expand All @@ -638,7 +636,9 @@ def save_checkpoint(self, acc_validation, model,device_model):
#save by torch
device = torch.device('cpu')
model.to(device)
print(type(model))
torch.save(model.module.state_dict(),self.path+'state_dict_validation.pt')

model.to(device_model)
#save by transformer
#model.save_pretrained(self.path)
Expand Down
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ You can install it with pip :
Example of use Classification :


from Manteia.Classification import Classification
documents=['a text','text b']
labels=['a','b']'
Classification(model_name ='roberta',documents,labels,process=True)
from Manteia.Classification import Classification
from Manteia.Model import Model
documents = ['What should you do before criticizing Pac-Man? WAKA WAKA WAKA mile in his shoe.','What did Arnold Schwarzenegger say at the abortion clinic? Hasta last vista, baby.',]
labels = ['funny','not funny']
model = Model(model_name ='roberta')
cl=Classification(model,documents,labels,process_classif=True)


Example of use Generation :
Expand Down
Binary file modified docs/_build/doctrees/Dataset.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
43 changes: 33 additions & 10 deletions docs/_build/html/Dataset.html
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,27 @@
<span id="dataset"></span><h1>Dataset<a class="headerlink" href="#module-Manteia.Dataset" title="Permalink to this headline"></a></h1>
<span class="target" id="module-Dataset"></span><dl class="py class">
<dt id="Manteia.Dataset.Dataset">
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Dataset.</code><code class="sig-name descname">Dataset</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">name</span><span class="o">=</span><span class="default_value">'20newsgroups'</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./dataset'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Dataset.</code><code class="sig-name descname">Dataset</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">name</span><span class="o">=</span><span class="default_value">'20newsgroups'</span></em>, <em class="sig-param"><span class="n">train</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">test</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">dev</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">description</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">url</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./dataset'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset" title="Permalink to this definition"></a></dt>
<dd><p>This is the class to give datasets.</p>
<p>Args:</p>
<blockquote>
<div><dl class="simple">
<dt>dataset_name (<code class="xref py py-obj docutils literal notranslate"><span class="pre">string</span></code>, optional, defaults to ‘’):</dt><dd><p>Name of the dataset.</p>
</dd>
</dl>
</div></blockquote>
<p>Example:</p>
<p>Attributes:</p>
<ul class="simple">
<li><p><strong>name</strong> - name of the dataset (str)</p></li>
<li><p><strong>train</strong> - load the dataset train Default: ‘True’.</p></li>
<li><p><strong>test</strong> - load the dataset test Default: ‘False’.</p></li>
<li><p><strong>dev</strong> - load the dataset dev Default: ‘False’.</p></li>
<li><p><strong>description</strong> - load description Default: ‘False’.</p></li>
<li><p><strong>url</strong> -</p></li>
<li><p><strong>verbose</strong> -</p></li>
<li><p><strong>path</strong> - Path to the data file.</p></li>
</ul>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="s1">&#39;hello&#39;</span><span class="p">)</span>
</pre></div>
</div>
<dl class="py method">
<dt id="Manteia.Dataset.Dataset.load_20newsgroups">
<code class="sig-name descname">load_20newsgroups</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset.load_20newsgroups" title="Permalink to this definition"></a></dt>
<dd><p>This is the function to give 20newsgroups datasets.</p>
</dd></dl>

<dl class="py method">
<dt id="Manteia.Dataset.Dataset.load_DBPedia">
<code class="sig-name descname">load_DBPedia</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset.load_DBPedia" title="Permalink to this definition"></a></dt>
Expand Down Expand Up @@ -209,6 +219,19 @@
</div>
</dd></dl>

<dl class="py method">
<dt id="Manteia.Dataset.Dataset.load_Yelp_Review_Polarity">
<code class="sig-name descname">load_Yelp_Review_Polarity</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset.load_Yelp_Review_Polarity" title="Permalink to this definition"></a></dt>
<dd><dl class="simple">
<dt>Defines YelpReviewPolarity datasets.</dt><dd><p>The labels includes:</p>
<ul class="simple">
<li><p>0 : Negative polarity.</p></li>
<li><p>1 : Positive polarity.</p></li>
</ul>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="Manteia.Dataset.Dataset.load_drugscom">
<code class="sig-name descname">load_drugscom</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset.load_drugscom" title="Permalink to this definition"></a></dt>
Expand Down
4 changes: 4 additions & 0 deletions docs/_build/html/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ <h2 id="G">G</h2>
<h2 id="L">L</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_20newsgroups">load_20newsgroups() (Manteia.Dataset.Dataset method)</a>
</li>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_DBPedia">load_DBPedia() (Manteia.Dataset.Dataset method)</a>
</li>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_drugscom">load_drugscom() (Manteia.Dataset.Dataset method)</a>
Expand All @@ -298,6 +300,8 @@ <h2 id="L">L</h2>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_pubmed_rct20k">load_pubmed_rct20k() (Manteia.Dataset.Dataset method)</a>
</li>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_Yahoo_Answers">load_Yahoo_Answers() (Manteia.Dataset.Dataset method)</a>
</li>
<li><a href="Dataset.html#Manteia.Dataset.Dataset.load_Yelp_Review_Polarity">load_Yelp_Review_Polarity() (Manteia.Dataset.Dataset method)</a>
</li>
</ul></td>
</tr></table>
Expand Down
Binary file modified docs/_build/html/objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/_build/html/searchindex.js

Large diffs are not rendered by default.

0 comments on commit 219a1f7

Please sign in to comment.