Skip to content

Commit

Permalink
mantéïa
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed May 18, 2020
1 parent fe71557 commit 3cc2d3d
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 47 deletions.
45 changes: 21 additions & 24 deletions Exemples/exemple_Generation.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# exemple_Data.py
#
# Copyright 2020 Yves <yves@mercadier>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
#
#

from Manteia.Generation import Generation

from Manteia.Dataset import Dataset
from Manteia.Model import *

def main(args):

Generation(seed='What do you do if a bird shits on your car?')

ds=Dataset('Short_Jokes')

model = Model(model_name ='gpt2-medium')
text_loader = Create_DataLoader_generation(ds.documents_train[:3000])
model.load_tokenizer()
model.load_class()
model.devices()
model.configuration(text_loader)

gn=Generation(model)

gn.model.fit_generation(text_loader)
output = model.predict_generation('What did you expect ?')
output_text = decode_text(output,model.tokenizer)
print(output_text)

return 0

if __name__ == '__main__':
Expand Down
46 changes: 31 additions & 15 deletions Manteia/Generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import datetime
import gc
############
from .Model import *
from .Preprocess import Preprocess
from Manteia.Model import *
from Manteia.Preprocess import Preprocess

class Generation:
r"""
Expand All @@ -34,22 +34,38 @@ class Generation:
Example::
from Manteia.Generation import Generation
Generation(seed='What do you do if a bird shits on your car?')
Attributes:
"""
def __init__(self,model_name ='gpt2-medium',documents = None,seed = None):
from Manteia.Dataset import Dataset
from Manteia.Model import *
ds=Dataset('Short_Jokes')
model = Model(model_name ='gpt2-medium')
text_loader = Create_DataLoader_generation(ds.documents_train[:3000])
model.load_tokenizer()
model.load_class()
model.devices()
model.configuration(text_loader)
gn=Generation(model)
gn.model.fit_generation(text_loader)
output = model.predict_generation('What did you expect ?')
output_text = decode_text(output,model.tokenizer)
print(output_text)
"""
def __init__(self,model = None,documents = None,seed = None):

model = Model(model_name =model_name)
model.load()
model.BATCH_SIZE = 16
model.EPOCHS = 10
model.LEARNING_RATE = 3e-5
model.WARMUP_STEPS = 500
model.MAX_SEQ_LEN = 400
if model is None:self.model = Model(model_name ='gpt2-medium')
else : self.model=model
#model.load()
self.model.BATCH_SIZE = 16
self.model.EPOCHS = 2
self.model.LEARNING_RATE = 3e-5
self.model.WARMUP_STEPS = 500
self.model.MAX_SEQ_LEN = 400
if documents!=None:
text_loader = Create_DataLoader_generation(documents)
model.fit_generation(text_loader)
Expand Down
10 changes: 6 additions & 4 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def predict(self,predict_dataloader,p_type='class'):
return predictions

def fit_generation(self,text_loader):
self.model.to(self.device)

self.model.train()
optimizer = AdamW(self.model.parameters(), lr=self.LEARNING_RATE)
Expand All @@ -362,7 +363,7 @@ def fit_generation(self,text_loader):
if self.verbose==True:
print('EPOCH :'+str(epoch ))

for idx,text in enumerate(text_loader):
for step,text in enumerate(text_loader):

#################### "Fit as many joke sequences into MAX_SEQ_LEN sequence as possible" logic start ####
joke_tens = torch.tensor(self.tokenizer.encode(text[0])).unsqueeze(0).to(self.device)
Expand All @@ -386,25 +387,26 @@ def fit_generation(self,text_loader):
continue
################## Sequence ready, process it trough the model ##################

outputs = model(work_jokes_tens, labels=work_jokes_tens)
outputs = self.model(work_jokes_tens, labels=work_jokes_tens)
loss, logits = outputs[:2]
loss.backward()
sum_loss = sum_loss + loss.detach().data

proc_seq_count = proc_seq_count + 1
if proc_seq_count == BATCH_SIZE:
if proc_seq_count == self.batch_size:
proc_seq_count = 0
batch_count += 1
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model.zero_grad()
self.model.zero_grad()

if batch_count == 100:
if self.verbose==True:
print("sum loss :"+str(sum_loss))
batch_count = 0
sum_loss = 0.0
progress(count=step+1, total=len(text_loader))


def predict_generation(self,seed):
Expand Down
Binary file modified docs/_build/doctrees/Generation.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
21 changes: 18 additions & 3 deletions docs/_build/html/Generation.html
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
<span id="generation"></span><h1>Generation<a class="headerlink" href="#module-Manteia.Generation" title="Permalink to this headline"></a></h1>
<span class="target" id="module-Generation"></span><dl class="py class">
<dt id="Manteia.Generation.Generation">
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Generation.</code><code class="sig-name descname">Generation</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">model_name</span><span class="o">=</span><span class="default_value">'gpt2-medium'</span></em>, <em class="sig-param"><span class="n">documents</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">seed</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Generation.Generation" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Generation.</code><code class="sig-name descname">Generation</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">model</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">documents</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">seed</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Generation.Generation" title="Permalink to this definition"></a></dt>
<dd><p>This is the class to gnerate text in categorie a NLP task.</p>
<p>Args:</p>
<blockquote>
Expand All @@ -183,11 +183,26 @@
</div></blockquote>
<p>Example:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">Manteia.Generation</span> <span class="kn">import</span> <span class="n">Generation</span>
<span class="kn">from</span> <span class="nn">Manteia.Dataset</span> <span class="kn">import</span> <span class="n">Dataset</span>
<span class="kn">from</span> <span class="nn">Manteia.Model</span> <span class="kn">import</span> <span class="o">*</span>

<span class="n">Generation</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="s1">&#39;What do you do if a bird shits on your car?&#39;</span><span class="p">)</span>
<span class="n">ds</span><span class="o">=</span><span class="n">Dataset</span><span class="p">(</span><span class="s1">&#39;Short_Jokes&#39;</span><span class="p">)</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="n">model_name</span> <span class="o">=</span><span class="s1">&#39;gpt2-medium&#39;</span><span class="p">)</span>
<span class="n">text_loader</span> <span class="o">=</span> <span class="n">Create_DataLoader_generation</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">documents_train</span><span class="p">[:</span><span class="mi">3000</span><span class="p">])</span>
<span class="n">model</span><span class="o">.</span><span class="n">load_tokenizer</span><span class="p">()</span>
<span class="n">model</span><span class="o">.</span><span class="n">load_class</span><span class="p">()</span>
<span class="n">model</span><span class="o">.</span><span class="n">devices</span><span class="p">()</span>
<span class="n">model</span><span class="o">.</span><span class="n">configuration</span><span class="p">(</span><span class="n">text_loader</span><span class="p">)</span>

<span class="n">gn</span><span class="o">=</span><span class="n">Generation</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>

<span class="n">gn</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">fit_generation</span><span class="p">(</span><span class="n">text_loader</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict_generation</span><span class="p">(</span><span class="s1">&#39;What did you expect ?&#39;</span><span class="p">)</span>
<span class="n">output_text</span> <span class="o">=</span> <span class="n">decode_text</span><span class="p">(</span><span class="n">output</span><span class="p">,</span><span class="n">model</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_text</span><span class="p">)</span>
</pre></div>
</div>
<p>Attributes:</p>
</dd></dl>

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/_build/html/searchindex.js

Large diffs are not rendered by default.

File renamed without changes.

0 comments on commit 3cc2d3d

Please sign in to comment.