In [21]:
from dataset import ConllDataset

file = '../data/english/train/wsj_train.first-1k.conll06'
english_training_data = ConllDataset(file)

In [4]:
from features import FeatureMapping
from tqdm import trange, tqdm
import concurrent.futures

def train_feature_extractor(feature_extractor, dataset, end=0):
    if end == 0:
        end = len(dataset)
    
    print('Training feature extractor')
    for i in tqdm(range(end)):
        sentence = dataset[i]
        feature_extractor.get(sentence)
    
    return feature_extractor


def get_training_data(dataset, feature_extractor, end=-1):
    if end == -1:
        end = len(dataset)

    feature_extractor.frozen = True
    output = []
    print('Extracting training dataset')
    for i in tqdm(range(end)):
        sentence = dataset[i]

        x = feature_extractor.get_permutations(sentence)
        x_true = feature_extractor.get(sentence)
        arcs = dataset[i].get_arcs()
        output.append((x, x_true, arcs))

    return output

In [43]:
import itertools
from multiprocessing import Pool

def get_training_data(dataset, feature_extractor, multiprocess=8):
    train_dataset = []
    feature_extractor.frozen = True
    batch = []
    print('Extracting training dataset')
    for i, instance in tqdm(enumerate(dataset, start=1), total=len(dataset)):
        batch.append(instance)

        if i % multiprocess == 0 or i == len(dataset):
            with Pool(multiprocess) as pool:
                output = pool.map(feature_extractor.get_permutations, batch)
                output_true = pool.map(feature_extractor.get, batch)
                
            for l in range(len(batch)):
                train_dataset.append([
                    output[l],
                    output_true[l],
                    batch[l].get_arcs()
                ])
            batch = []


    return train_dataset

In [17]:
feature_extractor = FeatureMapping()
feature_extractor = train_feature_extractor(feature_extractor, english_training_data)

Training feature extractor


100%|██████████| 1000/1000 [01:21<00:00, 12.23it/s]


In [44]:
train_dataset = get_training_data(english_training_data, feature_extractor, 32)

Extracting training dataset


 80%|███████▉  | 799/1000 [05:41<01:25,  2.34it/s]


KeyboardInterrupt: 

In [4]:
FeatureMapping.save(feature_extractor, './feature_extractor-first-1k-en.p')

In [5]:
from features import FeatureMapping
feature_extractor = FeatureMapping.load('./feature_extractor-first-1k-en.p')

In [59]:
import gzip, pickle

with gzip.open('train.p','wb') as stream:
    pickle.dump(train_dataset,stream,-1)

In [60]:
import gzip, pickle

with gzip.open('train.p','rb') as stream:
    train_dataset= pickle.load(stream)

# Already loaded

In [1]:
from model import AveragePerceptron
from features import FeatureMapping
import gzip, pickle

feature_extractor = FeatureMapping()
feature_extractor = FeatureMapping.load('./feature_extractor-first-1k-en.p')

with gzip.open('train.p','rb') as stream:
    train_dataset= pickle.load(stream)

model = AveragePerceptron(dim=feature_extractor.feature_count())

In [2]:
model.train(train_dataset, epoch=10)
model.weight

Start training at 10 epochs


100%|██████████| 100/100 [00:00<00:00, 169.73it/s]


UAS: 0.6406533575317604


100%|██████████| 100/100 [00:00<00:00, 174.42it/s]


UAS: 0.652994555353902


100%|██████████| 100/100 [00:00<00:00, 175.73it/s]


UAS: 0.6584392014519056


100%|██████████| 100/100 [00:00<00:00, 181.22it/s]


UAS: 0.6595281306715064


100%|██████████| 100/100 [00:00<00:00, 184.74it/s]


UAS: 0.6696914700544465


100%|██████████| 100/100 [00:00<00:00, 157.13it/s]


UAS: 0.6595281306715064


100%|██████████| 100/100 [00:00<00:00, 173.49it/s]


UAS: 0.6617059891107078


100%|██████████| 100/100 [00:00<00:00, 178.77it/s]


UAS: 0.6606170598911071


100%|██████████| 100/100 [00:00<00:00, 181.39it/s]


UAS: 0.6631578947368421


100%|██████████| 100/100 [00:00<00:00, 193.99it/s]

UAS: 0.6656987295825771
Finished training.





array([1954.5400000000072, -314.1599999999992, 0.08999999999999941, 0.0,
       4772.260000000016, -5935.500000000004, 58204.33999999988,
       77730.50000000015, 65485.01000000007, 11574.089999999976,
       70577.32000000015, 17086.67], dtype=object)