Skip to content

Commit

Permalink
Mantéïa
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed May 28, 2020
1 parent 644f5f1 commit 0909c2f
Show file tree
Hide file tree
Showing 14 changed files with 125 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Exemples/exemple_Classification1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

labels = ['funny','not funny']

model = Model(model_name ='bert')
model = Model(model_name ='bart')
cl=Classification(model,documents,labels,process_classif=True)
41 changes: 29 additions & 12 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self,model_name ='bert',model_type=None,task='classification',num_l
self.epochs = epochs
self.path = path
self.verbose = verbose
self.device = None
self.history = {}
self.history['loss'] = []
self.history['step'] = []
Expand Down Expand Up @@ -303,8 +304,8 @@ def fit(self,train_dataloader,validation_dataloader):
if self.early_stopping:
self.es=EarlyStopping(path=self.path)

if self.n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
#if self.n_gpu > 1:
# self.model = torch.nn.DataParallel(self.model)

loss_values = []

Expand Down Expand Up @@ -334,10 +335,10 @@ def fit(self,train_dataloader,validation_dataloader):

self.model.zero_grad()

if self.model_name != 'distilbert':
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
if self.model_name != 'distilbert' and self.model_name != 'bart':
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
else:
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)


#loss = outputs[0]
Expand Down Expand Up @@ -384,7 +385,7 @@ def fit(self,train_dataloader,validation_dataloader):

with torch.no_grad():

if self.model_name != 'distilbert':
if self.model_name != 'distilbert' and self.model_name != 'bart':
outputs = self.model(b_input_ids, token_type_ids=None,attention_mask=b_input_mask)
else:
outputs = self.model(b_input_ids,attention_mask=b_input_mask)
Expand Down Expand Up @@ -414,24 +415,34 @@ def fit(self,train_dataloader,validation_dataloader):
if self.verbose==True:
print("Early stopping")
break
#a la fin de l'entrainement on charge le meilleur model.
if self.early_stopping:
self.model.load_state_dict(torch.load(os.path.join(self.path,'state_dict_validation.pt')))

if self.verbose==True:
print("")
print("Training complete!")
"""
p_type='class' or 'probability' or 'logits'
"""
def predict(self,predict_dataloader,p_type='class'):
'''
if self.early_stopping:
#by torch
#pour charger uniquement la classe du modèle!
print('test')
self.load_type()
print('test')
self.load_class()
self.model.load_state_dict(torch.load(self.path+'state_dict_validation.pt'))
print('test')
self.model.load_state_dict(torch.load(os.path.join(self.path,'state_dict_validation.pt')))
print('test')
#by transformer
#self.model.from_pretrained(self.path)
if self.verbose==True:
print('loading model early...')
'''
#self.model.cuda()
self.model.to(self.device)

Expand Down Expand Up @@ -566,12 +577,16 @@ def save(self,file_name):
else:
print ("Successfully created the directory %s " % self.path)
self.model.to(torch.device('cpu'))
torch.save(self.model.module.state_dict(),os.path.join(self.path,file_name))
#torch.save(self.model.module.state_dict(),os.path.join(self.path,file_name))
torch.save(self.model.state_dict(),os.path.join(self.path,file_name))
self.model.to(self.device)

def load(self,file_name):
self.load_type()
self.load_class()
self.model.load_state_dict(torch.load(os.path.join(self.path,file_name)))
if self.device is None:
self.devices()
self.model.to(self.device)

def choose_from_top(probs, n=5):
Expand Down Expand Up @@ -734,8 +749,9 @@ def __init__(self, patience=2, delta=0,path=None, verbose=True):
self.acc_validation_min = 0
self.delta = delta
self.path=path
if os.path.isfile(self.path+'state_dict_validation.pt'):
os. remove(self.path+'state_dict_validation.pt')

if os.path.isfile(os.path.join(self.path,'state_dict_validation.pt')):
os. remove(os.path.join(self.path,'state_dict_validation.pt'))

def __call__(self, acc_validation , model,device_model):

Expand Down Expand Up @@ -773,7 +789,8 @@ def save_checkpoint(self, acc_validation, model,device_model):
device = torch.device('cpu')
model.to(device)
print(type(model))
torch.save(model.module.state_dict(),self.path+'state_dict_validation.pt')
#torch.save(model.module.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))
torch.save(model.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))

model.to(device_model)
#save by transformer
Expand Down
21 changes: 11 additions & 10 deletions Manteia/Summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ class Summarize:
"""
def __init__(self,model=None,documents = [],verbose=True):
def __init__(self,model=None,documents = [],verbose=True,process_summarize=False):

self.process_classif = process_classif
self.verbose = verbose
self.model = model
self.documents_train = documents_train
self.labels_train = labels_train
self.documents_test = documents_test
self.labels_test = labels_test
self.process_summarize = process_summarize
self.verbose = verbose
self.model = model
self.documents = documents

self.load_model()
inputs=self.process_text()
print(self.predict(inputs))
summary_ids = self.model.model.generate(inputs['input_ids'], num_beams=6, max_length=8, early_stopping=True)
print([self.model.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

#print(self.predict(inputs))

def load_model(self):
"""
Expand All @@ -66,6 +66,7 @@ def load_model(self):
"""
if self.model is None:
self.model = Model(model_name ='bart',model_type='bart-large-cnn',task='summarize')

self.model.load_type()
self.model.load_tokenizer()
self.model.load_class()
Expand All @@ -80,7 +81,7 @@ def process_text(self):
from Manteia.Summarize import Summarize
"""
inputs = self.model.tokenizer.batch_encode_plus(self.documents, max_length=1024, return_tensors='pt')
inputs = self.model.tokenizer.batch_encode_plus(self.documents, max_length=1024,pad_to_max_length=True, return_tensors='pt')
return inputs

