In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import sys
sys.path.append('../../transformers/src')

import math
import pandas as pd
import numpy as np

import torch
import torch.nn as nn

from transformers import AdamW, pipeline, PegasusForConditionalGeneration, PegasusTokenizer
from transformers import BartConfig
from transformers import AutoConfig
from transformers.models.bart.modeling_bart import EncoderLayer, SinusoidalPositionalEmbedding, LayerNorm

In [2]:
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("torch_device:",torch_device)

torch_device: cuda


In [116]:
model_name = 'google/pegasus-xsum'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
tokenizer.model_max_length

512

In [4]:
sample_text = [
    "Yuji Naraki [SAYS] Hi, John! [EOU] How are you? [EOU] [EOT] John [SAYS] I'm good. Thanks. [EOU] [EOT] Yuji Naraki [SAYS] Hi, John! [EOU] How are you? [EOU] [EOT]",
    "Naraki [SAYS] Good evening, Mr.Kim. [EOU] How was your today? [EOU] [EOT] Kim [SAYS] It is a pleasant day. [EOU] [EOT] Daive [SAYS] It is a pleasant day. [EOU] [EOT]"
]
special_tokens_dict = {'additional_special_tokens': ['[SAYS]','[EOU]','[EOT]']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

batch = tokenizer.prepare_seq2seq_batch(sample_text, truncation=True, padding='longest')
print('  '.join([tokenizer.convert_ids_to_tokens(i) for i in batch['input_ids'][0]]))

▁Yu  ji  ▁Nar  aki  [SAYS]  ▁Hi  ,  ▁John  !  [EOU]  ▁How  ▁are  ▁you  ?  [EOU]  [EOT]  ▁John  [SAYS]  ▁I  '  m  ▁good  .  ▁Thanks  .  [EOU]  [EOT]  ▁Yu  ji  ▁Nar  aki  [SAYS]  ▁Hi  ,  ▁John  !  [EOU]  ▁How  ▁are  ▁you  ?  [EOU]  [EOT]  </s>


In [45]:
text_list

["Amanda [SAYS] I baked cookies. [EOU] Do you want some? [EOU] [EOT] Jerry [SAYS] Sure! [EOU] [EOT] Amanda [SAYS] I'll bring you tomorrow :-) [EOU] [EOT]\n"]

In [5]:
class TurnConverter():
    def __init__(self, speaker_num, eot_idx):
        self.speaker_num = speaker_num
        self.eot_idx=eot_idx
        self.current_speaker_id=1
    
    def init_speaker_id(self):
        self.current_speaker_id=1
    
    def change_speaker_id(self):
        if self.current_speaker_id==1:
            self.current_speaker_id = 2
        elif self.current_speaker_id==2:
            self.current_speaker_id = 1
    
    def convert_id_to_speaker_id(self, w_id):
        if w_id==0:
            return 0
        elif w_id==self.eot_idx:
            self.change_speaker_id()
        return self.current_speaker_id

    def convert_batch(self, input_ids):
        batch_speaker_ids = []
        for text_ids in input_ids:
            speaker_ids = []
            sc.init_speaker_id()
            for w_id in text_ids:
                # speaker_ids.append(sc.convert_id_to_speaker_id(w_id.item()))
                speaker_ids.append(sc.convert_id_to_speaker_id(w_id))
            batch_speaker_ids.append(speaker_ids)
        return torch.tensor(batch_speaker_ids)

# Operation Check
tc = TurnConverter(speaker_num = 2, eot_idx = tokenizer.convert_tokens_to_ids('[EOT]'))
batch_speaker_ids = tc.convert_batch(batch['input_ids'])
embed_speaker = nn.Embedding(3, 10, padding_idx=0)
embed_spk = embed_speaker(batch_speaker_ids)

NameError: name 'sc' is not defined

In [6]:
class SpeakerConverter():
    def __init__(self, speaker_num, eot_idx):
        self.eot_idx=eot_idx
        self.current_speaker_id=1
        self.speaker_list = []
    
    def init_speaker_id(self):
        self.speaker_list = []
        self.current_speaker_id=1
    
    def change_speaker_id(self):
        if self.current_speaker_id==1:
            self.current_speaker_id = 2
        elif self.current_speaker_id==2:
            self.current_speaker_id = 1
    
    def convert_id_to_speaker_id(self, w_id):
        if w_id==0:
            return 0
        elif w_id==self.eot_idx:
            self.change_speaker_id()
        return self.current_speaker_id

    def convert_batch(self, input_ids):
        batch_speaker_ids = []
        for text_ids in input_ids:
            speaker_ids = []
            sc.init_speaker_id()
            for w_id in text_ids:
                # speaker_ids.append(sc.convert_id_to_speaker_id(w_id.item()))
                speaker_ids.append(sc.convert_id_to_speaker_id(w_id))
            batch_speaker_ids.append(speaker_ids)
        return torch.tensor(batch_speaker_ids)

# Operation Check
sc = SpeakerConverter(speaker_num = 2, eot_idx = tokenizer.convert_tokens_to_ids('[EOT]'))
batch_speaker_ids = sc.convert_batch(batch['input_ids'])
embed_speaker = nn.Embedding(99+1, 10, padding_idx=0)
embed_spk = embed_speaker(batch_speaker_ids)

In [7]:
says_token = tokenizer.convert_tokens_to_ids('[SAYS]')
eot_token  = tokenizer.convert_tokens_to_ids('[EOT]')
pad_token  = tokenizer.convert_tokens_to_ids('<pad>')
eod_token  = tokenizer.convert_tokens_to_ids('</s>')

In [8]:
class SpeakerConverter():
    def __init__(self, says_id, eot_id, eod_id=1, pad_id=0):
        self.says_id=says_id
        self.eot_idx=eot_id
        self.eod_id=eod_id
        self.pad_id=pad_id
        self.current_speaker_id=1
        self.speaker_list = []
    
    def init_attr(self):
        self.current_speaker_id=1
        self.speaker_list = []

    def get_speaker_id(self, speaker_name):
        if speaker_name in self.speaker_list:
            return self.speaker_list.index(speaker_name)+1
        else:
            self.speaker_list.append(speaker_name)
            return len(self.speaker_list)
        
    def change_speaker_id(self, speaker_ids):
        if self.eod_id == speaker_ids:
            self.current_speaker_id = 0
        else:
            self.current_speaker_id = self.get_speaker_id('_'.join([str(w_id) for w_id in speaker_ids]))
        return self.current_speaker_id
    
    def convert_batch(self, input_ids):
        batch_speaker_ids = []
        for text_idx, text_seq in enumerate(input_ids):
            speaker_ids = []
            self.init_attr()
            text_len = len(text_seq)
            for w_idx in range(text_len):
                if w_idx==0:
                    for i in range(w_idx, text_len):
                        if self.says_id == text_seq[i]:
                            speaker_ids.append(self.current_speaker_id)
                            self.change_speaker_id(text_seq[w_idx:i])
                            break
                elif self.eot_idx == text_seq[w_idx]:
                    for i in range(w_idx+1, text_len):
                        if self.eod_id == text_seq[i]:
                            speaker_ids.append(self.current_speaker_id)
                            self.change_speaker_id(self.eod_id)
                            break
                        elif self.says_id == text_seq[i]:
                            speaker_ids.append(self.current_speaker_id)
                            self.change_speaker_id(text_seq[w_idx+1:i])
                            break
                else:
                    speaker_ids.append(self.current_speaker_id)
            batch_speaker_ids.append(speaker_ids)
        return torch.tensor(batch_speaker_ids)

# Operation Check
sc = SpeakerConverter(
    says_id = tokenizer.convert_tokens_to_ids('[SAYS]'),
    eot_id = tokenizer.convert_tokens_to_ids('[EOT]'),
    eod_id = tokenizer.convert_tokens_to_ids('</s>'),
    pad_id = tokenizer.convert_tokens_to_ids('<pad>')
)
# print(batch['input_ids'])
batch_speaker_ids = sc.convert_batch(batch['input_ids'])
# print(batch_speaker_ids)
embed_speaker = nn.Embedding(99+1, 10, padding_idx=0)
embed_spk = embed_speaker(batch_speaker_ids)


In [9]:
embed_spk.shape

torch.Size([2, 44, 10])

In [14]:
tokenizer.convert_tokens_to_ids("[SAYS]")

96103

In [15]:
tokenizer.convert_tokens_to_ids("[EOT]")

96105

In [84]:
for t_i in range(2):
    for w_i in range(len(batch_speaker_ids[0])):
        print(tokenizer.convert_ids_to_tokens(batch['input_ids'][t_i][w_i]), batch_speaker_ids[t_i][w_i].item())

▁Yu 1
ji 1
▁Nar 1
aki 1
[SAYS] 1
▁Hi 1
, 1
▁John 1
! 1
[EOU] 1
▁How 1
▁are 1
▁you 1
? 1
[EOU] 1
[EOT] 1
▁John 2
[SAYS] 2
▁I 2
' 2
m 2
▁good 2
. 2
▁Thanks 2
. 2
[EOU] 2
[EOT] 2
▁Yu 1
ji 1
▁Nar 1
aki 1
[SAYS] 1
▁Hi 1
, 1
▁John 1
! 1
[EOU] 1
▁How 1
▁are 1
▁you 1
? 1
[EOU] 1
[EOT] 1
</s> 0
▁Nar 1
aki 1
[SAYS] 1
▁Good 1
▁evening 1
, 1
▁Mr 1
. 1
Kim 1
. 1
[EOU] 1
▁How 1
▁was 1
▁your 1
▁today 1
? 1
[EOU] 1
[EOT] 1
▁Kim 2
[SAYS] 2
▁It 2
▁is 2
▁a 2
▁pleasant 2
▁day 2
. 2
[EOU] 2
[EOT] 2
▁Dai 3
ve 3
[SAYS] 3
▁It 3
▁is 3
▁a 3
▁pleasant 3
▁day 3
. 3
[EOU] 3
[EOT] 3
</s> 0
<pad> 0
<pad> 0
<pad> 0
<pad> 0


In [57]:
input_ids = batch['input_ids']
for text_idx in range(len(input_ids)):
    # init_speaker_list()
    speaker_list = []
    text_seq = input_ids[text_idx]
    for w_idx in range(len(text_seq)):
        if w_idx==0:
            for i in range(w_idx, len(text_seq)):
                if says_token == text_seq[i]:
                    print("\t", ''.join([tokenizer.convert_ids_to_tokens(w_id) for w_id in text_seq[w_idx:i]]))
                    print("\t", get_speaker_id(text_seq[w_idx:i]))
                    break
        print(tokenizer.convert_ids_to_tokens(text_seq[w_idx]))
        if eot_token == text_seq[w_idx]:
            for i in range(w_idx+1, len(text_seq)):
                # print(i)
                if eod_token == text_seq[i]:
                    print("\t", "end")
                    print("\t", 0)
                    break
                elif says_token == text_seq[i]:
                    print("\t", ''.join([tokenizer.convert_ids_to_tokens(w_id) for w_id in text_seq[w_idx+1:i]]))
                    print("\t", get_speaker_id(text_seq[w_idx+1:i]))
                    break
                

	 ▁Yuji▁Naraki
	 1
▁Yu
ji
▁Nar
aki
[SAYS]
▁Hi
,
▁John
!
[EOU]
▁How
▁are
▁you
?
[EOU]
[EOT]
	 ▁John
	 2
▁John
[SAYS]
▁I
'
m
▁good
.
▁Thanks
.
[EOU]
[EOT]
	 end
	 0
</s>
<pad>
	 ▁Naraki
	 1
▁Nar
aki
[SAYS]
▁Good
▁evening
,
▁Mr
.
Kim
.
[EOU]
▁How
▁was
▁your
▁today
?
[EOU]
[EOT]
	 ▁Kim
	 2
▁Kim
[SAYS]
▁It
▁is
▁a
▁pleasant
▁day
.
[EOU]
[EOT]
	 end
	 0
</s>
