Skip to content

Commit

Permalink
Mantéïa
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed May 19, 2020
1 parent 3cc2d3d commit 3302b1d
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Exemples/exemple_Generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main(args):
ds=Dataset('Short_Jokes')

model = Model(model_name ='gpt2-medium')
text_loader = Create_DataLoader_generation(ds.documents_train[:3000])
text_loader = Create_DataLoader_generation(ds.documents_train[:10000],batch_size=32)
model.load_tokenizer()
model.load_class()
model.devices()
Expand Down
4 changes: 2 additions & 2 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def Create_DataLoader_predict(inputs,masks,batch_size=16):
ss = SequentialSampler(td)
return DataLoader(td, sampler=ss, batch_size=batch_size)

def Create_DataLoader_generation(text):
return DataLoader(TextDataset(text), batch_size=1, shuffle=True)
def Create_DataLoader_generation(text,batch_size=16):
return DataLoader(TextDataset(text), batch_size=batch_size, shuffle=True)

class TextDataset():
def __init__(self,list_texts):
Expand Down
2 changes: 1 addition & 1 deletion Manteia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
reminiscent.
"""

__version__ = "0.0.19"
__version__ = "0.0.20"


from Manteia import Classification
Expand Down
23 changes: 20 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,26 @@ Example of use Classification :
Example of use Generation :


from Manteia.Generation import Generation
Generation(seed='What do you do if a bird shits on your car?')
If you're a car owner, you're supposed to be able to call the police and have them take the bird off the car.
from Manteia.Generation import Generation
from Manteia.Dataset import Dataset
from Manteia.Model import *

ds=Dataset('Short_Jokes')

model = Model(model_name ='gpt2-medium')
text_loader = Create_DataLoader_generation(ds.documents_train[:10000],batch_size=32)
model.load_tokenizer()
model.load_class()
model.devices()
model.configuration(text_loader)

gn=Generation(model)

gn.model.fit_generation(text_loader)
output = model.predict_generation('What did you expect ?')
output_text = decode_text(output,model.tokenizer)
print(output_text)


[Documentation](https://manteia.readthedocs.io/en/latest/#)
[Pypi](https://pypi.org/project/Manteia/)
Expand Down
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
120 changes: 120 additions & 0 deletions notebook/notebook_Manteia_classification1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading roberta tokenizer...\n",
"Loading roberta class...\n",
"There are 2 GPU(s) available.\n",
"We will use the GPU: GeForce RTX 2080 Ti\n",
"\n",
"======== Epoch 1 / 4 ========\n",
"Training :\n",
" \n",
" Average training loss: 0.67\n",
" Training epoch took: 0:00:00\n",
"Validation :\n",
" \n",
" Accuracy: 0.00\n",
" Validation took: 0:00:00\n",
"\n",
"======== Epoch 2 / 4 ========\n",
"Training :\n",
" \n",
" Average training loss: 0.68\n",
" Training epoch took: 0:00:00\n",
"Validation :\n",
" \n",
" Accuracy: 0.00\n",
" Validation took: 0:00:00\n",
"\n",
"======== Epoch 3 / 4 ========\n",
"Training :\n",
" \n",
" Average training loss: 0.66\n",
" Training epoch took: 0:00:00\n",
"Validation :\n",
" \n",
" Accuracy: 0.00\n",
" Validation took: 0:00:00\n",
"\n",
"======== Epoch 4 / 4 ========\n",
"Training :\n",
" \n",
" Average training loss: 0.63\n",
" Training epoch took: 0:00:00\n",
"Validation :\n",
" \n",
" Accuracy: 0.00\n",
" Validation took: 0:00:00\n",
"\n",
"Training complete!\n",
"['not funny', 'not funny']\n"
]
}
],
"source": [
"from Manteia.Classification import Classification \n",
"from Manteia.Model import Model \n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"documents = [\n",
"\t\t\t' !?? What do you call a potato in space? Spudnik:::13 ;; // ',\n",
"\t\t\t'What should you do before criticizing Pac-Man? WAKA WAKA WAKA mile in his shoe.',\n",
"\t\t\t'What did Arnold Schwarzenegger say at the abortion clinic? Hasta last vista, baby.',\n",
"\t\t\t'Why do you never see elephants hiding in trees? \\'Cause they are freaking good at it',\n",
"\t\t\t'My son just got a tattoo of a heart, a spade, a club, and a diamond, all without my permission. I guess I\\'ll deal with him later.',\n",
"\t\t\t'Mom: \"Do you want this?\" Me: \"No.\" Mom: \"Ok I\\'ll give it to your brother.\" Me: \"No I want it.\"',\n",
"\t\t\t'Ibuprofen is my favorite headache medicine that also sounds like a reggae professor.',\n",
"\t\t\t'INTERVIEWER: Why do you want to work here? ME: *crumbs tumbling from my mouth* Oh, I don\\'t. I was just walking by and saw you had donuts.',\n",
"\t\t\t'I\\'ve struggled for years to be above the influence... But I\\'ve never been able to get that high',\n",
"\t\t\t'With Facebook, you can stay in touch with people you would otherwise never talk to, but that\\'s only one of the many awful things about it',\n",
"\t\t\t]\n",
"\t\t\t\n",
"labels = [\n",
"\t\t\t'funny','not funny','funny','not funny','funny','not funny','not funny','not funny','funny','not funny'\n",
"\t\t\t]\n",
"\t\t\t\n",
"model = Model(model_name ='roberta')\n",
"cl=Classification(model,documents,labels,process_classif=True)\n",
"print(cl.predict(documents[:2]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf_gpu",
"language": "python",
"name": "tf_gpu"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 3302b1d

Please sign in to comment.