def predict(self,inputs):
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ You can install it with pip :
For use with GPU and cuda we recommend the use of [Anaconda](https://www.anaconda.com/open-source) :

     __conda create -n manteia_env python=3.7__

     __conda activate manteia_env__

     __conda install pytorch__

     __pip install manteia__

Example of use Classification :
Expand Down
Binary file modified docs/_build/doctrees/Model.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/_build/doctrees/index.doctree
Binary file not shown.
25 changes: 24 additions & 1 deletion docs/_build/html/Model.html
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@

<dl class="py class">
<dt id="Manteia.Model.Model">
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Model.</code><code class="sig-name descname">Model</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">model_name</span><span class="o">=</span><span class="default_value">'bert'</span></em>, <em class="sig-param"><span class="n">model_type</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">num_labels</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">epochs</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">MAX_SEQ_LEN</span><span class="o">=</span><span class="default_value">128</span></em>, <em class="sig-param"><span class="n">early_stopping</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./model'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Model.Model" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Model.</code><code class="sig-name descname">Model</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">model_name</span><span class="o">=</span><span class="default_value">'bert'</span></em>, <em class="sig-param"><span class="n">model_type</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">task</span><span class="o">=</span><span class="default_value">'classification'</span></em>, <em class="sig-param"><span class="n">num_labels</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">epochs</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">MAX_SEQ_LEN</span><span class="o">=</span><span class="default_value">128</span></em>, <em class="sig-param"><span class="n">early_stopping</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./model'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Model.Model" title="Permalink to this definition"></a></dt>
<dd><p>This is the class to construct model.</p>
<p>Args:</p>
<blockquote>
Expand Down Expand Up @@ -219,6 +219,29 @@
</pre></div>
</div>
<p>Attributes:</p>
<dl class="py method">
<dt id="Manteia.Model.Model.predict">
<code class="sig-name descname">predict</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">predict_dataloader</span></em>, <em class="sig-param"><span class="n">p_type</span><span class="o">=</span><span class="default_value">'class'</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Model.Model.predict" title="Permalink to this definition"></a></dt>
<dd><dl>
<dt>if self.early_stopping:</dt><dd><p>#by torch
#pour charger uniquement la classe du modèle!
print(‘test’)
self.load_type()
print(‘test’)
self.load_class()
print(‘test’)
self.model.load_state_dict(torch.load(os.path.join(self.path,’state_dict_validation.pt’)))
print(‘test’)</p>
<p>#by transformer
#self.model.from_pretrained(self.path)
if self.verbose==True:</p>
<blockquote>
<div><p>print(‘loading model early…’)</p>
</div></blockquote>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py function">
Expand Down
21 changes: 21 additions & 0 deletions docs/_build/html/_sources/index.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@ of your results on several dataset ( 20newsgroups, Agnews Amazon Review Full, Am
| .. image:: images/train.png | .. image:: images/boxplot.png|
+---------------------------------+------------------------------+

Installation
------------

You can install it with pip :

     pip install Manteia

.. _Anaconda: https://www.anaconda.com/open-source>

For use with GPU and cuda we recommend the use of `Anaconda`_. :

     conda create -n manteia_env python=3.7

     conda activate manteia_env

     conda install pytorch

     pip install manteia



Classes
=======

Expand Down
4 changes: 4 additions & 0 deletions docs/_build/html/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,11 @@ <h2 id="P">P</h2>
<li><a href="Visualisation.html#Manteia.Visualisation.Visualisation.plot_boxplot">plot_boxplot() (Manteia.Visualisation.Visualisation method)</a>
</li>
<li><a href="Classification.html#Manteia.Classification.Classification.predict">predict() (Manteia.Classification.Classification method)</a>

<ul>
<li><a href="Model.html#Manteia.Model.Model.predict">(Manteia.Model.Model method)</a>
</li>
</ul></li>
<li>
Preprocess

Expand Down
10 changes: 10 additions & 0 deletions docs/_build/html/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ <h1>Simple Documentation Tutorial : Mantéïa<a class="headerlink" href="#simple
</tr>
</tbody>
</table>
<div class="section" id="installation">
<h2>Installation<a class="headerlink" href="#installation" title="Permalink to this headline"></a></h2>
<p>You can install it with pip :</p>
<p>     pip install Manteia</p>
<p>For use with GPU and cuda we recommend the use of <a class="reference external" href="https://www.anaconda.com/open-source&gt;">Anaconda</a>. :</p>
<p>     conda create -n manteia_env python=3.7</p>
<p>     conda activate manteia_env</p>
<p>     conda install pytorch</p>
<p>     pip install manteia</p>
</div>
</div>
<div class="section" id="classes">
<h1>Classes<a class="headerlink" href="#classes" title="Permalink to this headline"></a></h1>
Expand Down
Binary file modified docs/_build/html/objects.inv
Binary file not shown.

0 comments on commit 0909c2f

Please sign in to comment.