In [None]:
from deeplotx import SoftmaxRegression, LongformerEncoder
from util import NUM_CLASSES
lf_encoder = LongformerEncoder(model_name_or_path='severinsimmler/xlm-roberta-longformer-base-16384')

In [None]:
import torch
from deeplotx.util import sha256
from vortezwohl.cache import LRUCache

CACHE = LRUCache(capacity=16384)

def encode(text: str) -> torch.Tensor:
    key = sha256(text)
    if key in CACHE:
        return CACHE[key]
    emb = lf_encoder.encode(text, cls_only=False).mean(dim=-2, dtype=model.dtype)
    CACHE[key] = emb
    return emb

In [None]:
import os
import pandas as pd
base_path = './data/multilingual_wikineural'
languages = ['de', 'en', 'es', 'fr', 'it', 'nl', 'pl', 'pt', 'ru']
train_data = dict()
valid_data = dict()
test_data = dict()
for lang in languages:
    train_data[lang] = pd.read_csv(os.path.join(base_path, f'train_{lang}.csv')).to_dict(orient='records')
    valid_data[lang] = pd.read_csv(os.path.join(base_path, f'val_{lang}.csv')).to_dict(orient='records')
    test_data[lang] = pd.read_csv(os.path.join(base_path, f'test_{lang}.csv')).to_dict(orient='records')

In [None]:
total_train_data = []
total_valid_data = []
total_test_data = []
for lang, data in train_data.items():
    for d in data:
        total_train_data.append((d['tokens'], d['ner_tags']))
for lang, data in valid_data.items():
    for d in data:
        total_valid_data.append((d['tokens'], d['ner_tags']))
for lang, data in test_data.items():
    for d in data:
        total_test_data.append((d['tokens'], d['ner_tags']))

total_train_data[:2], total_valid_data[:2], total_test_data[:2]

In [None]:
train_dataset = {'token': [], 'label': []}
valid_dataset = {'token': [], 'label': []}
test_dataset = {'token': [], 'label': []}

def token_decode(s: str) -> list:
    return [_.strip().replace("'", '').replace('[', '').replace(']', '') for _ in s.replace("' '", '||').replace('\n', '||').split('||')]

def label_decode(s: str) -> list:
    _labels = []
    label_lists = token_decode(s)
    for label_ls in label_lists:
        _labels.extend([int(_) for _ in label_ls.split(' ')])
    return _labels

for tokens, labels in total_train_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for token in tokens:
        train_dataset['token'].append(token)
    for label in labels:
        train_dataset['label'].append(label)
for tokens, labels in total_valid_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for token in tokens:
        valid_dataset['token'].append(token)
    for label in labels:
        valid_dataset['label'].append(label)
for tokens, labels in total_test_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for token in tokens:
        test_dataset['token'].append(token)
    for label in labels:
        test_dataset['label'].append(label)

print(list(zip(train_dataset['token'][:23], train_dataset['label'][:23])))
f'{len(train_dataset["token"])} tokens to train in total.'

## 存储嵌入

In [None]:
train_embeddings = []
valid_embeddings = []
test_embeddings = []
for i, token in enumerate(train_dataset['token']):
    # print(f'\rTrainSet {i}/{len(train_dataset["token"])}')
    train_embeddings.append(encode(token))
print('TrainSet done\n')
for i, token in enumerate(valid_dataset['token']):
    # print(f'\rValidSet {i}/{len(valid_dataset["token"])}')
    valid_embeddings.append(encode(token))
print('ValidSet done\n')
for i, token in enumerate(test_dataset['token']):
    # print(f'\rTestSet {i}/{len(test_dataset["token"])}')
    test_embeddings.append(encode(token))
print('TestSet done\n')

In [None]:
import pickle

with open('./data/multilingual_wikineural_processed/train.pkl', 'wb') as f:
    _data = {
        'tokens': train_dataset['token'],
        'embeddings': train_embeddings,
        'labels': train_dataset['label']
    }
    pickle.dump(_data, f)
with open('./data/multilingual_wikineural_processed/valid.pkl', 'wb') as f:
    _data = {
        'tokens': valid_dataset['token'],
        'embeddings': valid_embeddings,
        'labels': valid_dataset['label']
    }
    pickle.dump(_data, f)
with open('./data/multilingual_wikineural_processed/test.pkl', 'wb') as f:
    _data = {
        'tokens': test_dataset['token'],
        'embeddings': test_embeddings,
        'labels': test_dataset['label']
    }
    pickle.dump(_data, f)

## 训练

In [None]:
model = SoftmaxRegression(input_dim=768, output_dim=NUM_CLASSES, num_heads=4, num_layers=3, expansion_factor=1.25, bias=True, dropout_rate=0.2, head_layers=2)
print(model)

In [None]:
# train
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter

train_step = 0
valid_step = 0
writer = SummaryWriter()

acc_train_loss = 0.
acc_valid_loss = 0.
eval_interval = 2000
log_interval = 200

In [None]:
from random import randint
from util import one_hot

elastic_net_param = {
    'alpha': 2e-4,
    'rho': 0.2
}
learning_rate = 2e-5
num_epochs = 1500
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    model.train()
    for i, _token in enumerate(train_dataset['token']):
        _label = train_dataset['label'][i]
        _one_hot_label = one_hot(_label).to(model.dtype).to(model.device)
        outputs = model.forward(encode(_token))
        loss = loss_function(outputs, _one_hot_label) + model.elastic_net(alpha=elastic_net_param['alpha'], rho=elastic_net_param['rho'])
        acc_train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if train_step % log_interval == 0 and train_step > 0:
            writer.add_scalar('train/loss', acc_train_loss / log_interval, train_step)
            print(f'- Train Step {train_step} Loss {acc_train_loss / log_interval} \ \nPred={outputs.tolist()} \nLabel={_one_hot_label.tolist()}', flush=True)
            acc_train_loss = 0.
        train_step += 1
        if train_step % eval_interval == 0:
            model.eval()
            rand_idx = randint(0, len(valid_dataset['token']) - 1001)
            with torch.no_grad():
                for _i, __token in enumerate(valid_dataset['token'][rand_idx: rand_idx + 1000]):
                    _label = valid_dataset['label'][_i]
                    _one_hot_label = one_hot(_label).to(model.dtype).to(model.device)
                    outputs = model.forward(encode(__token))
                    loss = loss_function(outputs, _one_hot_label)
                    acc_valid_loss += loss.item()
                    if valid_step % log_interval == 0 and valid_step > 0:
                        writer.add_scalar('valid/loss', acc_valid_loss / log_interval, valid_step)
                        print(f'- Valid Step {valid_step} Loss {acc_valid_loss / log_interval} \ \nPred={outputs.tolist()} \nLabel={_one_hot_label.tolist()}', flush=True)
                        acc_valid_loss = 0.
                    valid_step += 1
            model.train()

In [None]:
model.eval()
test_tokens = ['Mike', 'John', 'Smith', 'London', 'NYC', 'HongKong', 'China', 'South Africa', 'Korea']
with torch.no_grad():
    for _tok in test_tokens:
        _dist = model.forward(encode(_tok))
        print(f'Token {_tok} Dist {_dist.tolist()}', flush=True)
model.train()
'Test finished.'