In [43]:
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoConfig
import logging 

model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2', device = "mps") 
tokenizer = model.tokenizer

start_token, sep_token, pad_token_id = tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token_id

import os 

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

06/26/2025 10:45:24 - INFO - sentence_transformers.SentenceTransformer -   Load pretrained SentenceTransformer: paraphrase-multilingual-MiniLM-L12-v2


In [47]:
def line_statistics(file_name):
    """
    统计文件行数
    """
    if file_name is None:
        return 0

    return 6948 


In [36]:
import torch.nn as nn 
import torch 

class PoolingAverage(nn.Module):
    def __init__(self, eps=1e-12):
        super(PoolingAverage, self).__init__()
        self.eps = eps

    def forward(self, hidden_states, attention_mask):
        mul_mask = lambda x, m: x * torch.unsqueeze(m, dim=-1)
        reduce_mean = lambda x, m: torch.sum(mul_mask(x, m), dim=1) / (torch.sum(m, dim=1, keepdims=True) + self.eps)

        avg_output = reduce_mean(hidden_states, attention_mask)
        return avg_output

    def equal_forward(self, hidden_states, attention_mask):
        mul_mask = hidden_states * attention_mask.unsqueeze(-1)
        avg_output = torch.sum(mul_mask, dim=1) / (torch.sum(attention_mask, dim=1, keepdim=True) + self.eps)
        return avg_output


In [37]:





class DialogueTransformer(nn.Module):
    def __init__(self):
        super(DialogueTransformer, self).__init__()
        self.bert = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2', device = "mps") 

        hf_model = model._first_module().auto_model  # or model[0].auto_model
        self.config = hf_model.config

        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.labels_data = None
        self.sample_nums = 10
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.avg = PoolingAverage(eps=1e-6)
        self.logger = logger  

    def forward(self, data, strategy='mean_by_role', output_attention=False):


        if len(data) == 7:
            input_ids, attention_mask, token_type_ids, role_ids, turn_ids, position_ids, labels = data
        else:
            input_ids, attention_mask, token_type_ids, role_ids, turn_ids, position_ids, labels, guids = data

        input_ids = input_ids.view(input_ids.size()[0] * input_ids.size()[1], input_ids.size()[-1])
        attention_mask = attention_mask.view(attention_mask.size()[0] * attention_mask.size()[1], attention_mask.size()[-1])
        token_type_ids = token_type_ids.view(token_type_ids.size()[0] * token_type_ids.size()[1], token_type_ids.size()[-1])
        role_ids = role_ids.view(role_ids.size()[0] * role_ids.size()[1], role_ids.size()[-1])
        turn_ids = turn_ids.view(turn_ids.size()[0] * turn_ids.size()[1], turn_ids.size()[-1])
        position_ids = position_ids.view(position_ids.size()[0] * position_ids.size()[1], position_ids.size()[-1])

        one_mask = torch.ones_like(role_ids)
        zero_mask = torch.zeros_like(role_ids)
        role_a_mask = torch.where(role_ids == 0, one_mask, zero_mask)
        role_b_mask = torch.where(role_ids == 1, one_mask, zero_mask)

        sep_token_id = self.bert.tokenizer.sep_token_id 
        a_attention_mask = (attention_mask * role_a_mask)
        b_attention_mask = (attention_mask * role_b_mask)

        self_output, pooled_output = self.encoder(input_ids, attention_mask, token_type_ids, position_ids, turn_ids, role_ids)

        q_self_output = self_output * a_attention_mask.unsqueeze(-1)
        r_self_output = self_output * b_attention_mask.unsqueeze(-1)

        self_output = self_output * attention_mask.unsqueeze(-1)
        w = torch.matmul(q_self_output, r_self_output.transpose(-1, -2))

        if turn_ids is not None:
            view_turn_mask = turn_ids.unsqueeze(1).repeat(1, self.bert.max_seq_length, 1)
            view_turn_mask_transpose = view_turn_mask.transpose(2, 1)
            view_range_mask = torch.where(abs(view_turn_mask_transpose - view_turn_mask) <= 1000,
                                          torch.ones_like(view_turn_mask),
                                          torch.zeros_like(view_turn_mask))
            filtered_w = w * view_range_mask

        q_cross_output = torch.matmul(filtered_w.permute(0, 2, 1), q_self_output)
        r_cross_output = torch.matmul(filtered_w, r_self_output)

        q_self_output = self.avg(q_self_output, a_attention_mask)
        q_cross_output = self.avg(q_cross_output, b_attention_mask)
        r_self_output = self.avg(r_self_output, b_attention_mask)
        r_cross_output = self.avg(r_cross_output, a_attention_mask)

        self_output = self.avg(self_output, attention_mask)
        q_self_output = q_self_output.view(-1, self.sample_nums, self.config.hidden_size)
        q_cross_output = q_cross_output.view(-1, self.sample_nums, self.config.hidden_size)
        r_self_output = r_self_output.view(-1, self.sample_nums, self.config.hidden_size)
        r_cross_output = r_cross_output.view(-1, self.sample_nums, self.config.hidden_size)

        self_output = self_output.view(-1, self.sample_nums, self.config.hidden_size)
        pooled_output = pooled_output.view(-1, self.sample_nums, self.config.hidden_size)

        output = self_output[:, 0, :]
        q_output = q_self_output[:, 0, :]
        r_output = r_self_output[:, 0, :]
        q_contrastive_output = q_cross_output[:, 0, :]
        r_contrastive_output = r_cross_output[:, 0, :]

        logit_q = []
        logit_r = []
        for i in range(self.sample_nums):
            cos_q = self.calc_cos(q_self_output[:, i, :], q_cross_output[:, i, :])
            cos_r = self.calc_cos(r_self_output[:, i, :], r_cross_output[:, i, :])
            logit_r.append(cos_r)
            logit_q.append(cos_q)

        logit_r = torch.stack(logit_r, dim=1)
        logit_q = torch.stack(logit_q, dim=1)

        loss_r = self.calc_loss(logit_r, labels)
        loss_q = self.calc_loss(logit_q, labels)

        if strategy not in ['mean', 'mean_by_role']:
            raise ValueError('Unknown strategy: [%s]' % strategy)

        output_dict = {'loss': loss_r + loss_q,
                       'final_feature': output if strategy == 'mean' else q_output + r_output,
                       'q_feature': q_output,
                       'r_feature': r_output,
                       'attention': w}

        return output_dict

    def encoder(self, *x):
        input_ids, attention_mask, token_type_ids, position_ids, turn_ids, role_ids = x    

        transformer = self.bert[0].auto_model
        output = transformer(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            output_hidden_states=True,
                            return_dict=True)
        all_output = output['hidden_states']
        pooler_output = output['pooler_output']
        return all_output[-1], pooler_output

    def calc_cos(self, x, y):
        cos = torch.cosine_similarity(x, y, dim=1)
        cos = cos / 1.0 # cos = cos / 2.0
        return cos

    def calc_loss(self, pred, labels):
        loss = -torch.mean(self.log_softmax(pred) * labels)
        return loss

    def get_result(self):
        return self.result

    def get_labels_data(self):
        return self.labels_data

