In [74]:
from deeplotx import SoftmaxRegression, LongformerEncoder
from util import NUM_CLASSES
lf_encoder = LongformerEncoder(model_name_or_path='severinsimmler/xlm-roberta-longformer-base-16384')

[DEBUG] 2025-08-03 00:23:26,988 deeplotx.embedding : LongformerEncoder initialized on device: cuda.


In [75]:
import torch
from deeplotx.util import sha256
from vortezwohl.cache import LRUCache

CACHE = LRUCache(capacity=16384)

def encode(text: str) -> torch.Tensor:
    key = sha256(text)
    if key in CACHE:
        return CACHE[key]
    emb = lf_encoder.encode(text, cls_only=False).mean(dim=-2, dtype=model.dtype)
    CACHE[key] = emb
    return emb

In [76]:
import os
import pandas as pd
base_path = './data/multilingual_wikineural'
languages = ['de', 'en', 'es', 'fr', 'it', 'nl', 'pl', 'pt', 'ru']
train_data = dict()
valid_data = dict()
test_data = dict()
for lang in languages:
    train_data[lang] = pd.read_csv(os.path.join(base_path, f'train_{lang}.csv')).to_dict(orient='records')
    valid_data[lang] = pd.read_csv(os.path.join(base_path, f'val_{lang}.csv')).to_dict(orient='records')
    test_data[lang] = pd.read_csv(os.path.join(base_path, f'test_{lang}.csv')).to_dict(orient='records')

In [77]:
total_train_data = []
total_valid_data = []
total_test_data = []
for lang, data in train_data.items():
    for d in data:
        total_train_data.append((d['tokens'], d['ner_tags'], d['lang']))
for lang, data in valid_data.items():
    for d in data:
        total_valid_data.append((d['tokens'], d['ner_tags'], d['lang']))
for lang, data in test_data.items():
    for d in data:
        total_test_data.append((d['tokens'], d['ner_tags'], d['lang']))

total_train_data[:2], total_valid_data[:2], total_test_data[:2]

