Skip to content

Commit

Permalink
Mantéïa
Browse files Browse the repository at this point in the history
  • Loading branch information
ym001 committed Jun 2, 2020
1 parent 0909c2f commit 91c881d
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 22 deletions.
3 changes: 2 additions & 1 deletion Manteia/ActiveLearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import math
from operator import itemgetter
import random

class RandomSampling():
"""
Expand All @@ -31,7 +32,7 @@ class UncertaintyEntropySampling():
The basic uncertainty sampling query strategy, querying the examples with the top entropy.
"""

def __init__(self,verbose=True):
def __init__(self,verbose=False):
self.verbose=verbose
if self.verbose:
print('UncertaintyEntropySampling')
Expand Down
12 changes: 6 additions & 6 deletions Manteia/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Dataset:
* **path** - Path to the data file.
"""
def __init__(self,name='20newsgroups',train=True,test=False,dev=False,classe=False,desc=False,path='./dataset',verbose=True):
def __init__(self,name='20newsgroups',train=True,test=False,dev=False,classe=True,desc=False,path='./dataset',verbose=True):
r"""
"""
self.name=name
Expand Down Expand Up @@ -507,23 +507,23 @@ def load_DBPedia(self):
download_and_extract(url, self.path)

self.path_classes = os.path.join(self.path_dir,'classes.txt')
classes=['']
classes=[]
if os.path.isfile(self.path_classes) and self.classe:
fi = open(self.path_classes, "r")
rows = fi.readlines()
for row in rows:
classes.append(row.strip())
classes.append(row.strip())
self.list_labels=classes

self.path_train = os.path.join(self.path_dir,'train.csv')
if os.path.isfile(self.path_train)and self.train:
fi = open(self.path_train, "r")
rows = fi.readlines()
for row in rows:
row.strip()
row=row.split(',')

self.documents_train.append(row[1].strip('"')+' '+row[2].strip('"'))
self.labels_train.append(classes[int(row[0])])
self.labels_train.append(classes[int(row[0])-1])
self.path_test = os.path.join(self.path_dir,'test.csv')
if os.path.isfile(self.path_test)and self.test:
fi = open(self.path_test, "r")
Expand Down Expand Up @@ -930,7 +930,7 @@ def load_pubmed_rct20k(self):
wget.download(url_test, out=path_dir)
wget.download(url_dev, out=path_dir)
if self.train:
path_file=os.path.join(path_dir,'train.txt')
path_file=os.path.join(self.path_dir,'train.txt')
fi = open(path_file, "r")
rows = fi.readlines()
for row in rows:
Expand Down
20 changes: 13 additions & 7 deletions Manteia/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def load_type(self):
self.model_type=model_dict[0]
else:
if self.model_type in model_dict:
print('type compatible')
if self.verbose:
print('type compatible : {}'.format(self.model_type))
else:
raise TypeError("{} Model type not in : {}".format(self.model_name,model_dict))

Expand Down Expand Up @@ -302,7 +303,7 @@ def fit(self,train_dataloader,validation_dataloader):
#self.model.cuda()

if self.early_stopping:
self.es=EarlyStopping(path=self.path)
self.es=EarlyStopping(path=self.path,verbose=self.verbose)

#if self.n_gpu > 1:
# self.model = torch.nn.DataParallel(self.model)
Expand Down Expand Up @@ -361,7 +362,8 @@ def fit(self,train_dataloader,validation_dataloader):
avg_train_loss = total_loss / len(train_dataloader)

loss_values.append(avg_train_loss)
progress(count=step+1, total=len(train_dataloader))
if self.verbose:
progress(count=step+1, total=len(train_dataloader))

if self.verbose==True:
print("")
Expand Down Expand Up @@ -399,7 +401,8 @@ def fit(self,train_dataloader,validation_dataloader):
else:tab_logits=np.append(tab_logits,np.argmax(logits, axis=1), axis=0)
if tab_labels is None:tab_labels=label_ids
else:tab_labels=np.append(tab_labels,label_ids, axis=0)
progress(count=step+1, total=len(validation_dataloader))
if self.verbose:
progress(count=step+1, total=len(validation_dataloader))


acc_validation=accuracy(tab_logits, tab_labels)
Expand Down Expand Up @@ -448,7 +451,8 @@ def predict(self,predict_dataloader,p_type='class'):

self.model.eval()
predictions = None
print('Predicting :')
if self.verbose:
print('Predicting :')

for step,batch in enumerate(predict_dataloader):

Expand All @@ -474,7 +478,8 @@ def predict(self,predict_dataloader,p_type='class'):
if p_type=='probability':
if predictions is None:predictions=torch.softmax(logits, dim=1).numpy()
else:predictions=np.append(predictions,torch.softmax(logits, dim=1).numpy(), axis=0)
progress(count=step+1, total=len(predict_dataloader))
if self.verbose:
progress(count=step+1, total=len(predict_dataloader))

return predictions

Expand Down Expand Up @@ -788,7 +793,8 @@ def save_checkpoint(self, acc_validation, model,device_model):
#save by torch
device = torch.device('cpu')
model.to(device)
print(type(model))
if self.verbose:
print(type(model))
#torch.save(model.module.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))
torch.save(model.state_dict(),os.path.join(self.path,'state_dict_validation.pt'))

Expand Down
42 changes: 39 additions & 3 deletions Manteia/Visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,42 @@ def plot_train(self,loss,accuracy,granularity=None):
if self.show:
plt.show()




def plot_multigraph(self,metric_dictionnary,title=None,labelx=None,labely=None,legend=True):

x_list=[]
y_list=[]

labely_list=list(metric_dictionnary.keys())
for graph in metric_dictionnary.keys():
x = []
y = []
for key,value in metric_dictionnary[graph].items():
x.append(float(key))
y.append(float(value))
x_list.append(x)
y_list.append(y)

tab_color=['red','black','blue','green','orange','purple']
plt.axis([0, 1, 0, 1])

plt.figure(1)
i=0
for cx,cy,color in zip(x_list,y_list,tab_color):
if labely_list[i]=='Base line':
courbe=plt.plot(cx, cy,'--',label='Base line',color=color,linewidth=2)
else:
courbe=plt.plot(cx, cy,label=labely_list[i],color=color,linewidth=2)
i=i+1
if title is not None:
plt.title(title)
if labelx is not None:
plt.xlabel(labelx)
if labely is not None:
plt.ylabel(labely)
if legend==True:
plt.legend(fontsize='large')
if self.save:
path=os.path.join(self.path,self.name)
plt.savefig(path, dpi=300)
if self.show:
plt.show()
2 changes: 1 addition & 1 deletion Manteia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
reminiscent.
"""

