# 1.数据准备

In [8]:
# from datasets import load_dataset

# dataset = load_dataset("imdb")

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel
from torch.distributions import MultivariateNormal, Categorical
from torch.utils.data import Dataset, DataLoader

## Load data
df_b = pd.read_csv('/kaggle/input/intent-dataset/data/banking/train.tsv', sep='\t')
df_c = pd.read_csv('/kaggle/input/intent-dataset/data/clinc/train.tsv', sep='\t')
df_b_te = pd.read_csv('/kaggle/input/intent-dataset/data/banking/test.tsv', sep='\t')
df_b_dev = pd.read_csv('/kaggle/input/intent-dataset/data/banking/dev.tsv', sep='\t')

In [3]:
df = df_b
df_te = df_b_te
num_classes = df['label'].nunique()#*2

In [4]:
label_mapping = {v:i for i,v in enumerate(df['label'].unique())}
df['label_num'] = df['label'].map(label_mapping)
df_te['label_num'] = df_te['label'].map(label_mapping)

In [24]:
import torch.nn.functional as F

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CustomDataset(Dataset):
    def __init__(self, dataframe, num_classes=20):
        self.data = dataframe
        self.num_classes = num_classes
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data.loc[index, 'text']
#         encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=True, max_length=128, padding='max_length')
#         label = self.data.loc[index, 'label_num']
#         one_hot_label = F.one_hot(torch.tensor(label), num_classes=self.num_classes)
#         inputs_ids = encoded_input['input_ids'].squeeze(0)
#         attention_mask = encoded_input['attention_mask'].squeeze(0)
#         return inputs_ids, attention_mask, one_hot_label
        return {'text': text}

In [25]:
dataset = CustomDataset(df, num_classes)
val_dataset = CustomDataset(df_te, num_classes)

# 2.模型定义

In [11]:
class CFG:
    batch_size = 12
    ngpu = 2

In [12]:
# 加载BERT和GPT-2
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2')

# 保证GPT-2的词汇表与BERT的一致
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

class VAE(nn.Module):
    def __init__(self, bert_model, gpt2_model):
        super(VAE, self).__init__()
        self.bert_encoder = bert_model
        self.latent_dim = bert_model.config.hidden_size
        
        self.fc_mu = nn.Linear(self.latent_dim, self.latent_dim)
        self.fc_logvar = nn.Linear(self.latent_dim, self.latent_dim)
        self.bn_mu = nn.BatchNorm1d(self.latent_dim)  # 添加批量归一化层
#         self.bn_var = nn.BatchNorm1d(self.latent_dim)  # 添加批量归一化层
        
        self.gpt2_decoder = gpt2_model
        self.gpt2_config = gpt2_model.config
        
    def encode(self, input_ids, attention_mask=None):
        outputs = self.bert_encoder(input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :]  # 使用[CLS] token表示
        mu = self.bn_mu(self.fc_mu(hidden_state))
        logvar = self.fc_logvar(hidden_state)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def kl_divergence(self, z, z_l):
        q_z = MultivariateNormal(z, torch.eye(z.shape[1]).cuda())
        p_z_components = MultivariateNormal(z_l, torch.eye(z_l.shape[1]).cuda())
        log_q_z = q_z.log_prob(z)
        ## 当L=1
        log_p_z = p_z_components.log_prob(z_l)
        return torch.mean(log_q_z - log_p_z)
    
    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask=None):
        mu, logvar = self.encode(input_ids, attention_mask)
        z = self.reparameterize(mu, logvar)
        z_l = self.reparameterize(mu, logvar)
        kl_div = self.kl_divergence(z, z_l)
        # 将latent vector重复多次以匹配decoder_input_ids的序列长度
        latent_hidden = z.unsqueeze(1).repeat(1, decoder_input_ids.size(1), 1)
        # 将labels参数设置为与input_ids相同，因为再语言建模任务中，模型的目标是预测下一个单词，因此目标标签通常是输入数据的右移版本。
        gpt2_outputs = self.gpt2_decoder(inputs_embeds=latent_hidden, attention_mask=decoder_attention_mask, labels=decoder_input_ids)
        return gpt2_outputs.logits, kl_div

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型
model = VAE(bert_model, gpt2_model).to(device)

# model = nn.DataParallel(model, device_ids=list(range(CFG.ngpu)))

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

# 3.损失函数和优化器

In [13]:
from torch.optim import Adam

# def vae_loss_function(recon_logits, target_ids, mu, logvar):
#     recon_loss = nn.CrossEntropyLoss()(recon_logits.view(-1, recon_logits.size(-1)), target_ids.view(-1))
#     kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#     return recon_loss + kld_loss

