[参考サイト](https://www.philschmid.de/bert-text-classification-in-a-different-language)

In [1]:
# install simpletransformers
!pip install simpletransformers
 
# check installed version
!pip freeze | grep simpletransformers
# simpletransformers==0.28.2

Collecting simpletransformers
  Downloading simpletransformers-0.70.1-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.4/42.4 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting requests (from simpletransformers)
  Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.47.0 (from simpletransformers)
  Using cached tqdm-4.66.4-py3-none-any.whl.metadata (57 kB)
Collecting regex (from simpletransformers)
  Using cached regex-2024.5.15-cp312-cp312-macosx_11_0_arm64.whl.metadata (40 kB)
Collecting transformers>=4.31.0 (from simpletransformers)
  Using cached transformers-4.42.3-py3-none-any.whl.metadata (43 kB)
Collecting datasets (from simpletransformers)
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting seqeval (from simpletransformers)
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.1 MB/s[0m eta 

In [5]:
!wget https://projects.fzai.h-da.de/iggsa/wp-content/uploads/2019/08/germeval2019GoldLabelsSubtask1_2.txt
!wget https://projects.fzai.h-da.de/iggsa/wp-content/uploads/2019/09/germeval2019.training_subtask1_2_korrigiert.txt

--2024-07-09 08:52:43--  https://projects.fzai.h-da.de/iggsa/wp-content/uploads/2019/08/germeval2019GoldLabelsSubtask1_2.txt
projects.fzai.h-da.de (projects.fzai.h-da.de) をDNSに問いあわせています... 141.100.10.111, 2001:67c:2184:fdfe::111
projects.fzai.h-da.de (projects.fzai.h-da.de)|141.100.10.111|:443 に接続しています... 接続しました。
HTTP による接続要求を送信しました、応答を待っています... 301 Moved Permanently
場所: https://fz.h-da.de/iggsa//wp-content/uploads/2019/08/germeval2019GoldLabelsSubtask1_2.txt [続く]
--2024-07-09 08:52:44--  https://fz.h-da.de/iggsa//wp-content/uploads/2019/08/germeval2019GoldLabelsSubtask1_2.txt
fz.h-da.de (fz.h-da.de) をDNSに問いあわせています... 141.100.10.111, 2001:67c:2184:fdfe::111
fz.h-da.de (fz.h-da.de)|141.100.10.111|:443 に接続しています... 接続しました。
HTTP による接続要求を送信しました、応答を待っています... 404 Not Found
2024-07-09 08:52:45 エラー 404: Not Found。

--2024-07-09 08:52:45--  https://projects.fzai.h-da.de/iggsa/wp-content/uploads/2019/09/germeval2019.training_subtask1_2_korrigiert.txt
projects.fzai.h-da.de (projects.fzai.h-da.de) 

In [None]:
import pandas as pd
 
class_list = ['INSULT','ABUSE','PROFANITY','OTHER']
 
df1 = pd.read_csv('germeval2019GoldLabelsSubtask1_2.txt',sep='\t', lineterminator='\n',encoding='utf8',names=["tweet", "task1", "task2"])
df2 = pd.read_csv('germeval2019.training_subtask1_2_korrigiert.txt',sep='\t', lineterminator='\n',encoding='utf8',names=["tweet", "task1", "task2"])
 
df = pd.concat([df1,df2])
df['task2'] = df['task2'].str.replace('\r', "")
df['pred_class'] = df.apply(lambda x:  class_list.index(x['task2']),axis=1)
 
df = df[['tweet','pred_class']]
 
print(df.shape)
df.head()

In [None]:
from sklearn.model_selection import train_test_split
 
train_df, test_df = train_test_split(df, test_size=0.10)
 
print('train shape: ',train_df.shape)
print('test shape: ',test_df.shape)
 
# train shape:  (6309, 2)
# test shape:  (702, 2)

In [None]:
from simpletransformers.classification import ClassificationModel
 
# define hyperparameter
train_args ={"reprocess_input_data": True,
             "fp16":False,
             "num_train_epochs": 4}
 
# Create a ClassificationModel
model = ClassificationModel(
    "bert", "distilbert-base-german-cased",
    num_labels=4,
    args=train_args
)

In [None]:
model.train_model(train_df)

In [None]:
from sklearn.metrics import f1_score, accuracy_score
 
def f1_multiclass(labels, preds):
    return f1_score(labels, preds, average='micro')
 
result, model_outputs, wrong_predictions = model.eval_model(test_df, f1=f1_multiclass, acc=accuracy_score)
 
# {'acc': 0.6894586894586895,
# 'eval_loss': 0.8673831869594075,
# 'f1': 0.6894586894586895,
# 'mcc': 0.25262380289641617}

In [None]:
import os
import tarfile
 
def pack_model(model_path='',file_name=''):
  files = [files for root, dirs, files in os.walk(model_path)][0]
  with tarfile.open(file_name+ '.tar.gz', 'w:gz') as f:
    for file in files:
      f.add(f'{model_path}/{file}')
 
# run the function
pack_model('output_path','model_name')

In [None]:
import os
import tarfile
 
def unpack_model(model_name=''):
  tar = tarfile.open(f"{model_name}.tar.gz", "r:gz")
  tar.extractall()
  tar.close()
 
unpack_model('model_name')

In [None]:
from simpletransformers.classification import ClassificationModel
 
# define hyperparameter
train_args ={"reprocess_input_data": True,
             "fp16":False,
             "num_train_epochs": 4}
 
# Create a ClassificationModel with our trained model
model = ClassificationModel(
    "bert", 'path_to_model/',
    num_labels=4,
    args=train_args
)

In [None]:
class_list = ['INSULT','ABUSE','PROFANITY','OTHER']
 
test_tweet1 = "Meine Mutter hat mir erzählt, dass mein Vater einen Wahlkreiskandidaten nicht gewählt hat, weil der gegen die Homo-Ehe ist"
 
predictions, raw_outputs = model.predict([test_tweet1])
 
print(class_list[predictions[0]])
# OTHER
test_tweet2 = "Frau #Böttinger meine Meinung dazu ist sie sollten uns mit ihrem Pferdegebiss nicht weiter belästigen #WDR"
 
predictions, raw_outputs = model.predict([test_tweet2])
 
print(class_list[predictions[0]])
# INSULT