<a href="https://colab.research.google.com/github/sasachichito/knowledge/blob/master/computer/%E8%87%AA%E7%84%B6%E8%A8%80%E8%AA%9E%E5%87%A6%E7%90%86_%E8%A6%81%E7%B4%84_%E6%8A%BD%E8%B1%A1%E5%9E%8B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 依存関係

In [None]:
!pip install transformers[ja] deep-translator

from urllib.request import urlopen
from bs4 import BeautifulSoup

def webpage_to_text(url, selector):
  text = ''
  with urlopen(url) as res:
      html = res.read().decode('UTF-8', 'ignore')
      soup = BeautifulSoup(html, 'html.parser')
      # article = soup.find('div', class_="articleBody")
      # text = article.get_text(strip=True)
      article = soup.select(selector)
      text = ''
      for p in article:
        text += p.get_text(strip=True)
      return text

def trim_last_halfway_sentence(text, period_char):
  if text.endswith(period_char):
    return text

  last_period_index = text.rfind(period_char)
  return text[:last_period_index + 1]

# 抽象型要約ライブラリ比較

In [None]:
from deep_translator import GoogleTranslator
from transformers import pipeline

def summary_text(text_ja, text_en):
  print('==============origin================')
  print(text_ja)
  print(text_en)
  print('====================================')

  text = text_en

  # GPU使える場合は使う
  device = 0 if torch.cuda.is_available() else -1

  # 結果を日本語で確認するためにdeep_translatorを使用
  translator_for_print_ja = GoogleTranslator(source='auto', target="ja")

  print("\n" + '>>>>> google/pegasus-large' + "\n")
  pega_l_summarizer = pipeline("summarization", model="google/pegasus-large", device=device)
  summarized_pega_l = pega_l_summarizer(text, do_sample=False)
  print([item['summary_text'] for item in summarized_pega_l][0])
  print(translator_for_print_ja.translate([item['summary_text'] for item in summarized_pega_l][0]))

  print("\n" + '>>>>> google/pegasus-xsum' + "\n")
  pega_x_summarizer = pipeline("summarization", model="google/pegasus-xsum", device=device)
  summarized_pega_x = pega_x_summarizer(text, do_sample=False)
  print([item['summary_text'] for item in summarized_pega_x][0])
  print(translator_for_print_ja.translate([item['summary_text'] for item in summarized_pega_x][0]))

  print("\n" + '>>>>> facebook/bart-large-cnn' + "\n")
  bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
  summarized_bart = bart_summarizer(text, do_sample=False)
  print([item['summary_text'] for item in summarized_bart][0])
  print(translator_for_print_ja.translate([item['summary_text'] for item in summarized_bart][0]))

  print("\n" + '>>>>> csebuetnlp/mT5_multilingual_XLSum' + "\n")
  t5_summarizer = pipeline("summarization", model="csebuetnlp/mT5_multilingual_XLSum", device=device)
  summarized_t5 = t5_summarizer(text, do_sample=True)
  print([item['summary_text'] for item in summarized_t5][0])
  print(translator_for_print_ja.translate([item['summary_text'] for item in summarized_t5][0]))

# 日本語記事の要約
# text_ja = webpage_to_text('https://xtech.nikkei.com/atcl/nxt/column/18/02828/050900001/', '.articleBody p')
# text_ja = webpage_to_text('https://xtech.nikkei.com/atcl/nxt/column/18/02783/051000018/', '.articleBody p')
text_ja = webpage_to_text('https://xtech.nikkei.com/atcl/nxt/column/18/02252/051400006/', '.articleBody p')
text_ja = trim_last_halfway_sentence(text_ja[:800], '。')
translator_to_en = GoogleTranslator(source='auto', target="en")
text_en = translator_to_en.translate(text_ja)
summary_text(text_ja, text_en)

# 英語記事の要約
text_en = webpage_to_text('https://edition.cnn.com/2024/05/13/politics/takeaways-michael-cohen-testimony-donald-trump-day-16/index.html', '.article__content p')
text_en = trim_last_halfway_sentence(text_en[:800], '.')
summary_text('', text_en)

# 要約クラス実装

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from deep_translator import GoogleTranslator
import numpy as np
import re
import sys
import torch
from enum import Enum


class ModelType(Enum):
    facebook_bart_large_cnn = (
        "facebook/bart-large-cnn",
        512,
        lambda tokens: np.append(np.insert(tokens, 0, 0), 2), # 先頭に0、末尾に2を入れる
        1
    )
    google_pegasus_xsum = (
        "google/pegasus-xsum",
        512,
        lambda tokens: np.append(tokens, 1),
        0
    )
    csebuetnlp_mT5_multilingual_XLSum = (
        "csebuetnlp/mT5_multilingual_XLSum",
        512,
        lambda tokens: np.append(tokens, 1),
        0
    )

    def __init__(self, model_name, max_size, add_special_token_func, pad):
        self.model_name = model_name
        self.max_size = max_size
        self.add_special_token_func = add_special_token_func
        self.pad = pad

    def add_special_token_func(self, tokens):
      return self.add_special_token_func(tokens)


