Skip to content

Commit

Permalink
Mantéïa
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed Jun 6, 2020
1 parent 91c881d commit bc2d74f
Show file tree
Hide file tree
Showing 9 changed files with 610 additions and 36 deletions.
6 changes: 6 additions & 0 deletions Exemples/exemple_Summarize1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from Manteia.Summarize import Summarize

documents = ['What should you do before criticizing Pac-Man? WAKA WAKA WAKA mile in his shoe.','What did Arnold Schwarzenegger say at the abortion clinic? Hasta last vista, baby.']

#model = Model(model_name ='bart')
su=Summarize(documents=documents)
9 changes: 9 additions & 0 deletions Exemples/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
20 changes: 20 additions & 0 deletions Manteia/ActiveLearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ def query(self,predictions,unlabeled_idx,nb_question):
print(entropie_trie[:nb_question])
return idx_entropie

class DAL():
"""
The basic discriminative strategy.
"""

def __init__(self,verbose=False):
self.verbose=verbose
if self.verbose:
print('DAL')

def query(self,predictions,unlabeled_idx,nb_question):
#dal est une liste de tuple idx non labélisé et probabilité de ne pas etre labellisé
dal=[(idx,p[1])for idx,p in zip(unlabeled_idx,predictions)]
dal=sorted(dal, key=itemgetter(1),reverse=True)
print(dal[:3])
idx_dal=[tup[0] for tup in dal[:nb_question]]
if self.verbose:
print(dal[:nb_question])
return idx_dal




2 changes: 1 addition & 1 deletion Manteia/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def load(self):
if self.name=="trec":
self.load_trec()

if self.name=="agnews":
if self.name=="Agnews":
self.load_agnews()

if self.name=="DBPedia":
Expand Down
70 changes: 37 additions & 33 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Model:
Attributes:
"""
def __init__(self,model_name ='bert',model_type=None,task='classification',num_labels=0,epochs=None,MAX_SEQ_LEN = 128,early_stopping=False,path='./model',verbose=True):
def __init__(self,model_name ='bert',model_type=None,task='classification',num_labels=0,epochs=None,MAX_SEQ_LEN = 128,early_stopping=False,path='./model/',verbose=True):

self.model_name = model_name
self.model_type = model_type
Expand Down Expand Up @@ -297,7 +297,7 @@ def configuration(self,train_dataloader,batch_size = 16,epochs = 20,n_gpu=1):
self.devices()


def fit(self,train_dataloader,validation_dataloader):
def fit(self,train_dataloader=None,validation_dataloader=None):

self.model.to(self.device)
#self.model.cuda()
Expand All @@ -315,7 +315,8 @@ def fit(self,train_dataloader,validation_dataloader):

print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, self.epochs))
print('Training :')
#print('')
#print('Training :')

t0 = time.time()
total_loss = 0
Expand Down Expand Up @@ -365,21 +366,22 @@ def fit(self,train_dataloader,validation_dataloader):
if self.verbose:
progress(count=step+1, total=len(train_dataloader))

if self.verbose==True:
print("")
if self.verbose:
print(" Average training loss: {0:.2f}".format(avg_train_loss))
print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
print("Validation :")

t0 = time.time()

self.model.eval()
print("")

tab_logits = None
tab_labels = None
#print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
if validation_dataloader is not None:
if self.verbose:
print("")
#print("Validation :")

t0 = time.time()
self.model.eval()
tab_logits = None
tab_labels = None

for step,batch in enumerate(validation_dataloader):
for step,batch in enumerate(validation_dataloader):

batch = tuple(t.to(self.device) for t in batch)

Expand All @@ -405,30 +407,34 @@ def fit(self,train_dataloader,validation_dataloader):
progress(count=step+1, total=len(validation_dataloader))


acc_validation=accuracy(tab_logits, tab_labels)
if self.verbose==True:
print("")
print(" Accuracy: {0:.2f}".format(acc_validation))
print(" Validation took: {:}".format(format_time(time.time() - t0)))

if self.early_stopping:
self.es(acc_validation, self.model,self.device)
acc_validation=accuracy(tab_logits, tab_labels)
if self.verbose==True:
print("")
print(" Validation : Accuracy : {0:.2f}".format(acc_validation))
#print(" Validation took: {:}".format(format_time(time.time() - t0)))

if self.early_stopping:
self.es(acc_validation, self.model,self.device)

if self.es.early_stop:
if self.verbose==True:
print("Early stopping")
break
if self.es.early_stop:
if self.verbose==True:
print("")
print("Early stopping")
break
#a la fin de l'entrainement on charge le meilleur model.
if self.early_stopping:
self.model.load_state_dict(torch.load(os.path.join(self.path,'state_dict_validation.pt')))
self.model.to(self.device)


if self.verbose==True:
print("")
print("Training complete!")

"""
p_type='class' or 'probability' or 'logits'
"""
def predict(self,predict_dataloader,p_type='class'):
def predict(self,predict_dataloader,p_type='class',mode='eval'):
'''
if self.early_stopping:
#by torch
Expand All @@ -446,10 +452,11 @@ def predict(self,predict_dataloader,p_type='class'):
if self.verbose==True:
print('loading model early...')
'''
#self.model.cuda()
self.model.to(self.device)

self.model.eval()
if mode=='eval':
self.model.eval()
if mode=='train':
self.model.train()
predictions = None
if self.verbose:
print('Predicting :')
Expand Down Expand Up @@ -793,11 +800,8 @@ def save_checkpoint(self, acc_validation, model,device_model):
#save by torch
device = torch.device('cpu')
model.to(device)
if self.verbose:
print(type(model))
#torch.save(model.module.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))
torch.save(model.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))

model.to(device_model)
#save by transformer
#model.save_pretrained(self.path)
Expand Down
11 changes: 11 additions & 0 deletions Manteia/Utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import sys,time
import numpy as np

def coss_validation_idx(nb_pass,nb_docs):
docs_idx = [idx for idx in range(nb_docs)]
train_idx, test_idx = [], []
for pli in range(nb_pass):
test_pli_idx = list(np.random.choice(docs_idx,int(len(docs_idx)/nb_pass) , replace=False))
train_pli_idx = [idx for idx in docs_idx if idx not in test_pli_idx]
train_idx.append(train_pli_idx)
test_idx.append(test_pli_idx)
return train_idx, test_idx

def progress(count, total):
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
Expand Down
2 changes: 1 addition & 1 deletion Manteia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
reminiscent.
"""

__version__ = "0.0.31"
__version__ = "0.0.32"


from Manteia import Classification
Expand Down

0 comments on commit bc2d74f

Please sign in to comment.