In [13]:
import torch
from torch import nn
import os
import random
from torch.utils import data
from tqdm import tqdm
import numpy as np
from copy import deepcopy
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import warnings
import csv
import re

warnings.filterwarnings('ignore')

global extracted_grads

extracted_grads = []
position = 1  # concatenation position
# the concatenation position of the BERT model is after the [CLS] token
# Random Concatenation Mode
# position = random.randint(1,500)

BERT_path = 'PreTrainedModelBert'  # path to bert model
tokenize = BertTokenizer.from_pretrained(os.path.join(BERT_path, 'vocab.txt'))
model_config = BertConfig.from_pretrained(os.path.join(BERT_path, 'config.json'))
Model = BertForSequenceClassification.from_pretrained(os.path.join(BERT_path, 'pytorch_model.bin'), config=model_config)

# Load model related information

# Print the number of Total Parameters
# total = [param.nelement() for param in Model.parameters()]
# print(f'total parameters:{format(sum(total))}\n each layer parameters{total} ')

'''
SNLI Data
'''


### Load data

def extract_text(s):
    # 移除括号
    s = re.sub('\\(', '', s)
    s = re.sub('\\)', '', s)
    # 使用一个空格替换两个以上连续空格
    s = re.sub('\\s{2,}', ' ', s)
    return s.strip()


def read_snli_binary_data(data_dir, is_train):
    """读取SNLI二分类数据集"""
    label_set = {'entailment': 0, 'contradiction': 1}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]

    # 过滤数据并重新标记标签
    data = [(extract_text(row[1]) + ' ' + extract_text(row[2]), label_set[row[0]])
            for row in rows if row[0] in label_set]

    # 分离文本和标签
    texts, labels = zip(*data)
    return texts, labels


def read_snli_binary_test_data(data_dir, is_train):
    """读取SNLI二分类数据集"""
    # label_set = {'entailment': 0, 'contradiction': 1}
    # label_set = {'entailment': 0}
    label_set = {'contradiction': 1}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]

    # 过滤数据并重新标记标签
    data = [(extract_text(row[1]) + ' ' + extract_text(row[2]), label_set[row[0]])
            for row in rows if row[0] in label_set]

    # 分离文本和标签
    texts, labels = zip(*data)
    return texts, labels


