In [1]:
from collections import defaultdict
import itertools
from pathlib import Path
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam 
from torch.utils.data import DataLoader

from eff.data import get_word_list
from eff.data.dataset import CLTSDataset
from eff.model import lstm
from eff.train import get_class_weights_balanced, get_train_test_valid_split, generate_batch
from eff.train.dataset import TrainDataset, UnmaskedTestSet, \
    ConsonantMaskingTestSet, VowelMaskingTestSet
from eff.train.scripts import train, test
from eff.util import constants
from eff.util.util import save_results

torch.manual_seed(0)
torch.cuda.manual_seed(0)
import numpy as np
# np.random.seed(0)
base_path = Path("./out/nelex_unique")

In [2]:
word_list = get_word_list('northeuralex')

loading forms for northeuralex: 100%|██████████| 121612/121612 [00:02<00:00, 42773.28it/s]


In [3]:
batch_size = 32
n_layers = 2
embedding_size = 32
hidden_size = 256
dropout = 0.33
patience = 2

In [4]:
# lang_ids_dev = ['fin', 'tur', 'mnc', 'hun', 'khk', 'arb', 'ain', 'ekk', 'hye', 'eus']
# languages = [
#     lang for lang in word_list.languages
#     if lang.id.split("-")[1] in lang_ids_dev
# ]
# print(len(languages))

languages = word_list.languages

In [5]:
for language in languages:
    print(language.name)
    datasets = defaultdict(lambda: defaultdict(lambda: {}))
    res = defaultdict(lambda: defaultdict(lambda: {}))    
    lang_id = language.id.split("-")[1]
    clts_dataset = CLTSDataset(language.forms, unique_sequences=True)
    datasets['clts'] = clts_dataset

    train_words, valid_words, test_words = get_train_test_valid_split(clts_dataset.words, \
        test_size=0.3, valid_size=0.1)
    print("N= {}".format(len(test_words)))

    train_set = TrainDataset(
                    words=train_words,
                    input_alphabet=clts_dataset.input_alphabet,
                    output_alphabet=clts_dataset.output_alphabet,
                    bipa=clts_dataset.bipa,
                    masking=0.25
                )
    valid_set = TrainDataset( 
                    words=valid_words,
                    input_alphabet=clts_dataset.input_alphabet,
                    output_alphabet=clts_dataset.output_alphabet,
                    bipa=clts_dataset.bipa, 
                    masking=0.25
                )
    test_set = UnmaskedTestSet(
                    words=test_words,
                    input_alphabet=clts_dataset.input_alphabet,
                    output_alphabet=clts_dataset.output_alphabet,
                    bipa=clts_dataset.bipa
                )
    test_set_vowel = VowelMaskingTestSet(
                        words=test_words,
                        input_alphabet=clts_dataset.input_alphabet,
                        output_alphabet=clts_dataset.output_alphabet,
                        bipa=clts_dataset.bipa
                    )

    test_set_consonant = ConsonantMaskingTestSet(
                            words=test_words,
                            input_alphabet=clts_dataset.input_alphabet,
                            output_alphabet=clts_dataset.output_alphabet,
                            bipa=clts_dataset.bipa
                        )

    datasets['torch']['unmasked'] = test_set
    datasets['torch']['vowel_masking'] = test_set_vowel
    datasets['torch']['consonant_masking'] = test_set_consonant

    train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=generate_batch)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=generate_batch)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=generate_batch)
    test_loader_vowel = DataLoader(test_set_vowel, batch_size=batch_size, collate_fn=generate_batch)
    test_loader_consonant = DataLoader(test_set_consonant, batch_size=batch_size, collate_fn=generate_batch)
    
    train_labels = list(itertools.chain.from_iterable([t.cpu().tolist() for t in train_set._Y]))  
    
    missing_labels = list(set(clts_dataset.output_alphabet.indices).difference(set(train_labels)))
    train_labels = train_labels + [clts_dataset.output_alphabet.PAD_IDX] + missing_labels
    weight = get_class_weights_balanced(ignore_classes=[clts_dataset.pad_idx, clts_dataset.mask_idx], \
        classes=clts_dataset.output_alphabet.indices, y=train_labels)

    criterion = CrossEntropyLoss(weight=weight)
    model = lstm.LstmLM(
        input_dim=len(clts_dataset.input_alphabet),
        output_dim=len(clts_dataset.output_alphabet),
        embedding_dim=embedding_size,
        hidden_dim=hidden_size,
        dropout=dropout,
        n_layers=n_layers,
        loss_fn=criterion
    )
    # print(model)
    model.to(constants.device)
    optimizer = Adam(model.parameters())
    
    train(model, train_loader, valid_loader, optimizer, criterion, patience=patience)

    logprobs, target_indices, targets = test(model, test_loader, criterion)
    res['unmasked']['logprobs'] = logprobs
    res['unmasked']['targets'] = targets
    res['unmasked']['indices'] = target_indices
    
    logprobs, target_indices, targets = test(model, test_loader_vowel, criterion)
    res['vowel_masking']['logprobs'] = logprobs
    res['vowel_masking']['targets'] = targets
    res['vowel_masking']['indices'] = target_indices

    logprobs, target_indices, targets = test(model, test_loader_consonant, criterion)
    res['consonant_masking']['logprobs'] = logprobs
    res['consonant_masking']['targets'] = targets
    res['consonant_masking']['indices'] = target_indices

    model.to('cpu')
    criterion.to('cpu')

    save_results(base_path, lang_id, datasets, res, criterion, model)

# Hungarian
# N= 234
# Epoch	Loss	Perplexity
# 1	0.0816	0	
# 2	0.0786	0	
# 3	0.0772	0	
# 4	0.0744	0	
# 5	0.0755	0	
# Best epoch: 4, best valid loss: 0.07
# Test loss: 0.07218802484691653
# Test perplexity: 0
# Test loss: 0.06732728052590548
# Test perplexity: 0
# Test loss: 0.07600941311599863
# Test perplexity: 0

Finnish
N= 265
Epoch	Loss	Perplexity
1	0.097	0	
2	0.09	0	
3	0.089	0	
4	0.0882	0	
5	0.0875	0	
6	0.0881	0	
Best epoch: 5, best valid loss: 0.09
Test loss: 0.10789355421965977
Test perplexity: 0
Test loss: 0.09135610133677989
Test perplexity: 0
Test loss: 0.10945389225797833
Test perplexity: 0
North Karelian
N= 302
Epoch	Loss	Perplexity
1	0.1068	0	
2	0.1027	0	
3	0.0976	0	
4	0.0959	0	
5	0.0956	0	
6	0.0947	0	
7	0.0934	0	
8	0.0922	0	
9	0.0919	0	
10	0.0939	0	
Best epoch: 9, best valid loss: 0.09
Test loss: 0.07305502812593978
Test perplexity: 0
Test loss: 0.06679235273156284
Test perplexity: 0
Test loss: 0.07582287598919395
Test perplexity: 0
Olonets Karelian
N= 326
Epoch	Loss	Perplexity
1	0.1083	0	
2	0.1088	0	
Best epoch: 1, best valid loss: 0.11
Test loss: 0.0872304066558557
Test perplexity: 0
Test loss: 0.07528810029806093
Test perplexity: 0
Test loss: 0.08725657492327543
Test perplexity: 0
Veps
N= 219
Epoch	Loss	Perplexity
1	0.0954	0	
2	0.0906	0	
3	0.0859	0	
4	0.0846	0	
5	0.0843	0	
6	0.08