([("['Dieses' 'wiederum' 'basierte' 'auf' 'dem' 'gleichnamigen' 'Roman' 'von'\n 'Noël' 'Calef' '.']",
   '[0 0 0 0 0 0 0 0 1 2 0]',
   'de'),
  ("['Auf' 'Helgoland' 'starben' '2014' 'sieben' 'Jungvögel' 'und' 'fünf'\n 'Altvögel' 'als' 'Verstrickungsopfer' '.']",
   '[0 5 0 0 0 0 0 0 0 0 0 0]',
   'de')],
 [("['Die' 'Europameisterschaften' 'in' 'Dresden' 'schloss' 'sie' 'mit' 'dem'\n 'zweiten' 'Platz' 'in' 'der' 'Gesamtwertung' 'ab' '.']",
   '[0 0 0 5 0 0 0 0 0 0 0 0 0 0 0]',
   'de'),
  ('[\'Die\' \'Europameisterschaften\' \'im\' \'selben\' \'Jahr\' \'in\' \'Heerenveen\'\n \'konnte\' \'sie\' \'mit\' \'dem\' \'Gewinn\' \'von\' \'vier\' \'Goldmedaillen\' \'(\'\n \'1000\' \'"\' \'m\' \',\' \'1500\' \'"\' \'m\' \',\' \'3000\' \'"\' \'m\' \',\' \'Gesamtwertung\'\n \')\' \'deutlich\' \'für\' \'sich\' \'entscheiden\' \'.\']',
   '[0 0 0 0 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]',
   'de')],
 [("['Er' 'erklärte' ',' 'dass' 'der' 'Allgemeine' 'Nationalkongress'\n 'illega

In [None]:
from random import shuffle

train_tokens = []
valid_tokens = []
test_tokens = []
train_dataset = []
valid_dataset = []
test_dataset = []
max_other_token_count = 55_000  # 55k
other_token_count_for_each_lang = {
    'de': 0,
    'en': 0,
    'es': 0,
    'fr': 0,
    'it': 0,
    'nl': 0,
    'pl': 0,
    'pt': 0,
    'ru': 0
}

def token_decode(s: str) -> list:
    return [_.strip().replace("'", '').replace('[', '').replace(']', '') for _ in s.replace("' '", '||').replace('\n', '||').split('||')]

def label_decode(s: str) -> list:
    _labels = []
    label_lists = token_decode(s)
    for label_ls in label_lists:
        _labels.extend([int(_) for _ in label_ls.split(' ')])
    return _labels

for tokens, labels, lang in total_train_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for i, token in enumerate(tokens):
        if token in train_tokens:
            continue
        train_tokens.append(token)
        if labels[i] > 0:
            train_dataset.append((token, labels[i]))
        else:
            if other_token_count_for_each_lang[lang] < max_other_token_count:
                other_token_count_for_each_lang[lang] += 1
                train_dataset.append((token, labels[i]))

for tokens, labels, lang in total_valid_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for i, token in enumerate(tokens):
        if token in valid_tokens:
            continue
        valid_tokens.append(token)
        if labels[i] > 0:
            valid_dataset.append((token, labels[i]))
        else:
            if other_token_count_for_each_lang[lang] < max_other_token_count:
                other_token_count_for_each_lang[lang] += 1
                valid_dataset.append((token, labels[i]))
                
                
for tokens, labels, lang in total_test_data:
    tokens, labels = token_decode(tokens), label_decode(labels)
    for i, token in enumerate(tokens):
        if token in test_tokens:
            continue
        test_tokens.append(token)
        if labels[i] > 0:
            test_dataset.append((token, labels[i]))
        else:
            if other_token_count_for_each_lang[lang] < max_other_token_count:
                other_token_count_for_each_lang[lang] += 1
                test_dataset.append((token, labels[i]))

print('Dataset initialized', list(zip(train_dataset[:23], train_dataset[:23])))
print(f'other_token_count_for_each_lang: {other_token_count_for_each_lang}')
f'{len(train_dataset)} tokens to train in total.'

## 训练

In [99]:
model = SoftmaxRegression(input_dim=768, output_dim=NUM_CLASSES, num_heads=4, num_layers=3, expansion_factor=1.25, bias=True, dropout_rate=0.2, head_layers=2)
print(model)

Model_Name: SoftmaxRegression
In_Features: 768
Out_Features: 9
Device: cuda
Dtype: torch.float32
Total_Parameters: 49642017
Trainable_Parameters: 49642017
NonTrainable_Parameters: 0
-------------------------------
SoftmaxRegression(
  (multi_head_ffn_layers): ModuleList(
    (0-2): 3 x MultiHeadFeedForward(
      (expand_proj): Linear(in_features=768, out_features=3072, bias=True)
      (ffn_heads): ModuleList(
        (0-3): 4 x FeedForward(
          (ffn_layers): ModuleList(
            (0-1): 2 x FeedForwardUnit(
              (up_proj): Linear(in_features=768, out_features=960, bias=True)
              (down_proj): Linear(in_features=960, out_features=768, bias=True)
              (parametric_relu): PReLU(num_parameters=1)
              (layer_norm): LayerNorm((768,), eps=1e-09, elementwise_affine=True)
            )
          )
        )
      )
      (out_proj): Linear(in_features=3072, out_features=768, bias=True)
    )
  )
  (out_proj): Linear(in_features=768, out_features=9, 

In [96]:
# train
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter

In [100]:
train_step = 0
valid_step = 0
writer = SummaryWriter()

acc_train_loss = 0.
acc_valid_loss = 0.
eval_interval = 2000
log_interval = 200
valid_log_interval = 50

In [101]:
shuffle(train_dataset)
shuffle(valid_dataset)
shuffle(test_dataset)
print('Dataset shuffled', list(zip(train_dataset[:23], train_dataset[:23])))

Dataset shuffled [(('Parco', 5), ('Parco', 5)), (('Horemheb', 1), ('Horemheb', 1)), (('hiszpański', 0), ('hiszpański', 0)), (('australische', 7), ('australische', 7)), (('Ducum', 8), ('Ducum', 8)), (('.', 1), ('.', 1)), (('That "" 70', 7), ('That "" 70', 7)), (('Uffenbach', 1), ('Uffenbach', 1)), (('2', 8), ('2', 8)), (('Short', 8), ('Short', 8)), (('Grenville "s" ministry', 1), ('Grenville "s" ministry', 1)), (('al-Chums', 5), ('al-Chums', 5)), (('Bahia', 5), ('Bahia', 5)), (('Marc', 1), ('Marc', 1)), (('Hitchcock', 2), ('Hitchcock', 2)), (('Cross', 1), ('Cross', 1)), (('Joseph', 1), ('Joseph', 1)), (('CA', 3), ('CA', 3)), (('McCartney', 2), ('McCartney', 2)), (('Emmen', 4), ('Emmen', 4)), (('Miami', 3), ('Miami', 3)), (('.', 1), ('.', 1)), (('Herat', 5), ('Herat', 5))]


In [102]:
from random import randint
from util import one_hot

elastic_net_param = {
    'alpha': 2e-4,
    'rho': 0.2
}
learning_rate = 2e-6
num_epochs = 1500
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    model.train()
    for i, (_token, _label) in enumerate(train_dataset):
        _one_hot_label = one_hot(_label).to(model.dtype).to(model.device)
        outputs = model.forward(encode(_token))
        loss = loss_function(outputs, _one_hot_label) + model.elastic_net(alpha=elastic_net_param['alpha'], rho=elastic_net_param['rho'])
        acc_train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if train_step % log_interval == 0 and train_step > 0:
            writer.add_scalar('train/loss', acc_train_loss / log_interval, train_step)
            print(f'- Train Step {train_step} Loss {acc_train_loss / log_interval} \\'
                  f'\nToken={_token}'
                  f'\nPred={outputs.tolist()}'
                  f'\nLabel={_one_hot_label.tolist()}', flush=True)
            acc_train_loss = 0.
        train_step += 1
        if train_step % eval_interval == 0:
            model.eval()
            rand_idx = randint(0, len(valid_dataset) - 501)
            with torch.no_grad():
                for _i, (__token, __label) in enumerate(valid_dataset[rand_idx: rand_idx + 500]):
                    _one_hot_label = one_hot(__label).to(model.dtype).to(model.device)
                    outputs = model.forward(encode(__token))
                    loss = loss_function(outputs, _one_hot_label)
                    acc_valid_loss += loss.item()
                    if valid_step % valid_log_interval == 0 and valid_step > 0:
                        writer.add_scalar('valid/loss', acc_valid_loss / valid_log_interval, valid_step)
                        print(f'- Valid Step {valid_step} Loss {acc_valid_loss / valid_log_interval} \\'
                              f'\nToken={__token}'
                              f'\nPred={outputs.tolist()}'
                              f'\nLabel={_one_hot_label.tolist()}', flush=True)
                        acc_valid_loss = 0.
                    valid_step += 1
            model.train()

- Train Step 200 Loss 37.720984592437745 \
Token=Sappho
Pred=[0.10309843719005585, 0.06281554698944092, 0.1830260306596756, 0.025760382413864136, 0.03563135489821434, 0.4473961591720581, 0.030544010922312737, 0.0492185577750206, 0.06250952929258347]
Label=[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
- Train Step 400 Loss 37.16333711624146 \
Token=the
Pred=[0.03310875594615936, 0.020944811403751373, 0.036317843943834305, 0.0073010982014238834, 0.007972911931574345, 0.8601081371307373, 0.008484329096972942, 0.010676349513232708, 0.015085704624652863]
Label=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
- Train Step 600 Loss 36.89213706970215 \
Token=Estados
Pred=[0.16568142175674438, 0.052781667560338974, 0.05279934033751488, 0.012112121097743511, 0.014443262480199337, 0.6478871703147888, 0.013667537830770016, 0.016477560624480247, 0.02414986491203308]
Label=[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]


KeyboardInterrupt: 

In [None]:
model.eval()
test_tokens = ['Mike', 'John', 'Smith', 'London', 'NYC', 'HongKong', 'China', 'South Africa', 'Korea']
with torch.no_grad():
    for _tok in test_tokens:
        _dist = model.forward(encode(_tok))
        print(f'Token={_tok}, Class={torch.argmax(_dist)}', flush=True)
model.train()
'Test finished.'