def load_snli_array(data_arrays, batch_size, is_train=True):
    """Constructs a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


def load_snli_data(train_batch_size, test_batch_iter, num_steps=500):
    train_data = read_snli_binary_data('snli_1.0', is_train=True)
    test_data = read_snli_binary_test_data('snli_1.0', is_train=False)
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_snli_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        train_batch_size)
    test_iter = load_snli_array(
        (test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
        test_batch_iter,
        is_train=False)
    return train_iter, test_iter


### Train

def train(net, train_iter, lr, num_epochs, device):
    print('---------------------------start---------------------')
    optimizer = torch.optim.AdamW(net.parameters(), lr=lr)
    net = net.to(device[0])
    for epoch in range(num_epochs):
        net.train()
        print(f' epoch {epoch + 1}')
        train_losses = []
        train_accs = []
        train_length = 0
        for batch in tqdm(train_iter):
            a, b, y = batch
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            logits = outputs.logits
            l = outputs.loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            acc = (logits.argmax(dim=-1) == y).float().mean()
            train_losses.append(l)
            train_accs.append(acc)
            train_length += len(y)
        print("Learning rate for epoch %d：%f" % (epoch + 1, optimizer.param_groups[0]['lr']))
        train_loss = sum(train_losses) / len(train_iter)
        train_acc = sum(train_accs) / len(train_iter)
        print(f"[ Train | {epoch + 1:03d}/{num_epochs:03d} ] loss = {train_loss:.5f}   acc = {train_acc:.5f}")
    print('Training process has finished.')
    print('the loss of model {:.3f}'.format(train_loss))


def evaluate_no(net, test_iter):
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc


### Trigger Token

def init_trigger_tokens(trigger, num_trigger_tokens):
    # Initialize trigger tokens, we use 'the' as initial trigger token
    trigger_token_ids = [1996] * num_trigger_tokens  # 1996 means 'the'
    trigger_token_tensor = torch.tensor(trigger_token_ids)
    return trigger_token_tensor


def evaluate(net, test_iter, trigger_token_tensor):
    # evaluate the accuracy of the model after concatenating the initial trigger token
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    n = torch.tensor([0] * len(trigger_token_tensor))
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = n.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
            b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc


def extract_grad_hook(net, grad_in, grad_out):  # store the gradient in extracted_grads
    extracted_grads.append(grad_out[0].mean(dim=0))


def add_hook(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            hook = module.register_backward_hook(extract_grad_hook)
            break
    return hook


def get_gradient(net, test_iter, trigger_token_tensor):  # Calculate the loss to get the gradient
    net = net.to(device[0])
    net.train()
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = torch.tensor([0] * len(trigger_token_tensor))
    n = n.unsqueeze(0)
    optimizer = torch.optim.AdamW(net.parameters())
    for batch in tqdm(test_iter):
        a, b, y = batch
        a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
        b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
        a = a.to(device[0])
        b = b.to(device[0])
        y = y.to(device[0])
        outputs = net(input_ids=a, token_type_ids=b, labels=y)
        l = outputs.loss
        optimizer.zero_grad()
        l.backward()


def process_gradient(length, num_trigger_tokens):  # Process the gradient to get the average gradient
    extracted_grads_copy = extracted_grads
    extracted_grads_copy[0] = extracted_grads_copy[0]
    temp = extracted_grads_copy[0]
    temp = temp.unsqueeze(0)
    for i in range(1, length - 1):
        extracted_grads_copy[i] = extracted_grads_copy[i]
        extracted_grads_copy[i] = extracted_grads_copy[i].unsqueeze(0)
        temp = torch.cat((temp, extracted_grads_copy[i]), dim=0)
    average_grad = temp.mean(dim=0)[position:position + num_trigger_tokens]
    return average_grad


def hotflip_attack(averaged_grad, embedding_matrix,
                   num_candidates=1, increase_loss=False):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1
        # lower versus increase the class probability.
    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]  # Return candidates
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()


def collection_attack(net, test_iter, num_candidates, num_epoch, trigger='the',  # Summarize each function
                      num_trigger_tokens=3):
    trigger_token_tensor = init_trigger_tokens(trigger, num_trigger_tokens)
    print(f'Concatenation location:{position}')
    valid_acc = evaluate(net, test_iter, trigger_token_tensor)
    print(f'Initial trigger tokens state：the accuracy {valid_acc:.5f}')
    embedding_weight = get_embedding_weight(net)
    for i in range(num_epoch):
        extracted_grads.clear()
        hook = add_hook(net)
        get_gradient(net, test_iter, trigger_token_tensor)
        hook.remove()
        average_grad = process_gradient(len(test_iter), num_trigger_tokens)
        hot_token = hotflip_attack(average_grad, embedding_weight, num_candidates, increase_loss=True)
        hot_token_tensor = torch.from_numpy(hot_token)
        trigger_token_tensor, valid_acc = select_best_candid(net, test_iter, hot_token_tensor, trigger_token_tensor,
                                                             valid_acc)
        print(f'after {i + 1} rounds of attacking\ntriggers: {trigger_token_tensor} \nthe accuracy :{valid_acc:.5f} ')
    return trigger_token_tensor, valid_acc  # Return the final trigger tokens (trigger length) and the accuracy after the attack


def get_embedding_weight(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            weight = module.weight
            break
    return weight


def select_best_candid(net, test_iter, candid_trigger, trigger_token, valid_acc):
    # Concatenate each candidate to each input to determine the final trigger token
    n = torch.tensor([0] * len(trigger_token))
    n = n.unsqueeze(0)
    trigger_token = trigger_token.unsqueeze(0)
    net.eval()
    valid_accs = []
    for i in range(candid_trigger.shape[0]):
        trigger_token_temp = deepcopy(trigger_token)
        for j in range(candid_trigger.shape[1]):
            trigger_token_temp[0, i] = candid_trigger[i, j]
            valid_accs = []
            for batch in tqdm(test_iter):
                a, b, y = batch
                a = torch.cat((a[:, :position], trigger_token_temp.repeat_interleave(a.shape[0], dim=0),
                               a[:, position:]), dim=1)
                b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0),
                               b[:, position:]), dim=1)
                a = a.to(device[0])
                b = b.to(device[0])
                y = y.to(device[0])
                outputs = net(input_ids=a, token_type_ids=b, labels=y)
                acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
                valid_accs.append(acc)
            temp = sum(valid_accs) / len(test_iter)
            if temp < valid_acc:
                valid_acc = temp
                trigger_token[0, i] = candid_trigger[i, j]
    return trigger_token[0], valid_acc  # Return the final trigger token and the accuracy after the attack


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at PreTrainedModelBert/pytorch_model.bin and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
# train_iter, test_iter = load_imdb_data(10)
# train_iter, test_iter = load_sst_data(10)
train_iter, test_iter = load_snli_data(20, 3)
# Data preprocessing and loading
print("reading data finished\n")
print(len(train_iter))
print(len(test_iter))

reading data finished

18331
1079


In [4]:
def try_all_gpus():
    devices = [torch.device(f'cuda:{i}')
               for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

device = try_all_gpus()
# device = [torch.device('cpu')]
train(Model, train_iter, 5e-6, 3, device)  # base BERT

# train(Model, train_iter, 5e-5, 3, device) # else BERT
# The accuracy of the model on the test set when no trigger token is concatenated

evaluate_no(Model, test_iter)

---------------------------start---------------------
 epoch 1


  0%|          | 0/18331 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 18331/18331 [18:05<00:00, 16.88it/s]


Learning rate for epoch 1：0.000005
[ Train | 001/003 ] loss = 0.18846   acc = 0.92234
 epoch 2


100%|██████████| 18331/18331 [17:58<00:00, 17.00it/s]


Learning rate for epoch 2：0.000005
[ Train | 002/003 ] loss = 0.10166   acc = 0.96271
 epoch 3


100%|██████████| 18331/18331 [18:06<00:00, 16.87it/s]


Learning rate for epoch 3：0.000005
[ Train | 003/003 ] loss = 0.06908   acc = 0.97572
Training process has finished.
the loss of model 0.069


100%|██████████| 1123/1123 [00:07<00:00, 157.81it/s]


tensor(0.9599, device='cuda:0')

In [5]:
torch.save(Model, 'Bert_snli.bin')

In [15]:
model = torch.load('Bert_snli.bin')
evaluate_no(model, test_iter)

100%|██████████| 1079/1079 [00:06<00:00, 154.70it/s]


tensor(0.9725, device='cuda:0')

In [26]:
collection_attack(model, test_iter, 5, 10, trigger='the', num_trigger_tokens=2)

Concatenation location:1


100%|██████████| 1079/1079 [00:06<00:00, 154.46it/s]


Initial trigger tokens state：the accuracy 0.96756


100%|██████████| 1079/1079 [00:59<00:00, 18.04it/s]
100%|██████████| 1079/1079 [00:08<00:00, 123.95it/s]
100%|██████████| 1079/1079 [00:08<00:00, 123.58it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.52it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.56it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.97it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.82it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.98it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.52it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.56it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.71it/s]


after 1 rounds of attacking
triggers: tensor([19089, 22833]) 
the accuracy :0.95891 


100%|██████████| 1079/1079 [00:59<00:00, 18.06it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.35it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.45it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.34it/s]
100%|██████████| 1079/1079 [00:08<00:00, 119.97it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.85it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.37it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.20it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.18it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.05it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.63it/s]


after 2 rounds of attacking
triggers: tensor([27056, 18545]) 
the accuracy :0.95428 


100%|██████████| 1079/1079 [01:00<00:00, 17.97it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.26it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.32it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.19it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.42it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.63it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.47it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.47it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.58it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.67it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.68it/s]


after 3 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.02it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.79it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.43it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.76it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.60it/s]
100%|██████████| 1079/1079 [00:08<00:00, 122.13it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.12it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.51it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.00it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.34it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.10it/s]


after 4 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.05it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.07it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.27it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.76it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.14it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.54it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.63it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.58it/s]
100%|██████████| 1079/1079 [00:09<00:00, 116.94it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.44it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.16it/s]


after 5 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.00it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.79it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.22it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.54it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.23it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.70it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.03it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.70it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.62it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.88it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.62it/s]


after 6 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.07it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.53it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.38it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.84it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.24it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.42it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.77it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.15it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.54it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.48it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.63it/s]


after 7 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.09it/s]
100%|██████████| 1079/1079 [00:08<00:00, 119.95it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.21it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.65it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.60it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.39it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.59it/s]
100%|██████████| 1079/1079 [00:08<00:00, 119.91it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.57it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.13it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.94it/s]


after 8 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 18.03it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.42it/s]
100%|██████████| 1079/1079 [00:08<00:00, 121.49it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.71it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.58it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.02it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.79it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.85it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.05it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.20it/s]
100%|██████████| 1079/1079 [00:09<00:00, 115.84it/s]


after 9 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 


100%|██████████| 1079/1079 [00:59<00:00, 17.99it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.39it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.05it/s]
100%|██████████| 1079/1079 [00:09<00:00, 117.28it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.54it/s]
100%|██████████| 1079/1079 [00:09<00:00, 118.49it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.56it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.24it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.27it/s]
100%|██████████| 1079/1079 [00:08<00:00, 120.10it/s]
100%|██████████| 1079/1079 [00:09<00:00, 119.67it/s]

after 10 rounds of attacking
triggers: tensor([26940, 18545]) 
the accuracy :0.94810 





(tensor([26940, 18545]), tensor(0.9481, device='cuda:0'))