__version__ = "0.0.29"
__version__ = "0.0.31"


from Manteia import Classification
Expand Down
Binary file modified docs/_build/doctrees/ActiveLearning.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/Dataset.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/_build/html/ActiveLearning.html
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@

<dl class="py class">
<dt id="Manteia.ActiveLearning.UncertaintyEntropySampling">
<em class="property">class </em><code class="sig-prename descclassname">Manteia.ActiveLearning.</code><code class="sig-name descname">UncertaintyEntropySampling</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.ActiveLearning.UncertaintyEntropySampling" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">Manteia.ActiveLearning.</code><code class="sig-name descname">UncertaintyEntropySampling</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.ActiveLearning.UncertaintyEntropySampling" title="Permalink to this definition"></a></dt>
<dd><p>The basic uncertainty sampling query strategy, querying the examples with the top entropy.</p>
</dd></dl>

Expand Down
2 changes: 1 addition & 1 deletion docs/_build/html/Dataset.html
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
<span id="dataset"></span><h1>Dataset<a class="headerlink" href="#module-Manteia.Dataset" title="Permalink to this headline"></a></h1>
<span class="target" id="module-Dataset"></span><dl class="py class">
<dt id="Manteia.Dataset.Dataset">
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Dataset.</code><code class="sig-name descname">Dataset</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">name</span><span class="o">=</span><span class="default_value">'20newsgroups'</span></em>, <em class="sig-param"><span class="n">train</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">test</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">dev</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">classe</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">desc</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./dataset'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">Manteia.Dataset.</code><code class="sig-name descname">Dataset</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">name</span><span class="o">=</span><span class="default_value">'20newsgroups'</span></em>, <em class="sig-param"><span class="n">train</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">test</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">dev</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">classe</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">desc</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">path</span><span class="o">=</span><span class="default_value">'./dataset'</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#Manteia.Dataset.Dataset" title="Permalink to this definition"></a></dt>
<dd><p>This is the class description in order to get some dataset.</p>
<ul class="simple">
<li><p><strong>name</strong> - name of the dataset (str)</p></li>
Expand Down

0 comments on commit 91c881d

Please sign in to comment.