class SummarizeSlot:
    def __init__(self, model_type):
      self.sentece_token_list = []
      self.model_type = model_type

    def can_append(self, append_size):
      return self.size() + append_size <= self.model_type.max_size - 2

    def get_token(self, index):
      return self.sentece_token_list[index][1]

    def pop(self):
      return self.sentece_token_list.pop()

    def append(self, sentence, tokens):
      self.sentece_token_list.append([sentence, tokens])

    def insert(self, index, sentence, tokens):
      self.sentece_token_list.insert(index, [sentence, tokens])

    def size(self):
      return sum(row[1].size for row in self.sentece_token_list)

    def tensor_2d(self):
      token_array = np.concatenate([row[1] for row in self.sentece_token_list])
      temp = self.model_type.add_special_token_func(token_array)
      temp = np.pad( # max_sizeまでpadding
          temp,
          (0, self.model_type.max_size - temp.size),
          constant_values=self.model_type.pad
      )
      return torch.from_numpy(temp.copy()).unsqueeze(0) # 2階テンソルで返却


class SummarizeSlotList:
    def __init__(self, text, delimiter, model_type):
      self.text = text
      self.delimiter = delimiter
      self.model_type = model_type
      self.tokenizer = AutoTokenizer.from_pretrained(model_type.model_name)
      self.model = AutoModelForSeq2SeqLM.from_pretrained(model_type.model_name)
      self.slot_size = model_type.max_size
      self.slot_list = [SummarizeSlot(model_type)]

      parts = text.split(delimiter)
      sentence_list = [part + delimiter for part in parts]
      [self.append(sentence) for sentence in sentence_list]

    def append(self, sentence):
      WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
      token_ids = self.tokenizer(
          [WHITESPACE_HANDLER(sentence)],
          return_tensors="np",
          add_special_tokens=False,
          truncation=True,
          max_length=512
      )["input_ids"][0]

      slot = self.get_appendable_slot(token_ids.size)

      slot.append(sentence, token_ids)

    def get_appendable_slot(self, append_size):
      last_slot = self.slot_list[-1]
      if last_slot.can_append(append_size):
        return last_slot
      else:
        new_slot = SummarizeSlot(self.model_type)
        self.slot_list.append(new_slot)
        return new_slot

    def summarize_r(self, max_slot_cnt, **kwargs):
      result = self.summarize(**kwargs)

      for _ in range(3): # 最大{max_slot_cnt}回まで繰り返し要約する
        if len(result) <= max_slot_cnt:
            break
        one_more = SummarizeSlotList(''.join(result), self.delimiter, self.model_type)
        result = one_more.summarize(**kwargs)

      return result

    def summarize(self, **kwargs):
      self.__flatten()

      result = []
      for slot in self.slot_list:
        if torch.cuda.is_available():
          self.model = self.model.cuda()
          input_ids = slot.tensor_2d().cuda()

        with torch.no_grad():
          output_ids = self.model.generate(
              input_ids=input_ids,
              **kwargs
          )[0]

        result.append(
            self.tokenizer.decode(
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
        )
      return result

    def slot_count(self):
      return len(self.slot_list)

    def __flatten(self):
      if len(self.slot_list) <= 1:
        return

      last_slot = self.slot_list[-1]
      last2_slot = self.slot_list[-2] # 最後から2番目のユニット

      while last_slot.size() <= last2_slot.size():

        if not last_slot.can_append(last2_slot.get_token(-1).size):
          break

        mv_sentece_token = last2_slot.pop()
        last_slot.insert(0, mv_sentece_token[0], mv_sentece_token[1])

def make_trans_batch_list(text, delimiter, size):
  parts = text.split(delimiter)
  sentence_list = [part + delimiter for part in parts]

  trans_batch_list = []
  batch = ""
  for sentence in sentence_list:
    if len(batch) + len(sentence) > size:
      trans_batch_list.append(batch)
      batch = ""

    batch += sentence

  if batch:
    trans_batch_list.append(batch)

  return trans_batch_list


translator_to_en = GoogleTranslator(source='auto', target="en")
translator_to_ja = GoogleTranslator(source='auto', target="ja")

url_list = [
    ['https://xtech.nikkei.com/atcl/nxt/column/18/02252/051400006/', '.articleBody p'],
    ['https://xtech.nikkei.com/atcl/nxt/column/18/02252/052100016/', '.articleBody p'],
    ['https://xtech.nikkei.com/atcl/nxt/column/18/00001/09325/', '.articleBody p'],
    ['https://xtech.nikkei.com/atcl/nxt/column/18/00154/02062/', '.articleBody p'],
    ['https://xtech.nikkei.com/atcl/nxt/column/18/00138/051601526/', '.articleBody p'],
    ['https://ainow.ai/2024/05/28/276412/', '.article_area'],
    ['https://aws.amazon.com/jp/blogs/news/gen-ai-usecase-daiichikosho/', 'article'],
    ['https://techblog.lycorp.co.jp/ja/20240527a', 'article'],
    ['https://prtimes.jp/main/html/rd/p/000000095.000034517.html', 'article'],
    ['https://prtimes.jp/main/html/rd/p/000000094.000073671.html', 'article'],
]

text_ja_list = [webpage_to_text(url[0], url[1]) for url in url_list]

text_en_list = []
for text_ja in text_ja_list:
  trans_batch_list = make_trans_batch_list(text_ja, '。', 2000)
  text_en = ''.join(translator_to_en.translate(batch) for batch in trans_batch_list)
  text_en_list.append(text_en)

kwargs_list = [
    # {
    #     'name' : 'beam',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         # 'num_beams': 5,
    #         'num_beams': 10,
    #         # 'num_beams': 30,
    #         'no_repeat_ngram_size': 2,
    #         'early_stopping': True
    #     }
    # },
    # {
    #     'name' : 'sample_temp',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         'do_sample': True,
    #         # 'temperature': 0.5,
    #         'temperature': 0.7,
    #         # 'temperature': 1.2,
    #     }
    # },
    # {
    #     'name' : 'sample_top_k',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         'do_sample': True,
    #         # 'top_k': 30,
    #         'top_k': 50,
    #         # 'top_k': 100,
    #     }
    # },
    # {
    #     'name' : 'sample_top_p_0.7',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         'do_sample': True,
    #         'top_p': 0.7,
    #     }
    # },
    # {
    #     'name' : 'sample_top_p_0.8',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         'do_sample': True,
    #         'top_p': 0.8,
    #     }
    # },
    # {
    #     'name' : 'sample_top_p_0.9',
    #     'kwargs' : {
    #         'min_length': 10,
    #         'max_length': 30,
    #         'do_sample': True,
    #         'top_p': 0.9,
    #     }
    # },
    {
        'name' : 'sample_top_p_0.95',
        'kwargs' : {
            'min_length': 10,
            'max_length': 100,
            'do_sample': True,
            'top_p': 0.95,
        }
    },
]


for text_ja, text_en in zip(text_ja_list, text_en_list):
  print("\n" + '-'*50 + "\n" + text_ja + "\n" + '-'*50)

  summarize_slot_list = SummarizeSlotList(text_en, '.', ModelType.facebook_bart_large_cnn)
  # summarize_slot_list = SummarizeSlotList(text_en, '.', ModelType.google_pegasus_xsum)
  # summarize_slot_list = SummarizeSlotList(text_en, '.', ModelType.csebuetnlp_mT5_multilingual_XLSum)

  for kwargs in kwargs_list:
    print("\n" + kwargs.get('name') + "\n")
    result_list = summarize_slot_list.summarize_r(3, **kwargs.get('kwargs'))
    # result_list = summarize_slot_list.summarize(**kwargs.get('kwargs'))
    for result in result_list:
      print(translator_to_ja.translate(result))


スペシャルトークンの確認

In [None]:
from transformers import AutoTokenizer
import numpy as np
import torch

# model_name = "facebook/bart-large-cnn"
model_name = "google/pegasus-xsum"
# model_name = "csebuetnlp/mT5_multilingual_XLSum"

tokenizer = AutoTokenizer.from_pretrained(model_name)
print("スペシャルトークン:")
print(tokenizer.all_special_ids)
print(tokenizer.all_special_tokens)

# ID:105のトークン確認
# print(tokenizer.decode(torch.tensor([105], dtype=torch.int32), skip_special_tokens=False))

text_en = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."
token_ids_en = tokenizer(
    text_en,
    return_tensors="np",
    add_special_tokens=True,
    truncation=True,
    padding="max_length",
    max_length=200
)["input_ids"][0]

print("英語文字列トークン:")
print(token_ids_en)

common_elements = np.intersect1d(tokenizer.all_special_ids, token_ids_en)
print("使用されるスペシャルトークン:")
print(common_elements)


text_ja = "IT業界が属する情報通信業は、総じて若年層の伸びが大きい。 ６０～６４歳層は退職や再雇用による人手不足を補う戦力と位置付けられ、従来より高い賃金が提示されている。製造業がマイナス圏に陥るのは60年代後半まで。ITエンジニアの仕事は多岐にわたり、スキルの有無や習熟度によって賃金は変わります。 ITSSレベルが上がると年収も上がります。賃金は仕事の「価値」に対して支払われます。ある価値を生み出すためには、何らかのスキルが必要だからです。"
token_ids_ja = tokenizer(
    text_ja,
    return_tensors="np",
    add_special_tokens=True,
    truncation=True,
    padding="max_length",
    max_length=200
)["input_ids"][0]

print("日本語文字列トークン:")
print(token_ids_ja)

common_elements = np.intersect1d(tokenizer.all_special_ids, token_ids_ja)
print("使用されるスペシャルトークン:")
print(common_elements)