def vae_loss_function(recon_logits, target_ids, kl_div):
    recon_loss = nn.CrossEntropyLoss()(recon_logits.view(-1, recon_logits.size(-1)), target_ids.view(-1))
    return recon_loss + kl_div


optimizer = Adam(model.parameters(), lr=1e-4)

# 4.训练模型

In [26]:
from torch.utils.data import DataLoader

# 数据加载器
# train_loader = DataLoader(dataset['train'], batch_size=CFG.batch_size, shuffle=True)
train_loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)
# 训练循环
num_epochs = 10
model.train()

epoch_steps = len(train_loader)
print(f'trian_loader length: {epoch_steps}')

best_loss = 1_000
for epoch in range(num_epochs):
    losses = []
    for i, batch in enumerate(train_loader):
        inputs = bert_tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True, max_length=512)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        decoder_inputs = gpt2_tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True, max_length=512)
        decoder_input_ids = decoder_inputs['input_ids']
        decoder_attention_mask = decoder_inputs['attention_mask']

        input_ids, attention_mask, decoder_input_ids, decoder_attention_mask = input_ids.to(device), attention_mask.to(device), decoder_input_ids.to(device), decoder_attention_mask.to(device) 

        optimizer.zero_grad()
        recon_logits, kl_div = model(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
        loss = vae_loss_function(recon_logits, decoder_input_ids, kl_div)

        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if i!=0 and i%50 == 0:
            print(f"Epoch {epoch}: [{i}/{epoch_steps}], Loss: {np.mean(losses):.2f}")
            if best_loss > np.mean(losses):
                model_path_tr = f'model_best_tr.pth'
                torch.save(model.state_dict(), model_path_tr)
                best_loss = np.mean(losses)
        
    model_path = f'model_epoch{epoch}.pth'
    torch.save(model.state_dict(), model_path)

trian_loader length: 2251
Epoch 0: [0/2251], Loss: 7.99
Epoch 0: [50/2251], Loss: 5.15
Epoch 0: [100/2251], Loss: 4.87
Epoch 0: [150/2251], Loss: 4.70
Epoch 0: [200/2251], Loss: 4.62
Epoch 0: [250/2251], Loss: 4.55
Epoch 0: [300/2251], Loss: 4.51
Epoch 0: [350/2251], Loss: 4.47
Epoch 0: [400/2251], Loss: 4.43
Epoch 0: [450/2251], Loss: 4.42
Epoch 0: [500/2251], Loss: 4.40
Epoch 0: [550/2251], Loss: 4.34
Epoch 0: [600/2251], Loss: 4.31
Epoch 0: [650/2251], Loss: 4.28
Epoch 0: [700/2251], Loss: 4.23
Epoch 0: [750/2251], Loss: 4.19
Epoch 0: [800/2251], Loss: 4.17
Epoch 0: [850/2251], Loss: 4.12
Epoch 0: [900/2251], Loss: 4.10
Epoch 0: [950/2251], Loss: 4.08
Epoch 0: [1000/2251], Loss: 4.06
Epoch 0: [1050/2251], Loss: 4.04
Epoch 0: [1100/2251], Loss: 4.01
Epoch 0: [1150/2251], Loss: 4.00
Epoch 0: [1200/2251], Loss: 3.98
Epoch 0: [1250/2251], Loss: 3.97
Epoch 0: [1300/2251], Loss: 3.97
Epoch 0: [1350/2251], Loss: 3.98
Epoch 0: [1400/2251], Loss: 3.97
Epoch 0: [1450/2251], Loss: 3.96
Epoch 0

# 5.推理

In [None]:
# underlying_model = model.module

# underlying_model.eval()

# with torch.no_grad():
#     sample_text = "This is a test sentence."
#     inputs = bert_tokenizer(sample_text, return_tensors='pt')
#     input_ids = inputs['input_ids']
#     attention_mask = inputs['attention_mask']

#     mu, logvar = underlying_model.encode(input_ids, attention_mask)
#     z = underlying_model.reparameterize(mu, logvar)

#     # 使用起始token生成新文本
#     decoder_input_ids = gpt2_tokenizer("<|endoftext|>", return_tensors='pt').input_ids

#     # 将latent vector重复多次以匹配decoder_input_ids的序列长度
#     latent_hidden = z.unsqueeze(1).repeat(1, decoder_input_ids.size(1), 1)

#     outputs = underlying_model.gpt2_decoder(inputs_embeds=latent_hidden, labels=decoder_input_ids)
#     predicted_ids = torch.argmax(outputs.logits, dim=-1)

#     generated_text = gpt2_tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
#     print(f"Generated Text: {generated_text}")