In [38]:
class BertFeatures():
    def __init__(self, input_ids, input_mask, segment_ids, role_ids, label_id, turn_ids=None, position_ids=None, guid=None):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.role_ids = role_ids
        self.turn_ids = turn_ids
        self.position_ids = position_ids
        self.label_id = label_id
        self.guid = guid

        self.batch_size = len(self.input_ids)
import codecs  

features = [] 

with codecs.open("datasets/doc2dial/train.tsv", "r", "utf") as f: 
    examples = [] 
    for line in f: 
        line = [s.strip() for s in line.split('\t') if s.strip()]
        role, session, label = line[0], line[1], line[2]

        examples.append((role, session, label))

    for example in examples: 
        samples = example[1].split("|")
        roles = [int(r) for r in example[0].split("|")] \
                if example[0].find("#") != -1 \
                else [int(r) for r in example[0]]
    
    
        sample_input_ids = []
        sample_segment_ids = []
        sample_role_ids = []
        sample_input_mask = []
        sample_turn_ids = []
        sample_position_ids = []

        for t, s in enumerate(samples):
            text_tokens = []
            text_turn_ids = []
            text_role_ids = []

            texts = s.split("#")

            # bert-token:     [cls]  token   [sep]  token
            # roberta-token:   <s>   token   </s>   </s> token


            text_tokens.append(start_token)
            text_turn_ids.append(0)
            text_role_ids.append(roles[0])

            for i, text in enumerate(texts): 

                 tokenized = tokenizer.tokenize(text)
                 text_tokens.extend(tokenized)
                 text_turn_ids.extend([i] * len(tokenized))
                 text_role_ids.extend([roles[i]] * len(tokenized))

                 if i != (len(text) - 1): 
                     text_tokens.append(sep_token)
                     text_turn_ids.append(i)
                     text_role_ids.append(roles[i])

            text_tokens = text_tokens[:model.max_seq_length]
            text_turn_ids = text_turn_ids[:model.max_seq_length]
            text_role_ids = text_role_ids[:model.max_seq_length]

            text_input_ids = tokenizer.convert_tokens_to_ids(text_tokens)


            text_input_ids += [pad_token_id] * (model.max_seq_length - len(text_tokens))
            text_input_mask = [1] * len(text_tokens) + [0] * (model.max_seq_length - len(text_tokens))
            text_segment_ids = [0] * model.max_seq_length
            text_position_ids = list(range(len(text_tokens))) + [0] * (model.max_seq_length - len(text_tokens))
            text_turn_ids += [0] * (model.max_seq_length - len(text_tokens))
            text_role_ids += [0] * (model.max_seq_length - len(text_tokens))


            assert len(text_input_ids) == model.max_seq_length
            assert len(text_input_mask) == model.max_seq_length
            assert len(text_segment_ids) == model.max_seq_length
            assert len(text_position_ids) == model.max_seq_length
            assert len(text_turn_ids) == model.max_seq_length
            assert len(text_role_ids) == model.max_seq_length

            sample_input_ids.append(text_input_ids)
            sample_turn_ids.append(text_turn_ids)
            sample_role_ids.append(text_role_ids)
            sample_segment_ids.append(text_segment_ids)
            sample_position_ids.append(text_position_ids)
            sample_input_mask.append(text_input_mask) 


        n_neg = 9
        label_id = [1] + [0] * n_neg
        bert_feature = BertFeatures(input_ids=sample_input_ids,
                                    input_mask=sample_input_mask,
                                    segment_ids=sample_segment_ids,
                                    role_ids=sample_role_ids,
                                    turn_ids=sample_turn_ids,
                                    position_ids=sample_position_ids,
                                    label_id=label_id,
                                    guid=[None] * (1 + n_neg))

        features.append(bert_feature)

Token indices sequence length is longer than the specified maximum sequence length for this model (141 > 128). Running this sequence through the model will result in indexing errors


In [39]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, DistributedSampler
import torch 
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_role_ids = torch.tensor([f.role_ids for f in features], dtype=torch.long)
all_turn_ids = torch.tensor([f.turn_ids for f in features], dtype=torch.long)
all_position_ids = torch.tensor([f.position_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)

train_data = TensorDataset(all_input_ids,
                            all_input_mask,
                            all_segment_ids,
                            all_role_ids,
                            all_turn_ids,
                            all_position_ids,
                            all_label_ids)

train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data,
                        sampler=train_sampler,
                        batch_size=5)

In [49]:
from optimization import BERTAdam


d2vmodel = DialogueTransformer().to(torch.device("mps"))

transformer = d2vmodel.bert[0].auto_model
param_optimizer = list(transformer.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}]
optimizer = BERTAdam(optimizer_grouped_parameters, lr=1e-5, warmup=0.1, t_total=line_statistics("datasets/doc2dial/train.tsv"))

steps = 0 
d2vmodel.train()
for epoch in range(10):
    for step, batch in enumerate(train_loader):
        batch = tuple(t.to(torch.device("mps")) for t in batch)
        output_dict = d2vmodel(batch, strategy='mean_by_role')
        loss = output_dict['loss']

        loss = loss / 1.0
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        if steps % 20 == 0: 
            print("Loss: ", loss.item())
        if steps % 100 == 0:
            d2vmodel.train()

        steps += 1 

06/26/2025 10:47:14 - INFO - sentence_transformers.SentenceTransformer -   Load pretrained SentenceTransformer: paraphrase-multilingual-MiniLM-L12-v2


tensor(0.4607, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4600, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4607, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4602, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4608, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4604, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4600, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4602, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4604, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4602, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4607, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4601, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4601, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4599, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4602, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4604, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4598, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4598, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.4

KeyboardInterrupt: 